From b5348113a711645fd7cf114e7e8f0559acadf122 Mon Sep 17 00:00:00 2001 From: Benjamin Rapaport Date: Wed, 8 May 2024 17:05:44 -0400 Subject: [PATCH 1/3] First attempt at gemini tool calling for nested structures. It might work but it is very complex. The test passes. --- mirascope/gemini/tools.py | 49 ++++++++-- poetry.lock | 179 +++++++++++++++++++++++++------------ pyproject.toml | 5 +- tests/gemini/test_tools.py | 24 ++++- 4 files changed, 189 insertions(+), 68 deletions(-) diff --git a/mirascope/gemini/tools.py b/mirascope/gemini/tools.py index 477e194e..51791138 100644 --- a/mirascope/gemini/tools.py +++ b/mirascope/gemini/tools.py @@ -1,7 +1,7 @@ """Classes for using tools with Google's Gemini API.""" from __future__ import annotations -from typing import Callable, Type +from typing import Any, Callable, Type, Union from google.ai.generativelanguage import FunctionCall from google.generativeai.types import ( # type: ignore @@ -19,6 +19,33 @@ ) +class NoDefsFunctionDeclaration(FunctionDeclaration): + def __init__(self, **kwargs): + self._defs = kwargs.pop("$defs", None) # Store $defs separately + super().__init__(**kwargs) + + +def resolve_refs( + schema: Union[dict[str, Any], list[Any]], defs: dict[str, Any] +) -> Union[dict[str, Any], list[Any]]: + """Recursively resolves $ref references within a schema.""" + if isinstance(schema, dict): + if "$ref" in schema: + ref_path = schema["$ref"].lstrip("#/") + if ref_path.startswith("$defs/"): + ref_name = ref_path.split("/")[1] + ref_schema = defs.get(ref_name) + if ref_schema is None: + raise ValueError(f"Invalid reference: {schema['$ref']}") + return resolve_refs(ref_schema, defs) + else: + return {k: resolve_refs(v, defs) for k, v in schema.items() if k != "title"} + elif isinstance(schema, list): + return [resolve_refs(item, defs) for item in schema] + else: + return schema + + class GeminiTool(BaseTool[FunctionCall]): '''A base class for easy use of tools with the Gemini API. @@ -70,18 +97,28 @@ def tool_schema(cls) -> Tool: """ tool_schema = super().tool_schema() if "parameters" in tool_schema: + # Handle nested structures with $defs if "$defs" in tool_schema["parameters"]: - raise ValueError( - "Unfortunately Google's Gemini API cannot handle nested structures " - "with $defs." - ) + defs = tool_schema["parameters"]["$defs"] + # Resolve references in properties + for key, prop_schema in tool_schema["parameters"]["properties"].items(): + tool_schema["parameters"]["properties"][key] = resolve_refs( + prop_schema, defs + ) + + # Remove $defs after resolving references + del tool_schema["parameters"]["$defs"] + + # Remove title from properties tool_schema["parameters"]["properties"] = { prop: { key: value for key, value in prop_schema.items() if key != "title" } for prop, prop_schema in tool_schema["parameters"]["properties"].items() } - return Tool(function_declarations=[FunctionDeclaration(**tool_schema)]) + + # Use CustomFunctionDeclaration to handle $defs + return Tool(function_declarations=[NoDefsFunctionDeclaration(**tool_schema)]) @classmethod def from_tool_call(cls, tool_call: FunctionCall) -> GeminiTool: diff --git a/poetry.lock b/poetry.lock index 1456e85e..82acd1d7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -164,13 +164,13 @@ files = [ [[package]] name = "anthropic" -version = "0.25.7" +version = "0.25.8" description = "The official Python library for the anthropic API" optional = true python-versions = ">=3.7" files = [ - {file = "anthropic-0.25.7-py3-none-any.whl", hash = "sha256:419a276eb20cfb7ddaac03c7e28e4e12df3ace71bcf33071a68c9a03c0dfcbdd"}, - {file = "anthropic-0.25.7.tar.gz", hash = "sha256:e7de4c8ba8e7e8248ad7f05ed9176634780b95b67c678d23915d8964c8a26f4e"}, + {file = "anthropic-0.25.8-py3-none-any.whl", hash = "sha256:c7a0091916eb22a5e0012b725f5492779eedfcad2da8dc906082e1db7596a65c"}, + {file = "anthropic-0.25.8.tar.gz", hash = "sha256:93f6063e96d5dbeaa172edc177762f630e55b2f81595cedb760278b95a2dd03e"}, ] [package.dependencies] @@ -208,17 +208,6 @@ doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphin test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] trio = ["trio (>=0.23)"] -[[package]] -name = "appdirs" -version = "1.4.4" -description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -optional = true -python-versions = "*" -files = [ - {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, - {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, -] - [[package]] name = "asgiref" version = "3.8.1" @@ -360,7 +349,7 @@ virtualenv = ["virtualenv (>=20.0.35)"] name = "cachetools" version = "5.3.3" description = "Extensible memoizing collections and decorators" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "cachetools-5.3.3-py3-none-any.whl", hash = "sha256:0abad1021d3f8325b2fc1d2e9c8b9c9d57b04c3932657a72465447332c24d945"}, @@ -570,13 +559,13 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [[package]] name = "cohere" -version = "5.3.4" +version = "5.3.5" description = "" optional = true python-versions = "<4.0,>=3.8" files = [ - {file = "cohere-5.3.4-py3-none-any.whl", hash = "sha256:2d62306ed86c346d472e2355cf653b8fbd8722330e554355d576fb6061e204f2"}, - {file = "cohere-5.3.4.tar.gz", hash = "sha256:b45a2c85b025dd15367d69bf92748de4a30ec82abb4b08fa5301b72e89051fe2"}, + {file = "cohere-5.3.5-py3-none-any.whl", hash = "sha256:49683c813d63f6d3fd212f1105839d171d0db46303a33f96b69c1528c2806b06"}, + {file = "cohere-5.3.5.tar.gz", hash = "sha256:dc113c618312a2d0d004f848e3948809a65f463dc623b3aeca325be062e24d25"}, ] [package.dependencies] @@ -1031,7 +1020,7 @@ test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", name = "google-ai-generativelanguage" version = "0.6.2" description = "Google Ai Generativelanguage API client library" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "google-ai-generativelanguage-0.6.2.tar.gz", hash = "sha256:308791ac3b9dad015b359172970739aa3753dd542142a416d07f9fa047e22386"}, @@ -1048,7 +1037,7 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4 name = "google-api-core" version = "2.19.0" description = "Google API client core library" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "google-api-core-2.19.0.tar.gz", hash = "sha256:cf1b7c2694047886d2af1128a03ae99e391108a08804f87cfd35970e49c9cd10"}, @@ -1077,13 +1066,13 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] [[package]] name = "google-api-python-client" -version = "2.127.0" +version = "2.128.0" description = "Google API Client Library for Python" -optional = true +optional = false python-versions = ">=3.7" files = [ - {file = "google-api-python-client-2.127.0.tar.gz", hash = "sha256:bbb51b0fbccdf40e536c26341e372d7800f09afebb53103bbcc94e08f14b523b"}, - {file = "google_api_python_client-2.127.0-py2.py3-none-any.whl", hash = "sha256:d01c70c7840ec37888aa02b1aea5d9baba4c1701e268d1a0251640afd56e5e90"}, + {file = "google-api-python-client-2.128.0.tar.gz", hash = "sha256:908af182dfc1cd79412a489b37fe45e4f3cc99c74e80c7c477ca5babaa54eea5"}, + {file = "google_api_python_client-2.128.0-py2.py3-none-any.whl", hash = "sha256:99da6acb0acc648e309102b0e0262d7fef30f07f6bf56c6eeaa0504ceca113e3"}, ] [package.dependencies] @@ -1097,7 +1086,7 @@ uritemplate = ">=3.0.1,<5" name = "google-auth" version = "2.29.0" description = "Google Authentication Library" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "google-auth-2.29.0.tar.gz", hash = "sha256:672dff332d073227550ffc7457868ac4218d6c500b155fe6cc17d2b13602c360"}, @@ -1120,7 +1109,7 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"] name = "google-auth-httplib2" version = "0.2.0" description = "Google Authentication Library: httplib2 transport" -optional = true +optional = false python-versions = "*" files = [ {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"}, @@ -1135,7 +1124,7 @@ httplib2 = ">=0.19.0" name = "google-generativeai" version = "0.5.2" description = "Google Generative AI High level API client library and tools." -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "google_generativeai-0.5.2-py3-none-any.whl", hash = "sha256:56f39485a0a673c93c21ec31c17809cc6a964193fb77b7ce809ad15d0dd72d7b"}, @@ -1158,7 +1147,7 @@ dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "py name = "googleapis-common-protos" version = "1.63.0" description = "Common protobufs used in Google APIs" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "googleapis-common-protos-1.63.0.tar.gz", hash = "sha256:17ad01b11d5f1d0171c06d3ba5c04c54474e883b66b949722b4938ee2694ef4e"}, @@ -1320,7 +1309,7 @@ typing-extensions = ">=4.7,<5" name = "grpcio" version = "1.63.0" description = "HTTP/2-based RPC framework" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, @@ -1378,7 +1367,7 @@ protobuf = ["grpcio-tools (>=1.63.0)"] name = "grpcio-status" version = "1.62.2" description = "Status proto mapping for gRPC" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "grpcio-status-1.62.2.tar.gz", hash = "sha256:62e1bfcb02025a1cd73732a2d33672d3e9d0df4d21c12c51e0bbcaf09bab742a"}, @@ -1426,7 +1415,7 @@ trio = ["trio (>=0.22.0,<0.26.0)"] name = "httplib2" version = "0.22.0" description = "A comprehensive HTTP client library." -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"}, @@ -1720,6 +1709,56 @@ sqlalchemy = ["opentelemetry-instrumentation-sqlalchemy (>=0.42b0)"] starlette = ["opentelemetry-instrumentation-starlette (>=0.42b0)"] system-metrics = ["opentelemetry-instrumentation-system-metrics (>=0.42b0)"] +[[package]] +name = "lz4" +version = "4.3.3" +description = "LZ4 Bindings for Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "lz4-4.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b891880c187e96339474af2a3b2bfb11a8e4732ff5034be919aa9029484cd201"}, + {file = "lz4-4.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:222a7e35137d7539c9c33bb53fcbb26510c5748779364014235afc62b0ec797f"}, + {file = "lz4-4.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f76176492ff082657ada0d0f10c794b6da5800249ef1692b35cf49b1e93e8ef7"}, + {file = "lz4-4.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1d18718f9d78182c6b60f568c9a9cec8a7204d7cb6fad4e511a2ef279e4cb05"}, + {file = "lz4-4.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6cdc60e21ec70266947a48839b437d46025076eb4b12c76bd47f8e5eb8a75dcc"}, + {file = "lz4-4.3.3-cp310-cp310-win32.whl", hash = "sha256:c81703b12475da73a5d66618856d04b1307e43428a7e59d98cfe5a5d608a74c6"}, + {file = "lz4-4.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:43cf03059c0f941b772c8aeb42a0813d68d7081c009542301637e5782f8a33e2"}, + {file = "lz4-4.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:30e8c20b8857adef7be045c65f47ab1e2c4fabba86a9fa9a997d7674a31ea6b6"}, + {file = "lz4-4.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2f7b1839f795315e480fb87d9bc60b186a98e3e5d17203c6e757611ef7dcef61"}, + {file = "lz4-4.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edfd858985c23523f4e5a7526ca6ee65ff930207a7ec8a8f57a01eae506aaee7"}, + {file = "lz4-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e9c410b11a31dbdc94c05ac3c480cb4b222460faf9231f12538d0074e56c563"}, + {file = "lz4-4.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d2507ee9c99dbddd191c86f0e0c8b724c76d26b0602db9ea23232304382e1f21"}, + {file = "lz4-4.3.3-cp311-cp311-win32.whl", hash = "sha256:f180904f33bdd1e92967923a43c22899e303906d19b2cf8bb547db6653ea6e7d"}, + {file = "lz4-4.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:b14d948e6dce389f9a7afc666d60dd1e35fa2138a8ec5306d30cd2e30d36b40c"}, + {file = "lz4-4.3.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e36cd7b9d4d920d3bfc2369840da506fa68258f7bb176b8743189793c055e43d"}, + {file = "lz4-4.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:31ea4be9d0059c00b2572d700bf2c1bc82f241f2c3282034a759c9a4d6ca4dc2"}, + {file = "lz4-4.3.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33c9a6fd20767ccaf70649982f8f3eeb0884035c150c0b818ea660152cf3c809"}, + {file = "lz4-4.3.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca8fccc15e3add173da91be8f34121578dc777711ffd98d399be35487c934bf"}, + {file = "lz4-4.3.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7d84b479ddf39fe3ea05387f10b779155fc0990125f4fb35d636114e1c63a2e"}, + {file = "lz4-4.3.3-cp312-cp312-win32.whl", hash = "sha256:337cb94488a1b060ef1685187d6ad4ba8bc61d26d631d7ba909ee984ea736be1"}, + {file = "lz4-4.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:5d35533bf2cee56f38ced91f766cd0038b6abf46f438a80d50c52750088be93f"}, + {file = "lz4-4.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:363ab65bf31338eb364062a15f302fc0fab0a49426051429866d71c793c23394"}, + {file = "lz4-4.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0a136e44a16fc98b1abc404fbabf7f1fada2bdab6a7e970974fb81cf55b636d0"}, + {file = "lz4-4.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:abc197e4aca8b63f5ae200af03eb95fb4b5055a8f990079b5bdf042f568469dd"}, + {file = "lz4-4.3.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56f4fe9c6327adb97406f27a66420b22ce02d71a5c365c48d6b656b4aaeb7775"}, + {file = "lz4-4.3.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0e822cd7644995d9ba248cb4b67859701748a93e2ab7fc9bc18c599a52e4604"}, + {file = "lz4-4.3.3-cp38-cp38-win32.whl", hash = "sha256:24b3206de56b7a537eda3a8123c644a2b7bf111f0af53bc14bed90ce5562d1aa"}, + {file = "lz4-4.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:b47839b53956e2737229d70714f1d75f33e8ac26e52c267f0197b3189ca6de24"}, + {file = "lz4-4.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6756212507405f270b66b3ff7f564618de0606395c0fe10a7ae2ffcbbe0b1fba"}, + {file = "lz4-4.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ee9ff50557a942d187ec85462bb0960207e7ec5b19b3b48949263993771c6205"}, + {file = "lz4-4.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b901c7784caac9a1ded4555258207d9e9697e746cc8532129f150ffe1f6ba0d"}, + {file = "lz4-4.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6d9ec061b9eca86e4dcc003d93334b95d53909afd5a32c6e4f222157b50c071"}, + {file = "lz4-4.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4c7bf687303ca47d69f9f0133274958fd672efaa33fb5bcde467862d6c621f0"}, + {file = "lz4-4.3.3-cp39-cp39-win32.whl", hash = "sha256:054b4631a355606e99a42396f5db4d22046a3397ffc3269a348ec41eaebd69d2"}, + {file = "lz4-4.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:eac9af361e0d98335a02ff12fb56caeb7ea1196cf1a49dbf6f17828a131da807"}, + {file = "lz4-4.3.3.tar.gz", hash = "sha256:01fe674ef2889dbb9899d8a67361e0c4a2c833af5aeb37dd505727cf5d2a131e"}, +] + +[package.extras] +docs = ["sphinx (>=1.6.0)", "sphinx-bootstrap-theme"] +flake8 = ["flake8"] +tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] + [[package]] name = "markdown" version = "3.5.2" @@ -2829,17 +2868,25 @@ files = [ [[package]] name = "pinecone-client" -version = "3.2.2" +version = "4.0.0" description = "Pinecone client and SDK" optional = true python-versions = "<4.0,>=3.8" files = [ - {file = "pinecone_client-3.2.2-py3-none-any.whl", hash = "sha256:7e492fdda23c73726bc0cb94c689bb950d06fb94e82b701a0c610c2e830db327"}, - {file = "pinecone_client-3.2.2.tar.gz", hash = "sha256:887a12405f90ac11c396490f605fc479f31cf282361034d1ae0fccc02ac75bee"}, + {file = "pinecone_client-4.0.0-py3-none-any.whl", hash = "sha256:606a6acb70e387ebcc2ae29b1848f197e58226d1d7267ea87a22cb561a36bdf1"}, + {file = "pinecone_client-4.0.0.tar.gz", hash = "sha256:d44db212e64aa343d14efc9b08e9e45d98ba7b681f63298b87033542ea017f23"}, ] [package.dependencies] certifi = ">=2019.11.17" +googleapis-common-protos = {version = ">=1.53.0", optional = true, markers = "extra == \"grpc\""} +grpcio = [ + {version = ">=1.44.0", optional = true, markers = "python_version >= \"3.8\" and python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.59.0", optional = true, markers = "python_version >= \"3.11\" and python_version < \"4.0\" and extra == \"grpc\""}, +] +lz4 = {version = ">=3.1.3", optional = true, markers = "extra == \"grpc\""} +protobuf = {version = ">=4.25,<5.0", optional = true, markers = "extra == \"grpc\""} +protoc-gen-openapiv2 = {version = ">=0.0.1,<0.0.2", optional = true, markers = "extra == \"grpc\""} tqdm = ">=4.64.1" typing-extensions = ">=3.7.4" urllib3 = [ @@ -2848,7 +2895,7 @@ urllib3 = [ ] [package.extras] -grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "grpcio (>=1.59.0)", "lz4 (>=3.1.3)", "protobuf (>=3.20.0,<3.21.0)"] +grpc = ["googleapis-common-protos (>=1.53.0)", "grpcio (>=1.44.0)", "grpcio (>=1.59.0)", "lz4 (>=3.1.3)", "protobuf (>=4.25,<5.0)", "protoc-gen-openapiv2 (>=0.0.1,<0.0.2)"] [[package]] name = "platformdirs" @@ -2908,7 +2955,7 @@ test = ["coverage", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint" name = "proto-plus" version = "1.23.0" description = "Beautiful, Pythonic protocol buffers." -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "proto-plus-1.23.0.tar.gz", hash = "sha256:89075171ef11988b3fa157f5dbd8b9cf09d65fffee97e29ce403cd8defba19d2"}, @@ -2925,7 +2972,7 @@ testing = ["google-api-core[grpc] (>=1.31.5)"] name = "protobuf" version = "4.25.3" description = "" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"}, @@ -2941,6 +2988,21 @@ files = [ {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"}, ] +[[package]] +name = "protoc-gen-openapiv2" +version = "0.0.1" +description = "Provides the missing pieces for gRPC Gateway." +optional = true +python-versions = ">=3.6" +files = [ + {file = "protoc-gen-openapiv2-0.0.1.tar.gz", hash = "sha256:6f79188d842c13177c9c0558845442c340b43011bf67dfef1dfc3bc067506409"}, + {file = "protoc_gen_openapiv2-0.0.1-py3-none-any.whl", hash = "sha256:18090c8be3877c438e7da0f7eb7cace45a9a210306bca4707708dbad367857be"}, +] + +[package.dependencies] +googleapis-common-protos = "*" +protobuf = ">=4.21.0" + [[package]] name = "psutil" version = "5.9.8" @@ -3043,7 +3105,7 @@ numpy = ">=1.16.6" name = "pyasn1" version = "0.6.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "pyasn1-0.6.0-py2.py3-none-any.whl", hash = "sha256:cca4bb0f2df5504f02f6f8a775b6e416ff9b0b3b16f7ee80b5a3153d9b804473"}, @@ -3054,7 +3116,7 @@ files = [ name = "pyasn1-modules" version = "0.4.0" description = "A collection of ASN.1-based protocols modules" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "pyasn1_modules-0.4.0-py3-none-any.whl", hash = "sha256:be04f15b66c206eed667e0bb5ab27e2b1855ea54a842e5037738099e8ca4ae0b"}, @@ -3616,7 +3678,7 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] name = "rsa" version = "4.9" description = "Pure-Python RSA implementation" -optional = true +optional = false python-versions = ">=3.6,<4" files = [ {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, @@ -3983,17 +4045,18 @@ mpmath = ">=0.19" [[package]] name = "tenacity" -version = "8.2.3" +version = "8.3.0" description = "Retry code until it succeeds" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"}, - {file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"}, + {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"}, + {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"}, ] [package.extras] -doc = ["reno", "sphinx", "tornado (>=4.5)"] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tiktoken" @@ -4263,7 +4326,7 @@ files = [ name = "uritemplate" version = "4.1.1" description = "Implementation of RFC 6570 URI Templates" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"}, @@ -4373,26 +4436,31 @@ test = ["coverage", "flake8 (>=3.7)", "mypy", "pretend", "pytest"] [[package]] name = "wandb" -version = "0.16.6" +version = "0.17.0" description = "A CLI and library for interacting with the Weights & Biases API." optional = true python-versions = ">=3.7" files = [ - {file = "wandb-0.16.6-py3-none-any.whl", hash = "sha256:5810019a3b981c796e98ea58557a7c380f18834e0c6bdaed15df115522e5616e"}, - {file = "wandb-0.16.6.tar.gz", hash = "sha256:86f491e3012d715e0d7d7421a4d6de41abef643b7403046261f962f3e512fe1c"}, + {file = "wandb-0.17.0-py3-none-any.whl", hash = "sha256:b1b056b4cad83b00436cb76049fd29ecedc6045999dcaa5eba40db6680960ac2"}, + {file = "wandb-0.17.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e1e6f04e093a6a027dcb100618ca23b122d032204b2ed4c62e4e991a48041a6b"}, + {file = "wandb-0.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:feeb60d4ff506d2a6bc67f953b310d70b004faa789479c03ccd1559c6f1a9633"}, + {file = "wandb-0.17.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bed8a3dd404a639e6bf5fea38c6efe2fb98d416ff1db4fb51be741278ed328"}, + {file = "wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a1dd6e0e635cba3f6ed30b52c71739bdc2a3e57df155619d2d80ee952b4201"}, + {file = "wandb-0.17.0-py3-none-win32.whl", hash = "sha256:1f692d3063a0d50474022cfe6668e1828260436d1cd40827d1e136b7f730c74c"}, + {file = "wandb-0.17.0-py3-none-win_amd64.whl", hash = "sha256:ab582ca0d54d52ef5b991de0717350b835400d9ac2d3adab210022b68338d694"}, ] [package.dependencies] -appdirs = ">=1.4.3" -Click = ">=7.1,<8.0.0 || >8.0.0" +click = ">=7.1,<8.0.0 || >8.0.0" docker-pycreds = ">=0.4.0" -GitPython = ">=1.0.0,<3.1.29 || >3.1.29" +gitpython = ">=1.0.0,<3.1.29 || >3.1.29" +platformdirs = "*" protobuf = [ {version = ">=3.15.0,<4.21.0 || >4.21.0,<5", markers = "python_version == \"3.9\" and sys_platform == \"linux\""}, {version = ">=3.19.0,<4.21.0 || >4.21.0,<5", markers = "python_version > \"3.9\" or sys_platform != \"linux\""}, ] psutil = ">=5.0.0" -PyYAML = "*" +pyyaml = "*" requests = ">=2.0.0,<3" sentry-sdk = ">=1.0.0" setproctitle = "*" @@ -4400,13 +4468,12 @@ setuptools = "*" typing-extensions = {version = "*", markers = "python_version < \"3.10\""} [package.extras] -async = ["httpx (>=0.23.0)"] aws = ["boto3"] azure = ["azure-identity", "azure-storage-blob"] gcp = ["google-cloud-storage"] importers = ["filelock", "mlflow", "polars", "rich", "tenacity"] kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] -launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "tomli", "typing-extensions"] +launch = ["awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "pyyaml (>=6.0.0)", "tomli", "typing-extensions"] media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"] models = ["cloudpickle"] perf = ["orjson"] @@ -4907,4 +4974,4 @@ weave = ["weave"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "b5fff4ad626a2b0edcf4a1baf254774256b6c623fd942a24f1dbd639c3a63534" +content-hash = "1fa987a6e51859de42a6bf45dcfc6a866bf21076f9829c53894b6be18d86e7d2" diff --git a/pyproject.toml b/pyproject.toml index 13ef40dc..0a33c9eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,14 +22,15 @@ tenacity = "^8.2.3" mirascope-cli = { version = ">=0.1.3,<1.0.0", optional = true } wandb = { version = ">=0.16.4,<1.0.0", optional = true } weave = { version = ">=0.50.2,<1.0.0", optional = true } -google-generativeai = { version = ">=0.4.0,<1.0.0", optional = true } chromadb = { version = ">=0.4.24,<1.0.0", optional = true } anthropic = { version = ">=0.23.1,<1.0.0", optional = true } mistralai = { version = ">=0.1.6,<1.0.0", optional = true } groq = { version = ">=0.4.2,<1.0.0", optional = true } cohere = { version = "^5.2.5", optional = true } -pinecone-client = { version = "^3.2.2", optional = true } logfire = { version = ">=0.26.0,<1.0.0", optional = true } +pinecone-client = {extras = ["grpc"], version = "^4.0.0", optional = true} +google-generativeai = "^0.5.2" +google-ai-generativelanguage = "0.6.2" [tool.poetry.extras] cli = ["mirascope-cli"] diff --git a/tests/gemini/test_tools.py b/tests/gemini/test_tools.py index f47902a8..e87e33f3 100644 --- a/tests/gemini/test_tools.py +++ b/tests/gemini/test_tools.py @@ -20,14 +20,30 @@ def test_from_tool_call_no_args() -> None: def test_no_nested_tools() -> None: """Tests that a `ValueError` is raised when using nested tools with Gemin.""" + class Author(BaseModel): + given_name: str + family_name: str + class Book(BaseModel): - title: str + # Note: title cannot be a field name + author: Author + description: str + year: int - class Books(GeminiTool): + class Books(BaseModel): books: list[Book] - with pytest.raises(ValueError): - Books.tool_schema() + class BooksGeminiTool(GeminiTool): + books: list[Book] + + # with pytest.raises(ValueError): + # Books.tool_schema() + + print(f"{Books.model_json_schema()=}") + + s = BooksGeminiTool.tool_schema() + print(f"{s.to_proto()=}") + raise ValueError("uh oh") def fake_tool(param: str): From d45980606f812fd9ef841a456910c61f97a46c32 Mon Sep 17 00:00:00 2001 From: Benjamin Rapaport Date: Wed, 8 May 2024 17:42:10 -0400 Subject: [PATCH 2/3] trying with googles built in tools. But the only seem to work for dataclasses. --- mirascope/gemini/tools.py | 110 ++++++++++++++++++++------------------ poetry.lock | 13 ++++- pyproject.toml | 1 + 3 files changed, 71 insertions(+), 53 deletions(-) diff --git a/mirascope/gemini/tools.py b/mirascope/gemini/tools.py index 51791138..309fea14 100644 --- a/mirascope/gemini/tools.py +++ b/mirascope/gemini/tools.py @@ -1,49 +1,69 @@ """Classes for using tools with Google's Gemini API.""" from __future__ import annotations -from typing import Any, Callable, Type, Union +import dataclasses +from typing import Any, Callable, Type from google.ai.generativelanguage import FunctionCall from google.generativeai.types import ( # type: ignore - FunctionDeclaration, Tool, + content_types, +) +from google.generativeai.types.content_types import ( # type: ignore + CallableFunctionDeclaration, ) from pydantic import BaseModel, ConfigDict +from pydantic_core import PydanticUndefined from ..base import ( BaseTool, BaseType, convert_base_model_to_tool, - convert_base_type_to_tool, - convert_function_to_tool, + convert_base_type_to_tool, # type: ignore + convert_function_to_tool, # type: ignore ) -class NoDefsFunctionDeclaration(FunctionDeclaration): - def __init__(self, **kwargs): - self._defs = kwargs.pop("$defs", None) # Store $defs separately - super().__init__(**kwargs) - - -def resolve_refs( - schema: Union[dict[str, Any], list[Any]], defs: dict[str, Any] -) -> Union[dict[str, Any], list[Any]]: - """Recursively resolves $ref references within a schema.""" - if isinstance(schema, dict): - if "$ref" in schema: - ref_path = schema["$ref"].lstrip("#/") - if ref_path.startswith("$defs/"): - ref_name = ref_path.split("/")[1] - ref_schema = defs.get(ref_name) - if ref_schema is None: - raise ValueError(f"Invalid reference: {schema['$ref']}") - return resolve_refs(ref_schema, defs) +def pydantic_to_dataclass( + klass: Type[BaseModel], + classname: str | None = None, +) -> Any: + """ + Dataclass from Pydantic model + + Transferred entities: + * Field names + * Type annotations, except of Annotated etc + * Default factory or default value + + Validators are not transferred. + + Order of fields may change due to dataclass's positional arguments. + + """ + dataclass_args = [] + for name, info in klass.model_fields.items(): + if info.default_factory is not None: + dataclass_field = dataclasses.field( + default_factory=info.default_factory, + ) + dataclass_arg = (name, info.annotation, dataclass_field) + elif info.default is not PydanticUndefined: + dataclass_field = dataclasses.field( + default=info.get_default(), + ) + dataclass_arg = (name, info.annotation, dataclass_field) else: - return {k: resolve_refs(v, defs) for k, v in schema.items() if k != "title"} - elif isinstance(schema, list): - return [resolve_refs(item, defs) for item in schema] - else: - return schema + dataclass_arg = (name, info.annotation) + + dataclass_args.append(dataclass_arg) + + dataclass_args.sort(key=lambda arg: len(arg) > 2) + + return dataclasses.make_dataclass( + classname or f"__{klass.__name__}Dataclass", + dataclass_args, + ) class GeminiTool(BaseTool[FunctionCall]): @@ -95,30 +115,16 @@ def tool_schema(cls) -> Tool: Returns: The constructed `Tool` schema. """ - tool_schema = super().tool_schema() - if "parameters" in tool_schema: - # Handle nested structures with $defs - if "$defs" in tool_schema["parameters"]: - defs = tool_schema["parameters"]["$defs"] - # Resolve references in properties - for key, prop_schema in tool_schema["parameters"]["properties"].items(): - tool_schema["parameters"]["properties"][key] = resolve_refs( - prop_schema, defs - ) - - # Remove $defs after resolving references - del tool_schema["parameters"]["$defs"] - - # Remove title from properties - tool_schema["parameters"]["properties"] = { - prop: { - key: value for key, value in prop_schema.items() if key != "title" - } - for prop, prop_schema in tool_schema["parameters"]["properties"].items() - } - - # Use CustomFunctionDeclaration to handle $defs - return Tool(function_declarations=[NoDefsFunctionDeclaration(**tool_schema)]) + + dc_cls = pydantic_to_dataclass(cls) + + def fun(a: Type[dc_cls]): + pass + + cfd: CallableFunctionDeclaration = ( + content_types.FunctionDeclaration.from_function(fun) + ) + return Tool(function_declarations=[cfd]) @classmethod def from_tool_call(cls, tool_call: FunctionCall) -> GeminiTool: diff --git a/poetry.lock b/poetry.lock index 82acd1d7..10983ab2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1646,6 +1646,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jsonref" +version = "1.1.0" +description = "jsonref is a library for automatic dereferencing of JSON Reference objects for Python." +optional = false +python-versions = ">=3.7" +files = [ + {file = "jsonref-1.1.0-py3-none-any.whl", hash = "sha256:590dc7773df6c21cbf948b5dac07a72a251db28b0238ceecce0a2abfa8ec30a9"}, + {file = "jsonref-1.1.0.tar.gz", hash = "sha256:32fe8e1d85af0fdefbebce950af85590b22b60f9e95443176adbde4e1ecea552"}, +] + [[package]] name = "kubernetes" version = "29.0.0" @@ -4974,4 +4985,4 @@ weave = ["weave"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "1fa987a6e51859de42a6bf45dcfc6a866bf21076f9829c53894b6be18d86e7d2" +content-hash = "a78fe281e3f79d562b16888aeccb94f3d099c123f86f7c2dbe8a097a770d63bb" diff --git a/pyproject.toml b/pyproject.toml index 0a33c9eb..2239ff67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ logfire = { version = ">=0.26.0,<1.0.0", optional = true } pinecone-client = {extras = ["grpc"], version = "^4.0.0", optional = true} google-generativeai = "^0.5.2" google-ai-generativelanguage = "0.6.2" +jsonref = "^1.1.0" [tool.poetry.extras] cli = ["mirascope-cli"] From 38fe3ebe88f2ead75cfa4627092c3eab0f812acc Mon Sep 17 00:00:00 2001 From: Benjamin Rapaport Date: Wed, 8 May 2024 20:26:13 -0400 Subject: [PATCH 3/3] wip --- mirascope/gemini/tools.py | 102 +++++++++++++----------------- mirascope/gemini/utils.py | 55 ++++++++++++++++ tests/gemini/test_tools.py | 124 +++++++++++++++++++++++++++++++------ 3 files changed, 202 insertions(+), 79 deletions(-) create mode 100644 mirascope/gemini/utils.py diff --git a/mirascope/gemini/tools.py b/mirascope/gemini/tools.py index 309fea14..c230b594 100644 --- a/mirascope/gemini/tools.py +++ b/mirascope/gemini/tools.py @@ -1,71 +1,29 @@ """Classes for using tools with Google's Gemini API.""" from __future__ import annotations -import dataclasses +import pprint from typing import Any, Callable, Type +import jsonref from google.ai.generativelanguage import FunctionCall from google.generativeai.types import ( # type: ignore + FunctionDeclaration, Tool, - content_types, -) -from google.generativeai.types.content_types import ( # type: ignore - CallableFunctionDeclaration, ) from pydantic import BaseModel, ConfigDict -from pydantic_core import PydanticUndefined + +from mirascope.base.tools import DEFAULT_TOOL_DOCSTRING +from mirascope.gemini.utils import remove_invalid_title_keys_from_parameters from ..base import ( BaseTool, BaseType, convert_base_model_to_tool, - convert_base_type_to_tool, # type: ignore - convert_function_to_tool, # type: ignore + convert_base_type_to_tool, + convert_function_to_tool, ) -def pydantic_to_dataclass( - klass: Type[BaseModel], - classname: str | None = None, -) -> Any: - """ - Dataclass from Pydantic model - - Transferred entities: - * Field names - * Type annotations, except of Annotated etc - * Default factory or default value - - Validators are not transferred. - - Order of fields may change due to dataclass's positional arguments. - - """ - dataclass_args = [] - for name, info in klass.model_fields.items(): - if info.default_factory is not None: - dataclass_field = dataclasses.field( - default_factory=info.default_factory, - ) - dataclass_arg = (name, info.annotation, dataclass_field) - elif info.default is not PydanticUndefined: - dataclass_field = dataclasses.field( - default=info.get_default(), - ) - dataclass_arg = (name, info.annotation, dataclass_field) - else: - dataclass_arg = (name, info.annotation) - - dataclass_args.append(dataclass_arg) - - dataclass_args.sort(key=lambda arg: len(arg) > 2) - - return dataclasses.make_dataclass( - classname or f"__{klass.__name__}Dataclass", - dataclass_args, - ) - - class GeminiTool(BaseTool[FunctionCall]): '''A base class for easy use of tools with the Gemini API. @@ -115,16 +73,42 @@ def tool_schema(cls) -> Tool: Returns: The constructed `Tool` schema. """ - - dc_cls = pydantic_to_dataclass(cls) - - def fun(a: Type[dc_cls]): - pass - - cfd: CallableFunctionDeclaration = ( - content_types.FunctionDeclaration.from_function(fun) + super().tool_schema() + model_schema: dict[str, Any] = cls.model_json_schema() + pprint.pprint(model_schema) + + # Replace all references with their values + without_refs: dict[str, Any] = jsonref.replace_refs(model_schema) # type: ignore + pprint.pprint(without_refs) + + # Remove all Defs + without_refs.pop("$defs") + pprint.pprint(without_refs) + + # Get the name and description, and remove them from the schema + name: str = without_refs.pop("title") # type: ignore + description: str = ( # type: ignore + without_refs.pop("description", None) or DEFAULT_TOOL_DOCSTRING + ) + parameters: dict[str, Any] = without_refs + + # Remove all instances of title key in each param definition + # This is careful not to delete keys that represent a field with the name title + remove_invalid_title_keys_from_parameters(parameters) + + print(f"{name=}") + print(f"{description=}") + print(f"{parameters=}") + + return Tool( + function_declarations=[ + FunctionDeclaration( + name=name, + description=description, + parameters=parameters, + ) + ] ) - return Tool(function_declarations=[cfd]) @classmethod def from_tool_call(cls, tool_call: FunctionCall) -> GeminiTool: diff --git a/mirascope/gemini/utils.py b/mirascope/gemini/utils.py new file mode 100644 index 00000000..1a9374ea --- /dev/null +++ b/mirascope/gemini/utils.py @@ -0,0 +1,55 @@ +from typing import Any + + +def remove_invalid_title_keys_from_parameters(d: dict[str, Any] | Any) -> None: + """ + For each property, remove the title key. However, we make sure to only remove + the title key in each schema + Before + { + "properties": { + "books": { + "items": { + "properties": { + "author_name": {"title": "Author Name", "type": "string"}, + "title": {"title": "Title", "type": "string"}, + }, + "required": ["author_name", "title"], + "title": "Book", + "type": "object", + }, + "title": "Books", + "type": "array", + } + }, + "required": ["books"], + "title": "Books", + "type": "object", + } + + AFTER + { + "properties": { + "books": { + "items": { + "properties": { + "author_name": {"type": "string"}, + "title": {"type": "string"}, + }, + "required": ["author_name", "title"], + "type": "object", + }, + "type": "array", + } + }, + "required": ["books"], + "type": "object", + } + + """ + if isinstance(d, dict): + for key in list(d.keys()): + if key == "title" and "type" in d.keys(): + del d[key] + else: + remove_invalid_title_keys_from_parameters(d[key]) diff --git a/tests/gemini/test_tools.py b/tests/gemini/test_tools.py index e87e33f3..f0f462d5 100644 --- a/tests/gemini/test_tools.py +++ b/tests/gemini/test_tools.py @@ -1,4 +1,8 @@ """Tests for the `mirascope.gemini.tools` module.""" + +from enum import StrEnum, auto +from typing import Annotated + import pytest from google.ai.generativelanguage import FunctionCall from pydantic import BaseModel, Field @@ -17,33 +21,113 @@ def test_from_tool_call_no_args() -> None: NoDescription.from_tool_call(FunctionCall(name="NoDescription")) -def test_no_nested_tools() -> None: +""" +{ + "properties": { + "books": { + "items": { + "properties": { + "author_name": { + "description": "The formal name of the author.", + "type": "string", + }, + "title": { + "description": "The title of the book.", + "type": "string", + }, + "published_at": { + "anyOf": [ + { + "properties": { + "year": { + "default": None, + "title": "Year", + "type": "integer", + }, + "month": { + "anyOf": [ + {"type": "integer"}, + {"type": "null"}, + ], + "default": None, + "title": "Month", + }, + "day": { + "anyOf": [ + {"type": "integer"}, + {"type": "null"}, + ], + "default": None, + "title": "Day", + }, + }, + "title": "Date", + "type": "object", + }, + {"type": "null"}, + ], + "default": None, + "description": "When the book was published.", + }, + "category": { + "allOf": [ + { + "enum": ["fiction", "non_fiction"], + "title": "BookCategory", + "type": "string", + } + ], + "description": "The category of the book.", + }, + }, + "required": ["author_name", "title", "category"], + "type": "object", + }, + "type": "array", + } + }, + "required": ["books"], + "type": "object", +} +""" + + +def test_nested_tools() -> None: """Tests that a `ValueError` is raised when using nested tools with Gemin.""" - class Author(BaseModel): - given_name: str - family_name: str + class BookCategory(StrEnum): + FICTION = auto() + NON_FICTION = auto() - class Book(BaseModel): - # Note: title cannot be a field name - author: Author - description: str - year: int + class Date(BaseModel): + year: int = Field(default=None) + month: int | None = Field(default=None) + day: int | None = Field(default=None) - class Books(BaseModel): - books: list[Book] - - class BooksGeminiTool(GeminiTool): + class Book(BaseModel): + author_name: Annotated[ + str, + Field(..., description="The formal name of the author."), + ] + title: Annotated[ + str, + Field(..., description="The title of the book."), + ] + published_at: Annotated[ + Date | None, + Field(default=None, description="When the book was published."), + ] + + category: Annotated[ + BookCategory, Field(..., description="The category of the book.") + ] + + class Books(GeminiTool): books: list[Book] - # with pytest.raises(ValueError): - # Books.tool_schema() - - print(f"{Books.model_json_schema()=}") + Books.tool_schema() - s = BooksGeminiTool.tool_schema() - print(f"{s.to_proto()=}") - raise ValueError("uh oh") + raise ValueError("Raising to get print") def fake_tool(param: str):