Skip to content

Commit 199e9c5

Browse files
authored
core[patch]: Fix tool args schema inherited field parsing (#24936)
Fix #24925
1 parent fba65ba commit 199e9c5

File tree

3 files changed

+289
-7
lines changed

3 files changed

+289
-7
lines changed

libs/core/langchain_core/tools.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
get_type_hints,
4646
)
4747

48-
from typing_extensions import Annotated, cast, get_args, get_origin
48+
from typing_extensions import Annotated, TypeVar, cast, get_args, get_origin
4949

5050
from langchain_core._api import deprecated
5151
from langchain_core.callbacks import (
@@ -88,11 +88,16 @@
8888
run_in_executor,
8989
)
9090
from langchain_core.runnables.utils import accepts_context
91-
from langchain_core.utils.function_calling import _parse_google_docstring
91+
from langchain_core.utils.function_calling import (
92+
_parse_google_docstring,
93+
_py_38_safe_origin,
94+
)
9295
from langchain_core.utils.pydantic import (
9396
TypeBaseModel,
9497
_create_subset_model,
9598
is_basemodel_subclass,
99+
is_pydantic_v1_subclass,
100+
is_pydantic_v2_subclass,
96101
)
97102

98103
FILTERED_ARGS = ("run_manager", "callbacks")
@@ -387,7 +392,7 @@ def args(self) -> dict:
387392
def tool_call_schema(self) -> Type[BaseModel]:
388393
full_schema = self.get_input_schema()
389394
fields = []
390-
for name, type_ in full_schema.__annotations__.items():
395+
for name, type_ in _get_all_basemodel_annotations(full_schema).items():
391396
if not _is_injected_arg_type(type_):
392397
fields.append(name)
393398
return _create_subset_model(
@@ -1650,3 +1655,80 @@ def _filter_schema_args(func: Callable) -> List[str]:
16501655
filter_args.append(config_param)
16511656
# filter_args.extend(_get_non_model_params(type_hints))
16521657
return filter_args
1658+
1659+
1660+
def _get_all_basemodel_annotations(
1661+
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
1662+
) -> Dict[str, Type]:
1663+
# cls has no subscript: cls = FooBar
1664+
if isinstance(cls, type):
1665+
annotations: Dict[str, Type] = {}
1666+
for name, param in inspect.signature(cls).parameters.items():
1667+
annotations[name] = param.annotation
1668+
orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple())
1669+
# cls has subscript: cls = FooBar[int]
1670+
else:
1671+
annotations = _get_all_basemodel_annotations(
1672+
get_origin(cls), default_to_bound=False
1673+
)
1674+
orig_bases = (cls,)
1675+
1676+
# Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not.
1677+
if not (isinstance(cls, type) and is_pydantic_v2_subclass(cls)):
1678+
# if cls = FooBar inherits from Baz[str], orig_bases will contain Baz[str]
1679+
# if cls = FooBar inherits from Baz, orig_bases will contain Baz
1680+
# if cls = FooBar[int], orig_bases will contain FooBar[int]
1681+
for parent in orig_bases:
1682+
# if class = FooBar inherits from Baz, parent = Baz
1683+
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
1684+
annotations.update(
1685+
_get_all_basemodel_annotations(parent, default_to_bound=False)
1686+
)
1687+
continue
1688+
1689+
parent_origin = get_origin(parent)
1690+
1691+
# if class = FooBar inherits from non-pydantic class
1692+
if not parent_origin:
1693+
continue
1694+
1695+
# if class = FooBar inherits from Baz[str]:
1696+
# parent = Baz[str],
1697+
# parent_origin = Baz,
1698+
# generic_type_vars = (type vars in Baz)
1699+
# generic_map = {type var in Baz: str}
1700+
generic_type_vars: Tuple = getattr(parent_origin, "__parameters__", tuple())
1701+
generic_map = {
1702+
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent))
1703+
}
1704+
for field in getattr(parent_origin, "__annotations__", dict()):
1705+
annotations[field] = _replace_type_vars(
1706+
annotations[field], generic_map, default_to_bound
1707+
)
1708+
1709+
return {
1710+
k: _replace_type_vars(v, default_to_bound=default_to_bound)
1711+
for k, v in annotations.items()
1712+
}
1713+
1714+
1715+
def _replace_type_vars(
1716+
type_: Type,
1717+
generic_map: Optional[Dict[TypeVar, Type]] = None,
1718+
default_to_bound: bool = True,
1719+
) -> Type:
1720+
generic_map = generic_map or {}
1721+
if isinstance(type_, TypeVar):
1722+
if type_ in generic_map:
1723+
return generic_map[type_]
1724+
elif default_to_bound:
1725+
return type_.__bound__ or Any
1726+
else:
1727+
return type_
1728+
elif (origin := get_origin(type_)) and (args := get_args(type_)):
1729+
new_args = tuple(
1730+
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args
1731+
)
1732+
return _py_38_safe_origin(origin)[new_args]
1733+
else:
1734+
return type_

