|
8 | 8 | from datetime import datetime
|
9 | 9 | from enum import Enum
|
10 | 10 | from functools import partial
|
11 |
| -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union |
| 11 | +from typing import ( |
| 12 | + Any, |
| 13 | + Callable, |
| 14 | + Dict, |
| 15 | + Generic, |
| 16 | + List, |
| 17 | + Literal, |
| 18 | + Optional, |
| 19 | + Tuple, |
| 20 | + Type, |
| 21 | + Union, |
| 22 | +) |
12 | 23 |
|
13 | 24 | import pytest
|
14 |
| -from typing_extensions import Annotated, TypedDict |
| 25 | +from pydantic import BaseModel as BaseModelProper # pydantic: ignore |
| 26 | +from typing_extensions import Annotated, TypedDict, TypeVar |
15 | 27 |
|
16 | 28 | from langchain_core.callbacks import (
|
17 | 29 | AsyncCallbackManagerForToolRun,
|
|
32 | 44 | StructuredTool,
|
33 | 45 | Tool,
|
34 | 46 | ToolException,
|
| 47 | + _get_all_basemodel_annotations, |
35 | 48 | _is_message_content_block,
|
36 | 49 | _is_message_content_type,
|
37 | 50 | tool,
|
38 | 51 | )
|
39 | 52 | from langchain_core.utils.function_calling import convert_to_openai_function
|
40 |
| -from langchain_core.utils.pydantic import _create_subset_model |
| 53 | +from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model |
41 | 54 | from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
42 | 55 | from tests.unit_tests.pydantic_utils import _schema
|
43 | 56 |
|
@@ -1452,6 +1465,66 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
|
1452 | 1465 | }
|
1453 | 1466 |
|
1454 | 1467 |
|
| 1468 | +def test_tool_inherited_injected_arg() -> None: |
| 1469 | + class barSchema(BaseModel): |
| 1470 | + """bar.""" |
| 1471 | + |
| 1472 | + y: Annotated[str, "foobar comment", InjectedToolArg()] = Field( |
| 1473 | + ..., description="123" |
| 1474 | + ) |
| 1475 | + |
| 1476 | + class fooSchema(barSchema): |
| 1477 | + """foo.""" |
| 1478 | + |
| 1479 | + x: int = Field(..., description="abc") |
| 1480 | + |
| 1481 | + class InheritedInjectedArgTool(BaseTool): |
| 1482 | + name: str = "foo" |
| 1483 | + description: str = "foo." |
| 1484 | + args_schema: Type[BaseModel] = fooSchema |
| 1485 | + |
| 1486 | + def _run(self, x: int, y: str) -> Any: |
| 1487 | + return y |
| 1488 | + |
| 1489 | + tool_ = InheritedInjectedArgTool() |
| 1490 | + assert tool_.get_input_schema().schema() == { |
| 1491 | + "title": "fooSchema", |
| 1492 | + "description": "foo.", |
| 1493 | + "type": "object", |
| 1494 | + "properties": { |
| 1495 | + "x": {"description": "abc", "title": "X", "type": "integer"}, |
| 1496 | + "y": {"description": "123", "title": "Y", "type": "string"}, |
| 1497 | + }, |
| 1498 | + "required": ["y", "x"], |
| 1499 | + } |
| 1500 | + assert tool_.tool_call_schema.schema() == { |
| 1501 | + "title": "foo", |
| 1502 | + "description": "foo.", |
| 1503 | + "type": "object", |
| 1504 | + "properties": {"x": {"description": "abc", "title": "X", "type": "integer"}}, |
| 1505 | + "required": ["x"], |
| 1506 | + } |
| 1507 | + assert tool_.invoke({"x": 5, "y": "bar"}) == "bar" |
| 1508 | + assert tool_.invoke( |
| 1509 | + {"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"} |
| 1510 | + ) == ToolMessage("bar", tool_call_id="123", name="foo") |
| 1511 | + expected_error = ( |
| 1512 | + ValidationError if not isinstance(tool_, InjectedTool) else TypeError |
| 1513 | + ) |
| 1514 | + with pytest.raises(expected_error): |
| 1515 | + tool_.invoke({"x": 5}) |
| 1516 | + |
| 1517 | + assert convert_to_openai_function(tool_) == { |
| 1518 | + "name": "foo", |
| 1519 | + "description": "foo.", |
| 1520 | + "parameters": { |
| 1521 | + "type": "object", |
| 1522 | + "properties": {"x": {"type": "integer", "description": "abc"}}, |
| 1523 | + "required": ["x"], |
| 1524 | + }, |
| 1525 | + } |
| 1526 | + |
| 1527 | + |
1455 | 1528 | def _get_parametrized_tools() -> list:
|
1456 | 1529 | def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
|
1457 | 1530 | """my_tool."""
|
@@ -1484,7 +1557,6 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
|
1484 | 1557 |
|
1485 | 1558 | def generate_models() -> List[Any]:
|
1486 | 1559 | """Generate a list of base models depending on the pydantic version."""
|
1487 |
| - from pydantic import BaseModel as BaseModelProper # pydantic: ignore |
1488 | 1560 |
|
1489 | 1561 | class FooProper(BaseModelProper):
|
1490 | 1562 | a: int
|
@@ -1670,3 +1742,124 @@ def test__is_message_content_block(obj: Any, expected: bool) -> None:
|
1670 | 1742 | )
|
1671 | 1743 | def test__is_message_content_type(obj: Any, expected: bool) -> None:
|
1672 | 1744 | assert _is_message_content_type(obj) is expected
|
| 1745 | + |
| 1746 | + |
| 1747 | +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.") |
| 1748 | +@pytest.mark.parametrize("use_v1_namespace", [True, False]) |
| 1749 | +def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None: |
| 1750 | + A = TypeVar("A") |
| 1751 | + |
| 1752 | + if use_v1_namespace: |
| 1753 | + |
| 1754 | + class ModelA(BaseModel, Generic[A]): |
| 1755 | + a: A |
| 1756 | + else: |
| 1757 | + |
| 1758 | + class ModelA(BaseModelProper, Generic[A]): # type: ignore[no-redef] |
| 1759 | + a: A |
| 1760 | + |
| 1761 | + class ModelB(ModelA[str]): |
| 1762 | + b: Annotated[ModelA[Dict[str, Any]], "foo"] |
| 1763 | + |
| 1764 | + class Mixin(object): |
| 1765 | + def foo(self) -> str: |
| 1766 | + return "foo" |
| 1767 | + |
| 1768 | + class ModelC(Mixin, ModelB): |
| 1769 | + c: dict |
| 1770 | + |
| 1771 | + expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict} |
| 1772 | + actual = _get_all_basemodel_annotations(ModelC) |
| 1773 | + assert actual == expected |
| 1774 | + |
| 1775 | + expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]} |
| 1776 | + actual = _get_all_basemodel_annotations(ModelB) |
| 1777 | + assert actual == expected |
| 1778 | + |
| 1779 | + expected = {"a": Any} |
| 1780 | + actual = _get_all_basemodel_annotations(ModelA) |
| 1781 | + assert actual == expected |
| 1782 | + |
| 1783 | + expected = {"a": int} |
| 1784 | + actual = _get_all_basemodel_annotations(ModelA[int]) |
| 1785 | + assert actual == expected |
| 1786 | + |
| 1787 | + D = TypeVar("D", bound=Union[str, int]) |
| 1788 | + |
| 1789 | + class ModelD(ModelC, Generic[D]): |
| 1790 | + d: Optional[D] |
| 1791 | + |
| 1792 | + expected = { |
| 1793 | + "a": str, |
| 1794 | + "b": Annotated[ModelA[Dict[str, Any]], "foo"], |
| 1795 | + "c": dict, |
| 1796 | + "d": Union[str, int, None], |
| 1797 | + } |
| 1798 | + actual = _get_all_basemodel_annotations(ModelD) |
| 1799 | + assert actual == expected |
| 1800 | + |
| 1801 | + expected = { |
| 1802 | + "a": str, |
| 1803 | + "b": Annotated[ModelA[Dict[str, Any]], "foo"], |
| 1804 | + "c": dict, |
| 1805 | + "d": Union[int, None], |
| 1806 | + } |
| 1807 | + actual = _get_all_basemodel_annotations(ModelD[int]) |
| 1808 | + assert actual == expected |
| 1809 | + |
| 1810 | + |
| 1811 | +@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="Testing pydantic v1.") |
| 1812 | +def test__get_all_basemodel_annotations_v1() -> None: |
| 1813 | + A = TypeVar("A") |
| 1814 | + |
| 1815 | + class ModelA(BaseModel, Generic[A]): |
| 1816 | + a: A |
| 1817 | + |
| 1818 | + class ModelB(ModelA[str]): |
| 1819 | + b: Annotated[ModelA[Dict[str, Any]], "foo"] |
| 1820 | + |
| 1821 | + class Mixin(object): |
| 1822 | + def foo(self) -> str: |
| 1823 | + return "foo" |
| 1824 | + |
| 1825 | + class ModelC(Mixin, ModelB): |
| 1826 | + c: dict |
| 1827 | + |
| 1828 | + expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict} |
| 1829 | + actual = _get_all_basemodel_annotations(ModelC) |
| 1830 | + assert actual == expected |
| 1831 | + |
| 1832 | + expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]} |
| 1833 | + actual = _get_all_basemodel_annotations(ModelB) |
| 1834 | + assert actual == expected |
| 1835 | + |
| 1836 | + expected = {"a": Any} |
| 1837 | + actual = _get_all_basemodel_annotations(ModelA) |
| 1838 | + assert actual == expected |
| 1839 | + |
| 1840 | + expected = {"a": int} |
| 1841 | + actual = _get_all_basemodel_annotations(ModelA[int]) |
| 1842 | + assert actual == expected |
| 1843 | + |
| 1844 | + D = TypeVar("D", bound=Union[str, int]) |
| 1845 | + |
| 1846 | + class ModelD(ModelC, Generic[D]): |
| 1847 | + d: Optional[D] |
| 1848 | + |
| 1849 | + expected = { |
| 1850 | + "a": str, |
| 1851 | + "b": Annotated[ModelA[Dict[str, Any]], "foo"], |
| 1852 | + "c": dict, |
| 1853 | + "d": Union[str, int, None], |
| 1854 | + } |
| 1855 | + actual = _get_all_basemodel_annotations(ModelD) |
| 1856 | + assert actual == expected |
| 1857 | + |
| 1858 | + expected = { |
| 1859 | + "a": str, |
| 1860 | + "b": Annotated[ModelA[Dict[str, Any]], "foo"], |
| 1861 | + "c": dict, |
| 1862 | + "d": Union[int, None], |
| 1863 | + } |
| 1864 | + actual = _get_all_basemodel_annotations(ModelD[int]) |
| 1865 | + assert actual == expected |
0 commit comments