diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index b9013a9fc..b4a97f7d0 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -698,12 +698,25 @@ class FunctionCallingConfigDict(TypedDict): allowed_function_names: list[str] -FunctionCallingConfigType = Union[FunctionCallingConfigDict, glm.FunctionCallingConfig] +FunctionCallingConfigType = Union[ + FunctionCallingModeType, FunctionCallingConfigDict, glm.FunctionCallingConfig +] def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig: - if isinstance(obj, (FunctionCallingMode, str, int)): + if isinstance(obj, glm.FunctionCallingConfig): + return obj + elif isinstance(obj, (FunctionCallingMode, str, int)): obj = {"mode": to_function_calling_mode(obj)} + elif isinstance(obj, dict): + obj = obj.copy() + mode = obj.pop("mode") + obj["mode"] = to_function_calling_mode(mode) + else: + raise TypeError( + f"Could not convert input to `glm.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n", + obj, + ) return glm.FunctionCallingConfig(obj) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 4d5488421..a7cdb3ff0 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -613,6 +613,99 @@ def test_tools(self): self.assertLen(obr.tools, 1) self.assertEqual(type(obr.tools[0]).to_dict(obr.tools[0]), tools) + @parameterized.named_parameters( + dict( + testcase_name="test_FunctionCallingMode_str", + tool_config={"function_calling_config": "any"}, + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.ANY, + "allowed_function_names": [], + } + }, + ), + dict( + testcase_name="test_FunctionCallingMode_int", + tool_config={"function_calling_config": 1}, + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.AUTO, + "allowed_function_names": [], + } + }, + ), + dict( + testcase_name="test_FunctionCallingMode", + tool_config={"function_calling_config": content_types.FunctionCallingMode.NONE}, + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.NONE, + "allowed_function_names": [], + } + }, + ), + dict( + testcase_name="test_glm_FunctionCallingConfig", + tool_config={ + "function_calling_config": glm.FunctionCallingConfig( + mode=content_types.FunctionCallingMode.AUTO + ) + }, + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.AUTO, + "allowed_function_names": [], + } + }, + ), + dict( + testcase_name="test_FunctionCallingConfigDict", + tool_config={ + "function_calling_config": { + "mode": "mode_auto", + "allowed_function_names": ["datetime", "greetings", "random"], + } + }, + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.AUTO, + "allowed_function_names": ["datetime", "greetings", "random"], + } + }, + ), + dict( + testcase_name="test_glm_ToolConfig", + tool_config=glm.ToolConfig( + function_calling_config=glm.FunctionCallingConfig( + mode=content_types.FunctionCallingMode.NONE + ) + ), + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.NONE, + "allowed_function_names": [], + } + }, + ), + ) + def test_tool_config(self, tool_config, expected_tool_config): + tools = dict( + function_declarations=[ + dict(name="datetime", description="Returns the current UTC date and time."), + dict(name="greetings", description="Returns a greeting."), + dict(name="random", description="Returns a random number."), + ] + ) + self.responses["generate_content"] = [simple_response("echo echo")] + + model = generative_models.GenerativeModel("gemini-pro", tools=tools) + _ = model.generate_content("Hello", tools=[tools], tool_config=tool_config) + + req = self.observed_requests[0] + + self.assertLen(type(req.tools[0]).to_dict(req.tools[0]).get("function_declarations"), 3) + self.assertEqual(type(req.tool_config).to_dict(req.tool_config), expected_tool_config) + @parameterized.named_parameters( ["bare_str", "talk like a pirate", simple_part("talk like a pirate")], [ diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index 6b30b04fc..8a93c2295 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -23,6 +23,7 @@ from google.generativeai import client as client_lib from google.generativeai import generative_models +from google.generativeai.types import content_types import google.ai.generativelanguage as glm from absl.testing import absltest @@ -107,6 +108,99 @@ async def responses(): self.assertEqual(response.text, "world!") + @parameterized.named_parameters( + dict( + testcase_name="test_FunctionCallingMode_str", + tool_config={"function_calling_config": "any"}, + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.ANY, + "allowed_function_names": [], + } + }, + ), + dict( + testcase_name="test_FunctionCallingMode_int", + tool_config={"function_calling_config": 1}, + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.AUTO, + "allowed_function_names": [], + } + }, + ), + dict( + testcase_name="test_FunctionCallingMode", + tool_config={"function_calling_config": content_types.FunctionCallingMode.NONE}, + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.NONE, + "allowed_function_names": [], + } + }, + ), + dict( + testcase_name="test_glm_FunctionCallingConfig", + tool_config={ + "function_calling_config": glm.FunctionCallingConfig( + mode=content_types.FunctionCallingMode.AUTO + ) + }, + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.AUTO, + "allowed_function_names": [], + } + }, + ), + dict( + testcase_name="test_FunctionCallingConfigDict", + tool_config={ + "function_calling_config": { + "mode": "mode_auto", + "allowed_function_names": ["datetime", "greetings", "random"], + } + }, + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.AUTO, + "allowed_function_names": ["datetime", "greetings", "random"], + } + }, + ), + dict( + testcase_name="test_glm_ToolConfig", + tool_config=glm.ToolConfig( + function_calling_config=glm.FunctionCallingConfig( + mode=content_types.FunctionCallingMode.NONE + ) + ), + expected_tool_config={ + "function_calling_config": { + "mode": content_types.FunctionCallingMode.NONE, + "allowed_function_names": [], + } + }, + ), + ) + async def test_tool_config(self, tool_config, expected_tool_config): + tools = dict( + function_declarations=[ + dict(name="datetime", description="Returns the current UTC date and time."), + dict(name="greetings", description="Returns a greeting."), + dict(name="random", description="Returns a random number."), + ] + ) + self.responses["generate_content"] = [simple_response("echo echo")] + + model = generative_models.GenerativeModel("gemini-pro", tools=tools) + _ = await model.generate_content_async("Hello", tools=[tools], tool_config=tool_config) + + req = self.observed_requests[0] + + self.assertLen(type(req.tools[0]).to_dict(req.tools[0]).get("function_declarations"), 3) + self.assertEqual(type(req.tool_config).to_dict(req.tool_config), expected_tool_config) + @parameterized.named_parameters( ["basic", "Hello"], ["list", ["Hello"]],