libs/core/langchain_core/utils/pydantic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def is_pydantic_v1_subclass(cls: Type) -> bool:
5353
return False
5454

5555

56+
def is_pydantic_v2_subclass(cls: Type) -> bool:
57+
"""Check if the installed Pydantic version is 1.x-like."""
58+
from pydantic import BaseModel
59+
60+
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
61+
62+
5663
def is_basemodel_subclass(cls: Type) -> bool:
5764
"""Check if the given class is a subclass of Pydantic BaseModel.
5865

libs/core/tests/unit_tests/test_tools.py

Lines changed: 197 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,22 @@
88
from datetime import datetime
99
from enum import Enum
1010
from functools import partial
11-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
11+
from typing import (
12+
Any,
13+
Callable,
14+
Dict,
15+
Generic,
16+
List,
17+
Literal,
18+
Optional,
19+
Tuple,
20+
Type,
21+
Union,
22+
)
1223

1324
import pytest
14-
from typing_extensions import Annotated, TypedDict
25+
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
26+
from typing_extensions import Annotated, TypedDict, TypeVar
1527

1628
from langchain_core.callbacks import (
1729
AsyncCallbackManagerForToolRun,
@@ -32,12 +44,13 @@
3244
StructuredTool,
3345
Tool,
3446
ToolException,
47+
_get_all_basemodel_annotations,
3548
_is_message_content_block,
3649
_is_message_content_type,
3750
tool,
3851
)
3952
from langchain_core.utils.function_calling import convert_to_openai_function
40-
from langchain_core.utils.pydantic import _create_subset_model
53+
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
4154
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
4255
from tests.unit_tests.pydantic_utils import _schema
4356

@@ -1452,6 +1465,66 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
14521465
}
14531466

14541467

1468+
def test_tool_inherited_injected_arg() -> None:
1469+
class barSchema(BaseModel):
1470+
"""bar."""
1471+
1472+
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
1473+
..., description="123"
1474+
)
1475+
1476+
class fooSchema(barSchema):
1477+
"""foo."""
1478+
1479+
x: int = Field(..., description="abc")
1480+
1481+
class InheritedInjectedArgTool(BaseTool):
1482+
name: str = "foo"
1483+
description: str = "foo."
1484+
args_schema: Type[BaseModel] = fooSchema
1485+
1486+
def _run(self, x: int, y: str) -> Any:
1487+
return y
1488+
1489+
tool_ = InheritedInjectedArgTool()
1490+
assert tool_.get_input_schema().schema() == {
1491+
"title": "fooSchema",
1492+
"description": "foo.",
1493+
"type": "object",
1494+
"properties": {
1495+
"x": {"description": "abc", "title": "X", "type": "integer"},
1496+
"y": {"description": "123", "title": "Y", "type": "string"},
1497+
},
1498+
"required": ["y", "x"],
1499+
}
1500+
assert tool_.tool_call_schema.schema() == {
1501+
"title": "foo",
1502+
"description": "foo.",
1503+
"type": "object",
1504+
"properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},
1505+
"required": ["x"],
1506+
}
1507+
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
1508+
assert tool_.invoke(
1509+
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
1510+
) == ToolMessage("bar", tool_call_id="123", name="foo")
1511+
expected_error = (
1512+
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
1513+
)
1514+
with pytest.raises(expected_error):
1515+
tool_.invoke({"x": 5})
1516+
1517+
assert convert_to_openai_function(tool_) == {
1518+
"name": "foo",
1519+
"description": "foo.",
1520+
"parameters": {
1521+
"type": "object",
1522+
"properties": {"x": {"type": "integer", "description": "abc"}},
1523+
"required": ["x"],
1524+
},
1525+
}
1526+
1527+
14551528
def _get_parametrized_tools() -> list:
14561529
def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
14571530
"""my_tool."""
@@ -1484,7 +1557,6 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
14841557

14851558
def generate_models() -> List[Any]:
14861559
"""Generate a list of base models depending on the pydantic version."""
1487-
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
14881560

14891561
class FooProper(BaseModelProper):
14901562
a: int
@@ -1670,3 +1742,124 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None:
16701742
)
16711743
def test__is_message_content_type(obj: Any, expected: bool) -> None:
16721744
assert _is_message_content_type(obj) is expected
1745+
1746+
1747+
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
1748+
@pytest.mark.parametrize("use_v1_namespace", [True, False])
1749+
def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
1750+
A = TypeVar("A")
1751+
1752+
if use_v1_namespace:
1753+
1754+
class ModelA(BaseModel, Generic[A]):
1755+
a: A
1756+
else:
1757+
1758+
class ModelA(BaseModelProper, Generic[A]): # type: ignore[no-redef]
1759+
a: A
1760+
1761+
class ModelB(ModelA[str]):
1762+
b: Annotated[ModelA[Dict[str, Any]], "foo"]
1763+
1764+
class Mixin(object):
1765+
def foo(self) -> str:
1766+
return "foo"
1767+
1768+
class ModelC(Mixin, ModelB):
1769+
c: dict
1770+
1771+
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
1772+
actual = _get_all_basemodel_annotations(ModelC)
1773+
assert actual == expected
1774+
1775+
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
1776+
actual = _get_all_basemodel_annotations(ModelB)
1777+
assert actual == expected
1778+
1779+
expected = {"a": Any}
1780+
actual = _get_all_basemodel_annotations(ModelA)
1781+
assert actual == expected
1782+
1783+
expected = {"a": int}
1784+
actual = _get_all_basemodel_annotations(ModelA[int])
1785+
assert actual == expected
1786+
1787+
D = TypeVar("D", bound=Union[str, int])
1788+
1789+
class ModelD(ModelC, Generic[D]):
1790+
d: Optional[D]
1791+
1792+
expected = {
1793+
"a": str,
1794+
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
1795+
"c": dict,
1796+
"d": Union[str, int, None],
1797+
}
1798+
actual = _get_all_basemodel_annotations(ModelD)
1799+
assert actual == expected
1800+
1801+
expected = {
1802+
"a": str,
1803+
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
1804+
"c": dict,
1805+
"d": Union[int, None],
1806+
}
1807+
actual = _get_all_basemodel_annotations(ModelD[int])
1808+
assert actual == expected
1809+
1810+
1811+
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Testing pydantic v1.")
1812+
def test__get_all_basemodel_annotations_v1() -> None:
1813+
A = TypeVar("A")
1814+
1815+
class ModelA(BaseModel, Generic[A]):
1816+
a: A
1817+
1818+
class ModelB(ModelA[str]):
1819+
b: Annotated[ModelA[Dict[str, Any]], "foo"]
1820+
1821+
class Mixin(object):
1822+
def foo(self) -> str:
1823+
return "foo"
1824+
1825+
class ModelC(Mixin, ModelB):
1826+
c: dict
1827+
1828+
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
1829+
actual = _get_all_basemodel_annotations(ModelC)
1830+
assert actual == expected
1831+
1832+
expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
1833+
actual = _get_all_basemodel_annotations(ModelB)
1834+
assert actual == expected
1835+
1836+
expected = {"a": Any}
1837+
actual = _get_all_basemodel_annotations(ModelA)
1838+
assert actual == expected
1839+
1840+
expected = {"a": int}
1841+
actual = _get_all_basemodel_annotations(ModelA[int])
1842+
assert actual == expected
1843+
1844+
D = TypeVar("D", bound=Union[str, int])
1845+
1846+
class ModelD(ModelC, Generic[D]):
1847+
d: Optional[D]
1848+
1849+
expected = {
1850+
"a": str,
1851+
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
1852+
"c": dict,
1853+
"d": Union[str, int, None],
1854+
}
1855+
actual = _get_all_basemodel_annotations(ModelD)
1856+
assert actual == expected
1857+
1858+
expected = {
1859+
"a": str,
1860+
"b": Annotated[ModelA[Dict[str, Any]], "foo"],
1861+
"c": dict,
1862+
"d": Union[int, None],
1863+
}
1864+
actual = _get_all_basemodel_annotations(ModelD[int])
1865+
assert actual == expected

0 commit comments

Comments
 (0)