diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 4ac5ceba5087c..e664703cfb257 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -34,6 +34,7 @@ model_validator, validate_arguments, ) +from pydantic._internal._model_construction import ModelMetaclass from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 from pydantic.v1 import validate_arguments as validate_arguments_v1 @@ -150,7 +151,10 @@ def _infer_arg_descriptions( fn, annotations, error_on_invalid_docstring=error_on_invalid_docstring ) else: - description = inspect.getdoc(fn) or "" + if isinstance(fn, ModelMetaclass): + description = fn.__doc__ or "" + else: + description = inspect.getdoc(fn) or "" arg_descriptions = {} if parse_docstring: _validate_docstring_args_against_annotations(arg_descriptions, annotations) diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index 4e202dfea3a01..0f3000f4cc381 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -16,6 +16,7 @@ ) from pydantic import Field, SkipValidation +from pydantic._internal._model_construction import ModelMetaclass from typing_extensions import override from langchain_core.callbacks import ( @@ -197,7 +198,10 @@ def add(a: int, b: int) -> int: description_ = source_function.__doc__ or None if description_ is None and args_schema: if isinstance(args_schema, type) and is_basemodel_subclass(args_schema): - description_ = args_schema.__doc__ or None + if isinstance(source_function, ModelMetaclass): + description_ = args_schema.__doc__ + else: + description_ = args_schema.__doc__ or None elif isinstance(args_schema, dict): description_ = args_schema.get("description") else: diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index f724bce74fd8f..5fc5613c12ee0 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -700,6 +700,12 @@ def test_missing_docstring() -> None: def search_api(query: str) -> str: return "API result" + @tool + class MyTool(BaseModel): + foo: str + + assert MyTool.description == "" # type: ignore[attr-defined] + def test_create_tool_positional_args() -> None: """Test that positional arguments are allowed."""