Skip to content

Commit

Permalink
Function calling mode patch (#271)
Browse files Browse the repository at this point in the history
* Handle function_calling_mode when passed as a dict with allowed_func_names

* Add tests for tool_config

* Update alias FunctionCallingConfigType

* format

* Remove unnecessary test functions

* Replace alias with concrete data types for instance checks
  • Loading branch information
mayureshagashe2105 authored Apr 5, 2024
1 parent ba6b439 commit 7d2c17e
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 2 deletions.
17 changes: 15 additions & 2 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,12 +698,25 @@ class FunctionCallingConfigDict(TypedDict):
allowed_function_names: list[str]


FunctionCallingConfigType = Union[FunctionCallingConfigDict, glm.FunctionCallingConfig]
FunctionCallingConfigType = Union[
FunctionCallingModeType, FunctionCallingConfigDict, glm.FunctionCallingConfig
]


def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig:
if isinstance(obj, (FunctionCallingMode, str, int)):
if isinstance(obj, glm.FunctionCallingConfig):
return obj
elif isinstance(obj, (FunctionCallingMode, str, int)):
obj = {"mode": to_function_calling_mode(obj)}
elif isinstance(obj, dict):
obj = obj.copy()
mode = obj.pop("mode")
obj["mode"] = to_function_calling_mode(mode)
else:
raise TypeError(
f"Could not convert input to `glm.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n",
obj,
)

return glm.FunctionCallingConfig(obj)

Expand Down
93 changes: 93 additions & 0 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,99 @@ def test_tools(self):
self.assertLen(obr.tools, 1)
self.assertEqual(type(obr.tools[0]).to_dict(obr.tools[0]), tools)

@parameterized.named_parameters(
dict(
testcase_name="test_FunctionCallingMode_str",
tool_config={"function_calling_config": "any"},
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.ANY,
"allowed_function_names": [],
}
},
),
dict(
testcase_name="test_FunctionCallingMode_int",
tool_config={"function_calling_config": 1},
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.AUTO,
"allowed_function_names": [],
}
},
),
dict(
testcase_name="test_FunctionCallingMode",
tool_config={"function_calling_config": content_types.FunctionCallingMode.NONE},
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.NONE,
"allowed_function_names": [],
}
},
),
dict(
testcase_name="test_glm_FunctionCallingConfig",
tool_config={
"function_calling_config": glm.FunctionCallingConfig(
mode=content_types.FunctionCallingMode.AUTO
)
},
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.AUTO,
"allowed_function_names": [],
}
},
),
dict(
testcase_name="test_FunctionCallingConfigDict",
tool_config={
"function_calling_config": {
"mode": "mode_auto",
"allowed_function_names": ["datetime", "greetings", "random"],
}
},
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.AUTO,
"allowed_function_names": ["datetime", "greetings", "random"],
}
},
),
dict(
testcase_name="test_glm_ToolConfig",
tool_config=glm.ToolConfig(
function_calling_config=glm.FunctionCallingConfig(
mode=content_types.FunctionCallingMode.NONE
)
),
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.NONE,
"allowed_function_names": [],
}
},
),
)
def test_tool_config(self, tool_config, expected_tool_config):
tools = dict(
function_declarations=[
dict(name="datetime", description="Returns the current UTC date and time."),
dict(name="greetings", description="Returns a greeting."),
dict(name="random", description="Returns a random number."),
]
)
self.responses["generate_content"] = [simple_response("echo echo")]

model = generative_models.GenerativeModel("gemini-pro", tools=tools)
_ = model.generate_content("Hello", tools=[tools], tool_config=tool_config)

req = self.observed_requests[0]

self.assertLen(type(req.tools[0]).to_dict(req.tools[0]).get("function_declarations"), 3)
self.assertEqual(type(req.tool_config).to_dict(req.tool_config), expected_tool_config)

@parameterized.named_parameters(
["bare_str", "talk like a pirate", simple_part("talk like a pirate")],
[
Expand Down
94 changes: 94 additions & 0 deletions tests/test_generative_models_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from google.generativeai import client as client_lib
from google.generativeai import generative_models
from google.generativeai.types import content_types
import google.ai.generativelanguage as glm

from absl.testing import absltest
Expand Down Expand Up @@ -107,6 +108,99 @@ async def responses():

self.assertEqual(response.text, "world!")

@parameterized.named_parameters(
dict(
testcase_name="test_FunctionCallingMode_str",
tool_config={"function_calling_config": "any"},
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.ANY,
"allowed_function_names": [],
}
},
),
dict(
testcase_name="test_FunctionCallingMode_int",
tool_config={"function_calling_config": 1},
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.AUTO,
"allowed_function_names": [],
}
},
),
dict(
testcase_name="test_FunctionCallingMode",
tool_config={"function_calling_config": content_types.FunctionCallingMode.NONE},
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.NONE,
"allowed_function_names": [],
}
},
),
dict(
testcase_name="test_glm_FunctionCallingConfig",
tool_config={
"function_calling_config": glm.FunctionCallingConfig(
mode=content_types.FunctionCallingMode.AUTO
)
},
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.AUTO,
"allowed_function_names": [],
}
},
),
dict(
testcase_name="test_FunctionCallingConfigDict",
tool_config={
"function_calling_config": {
"mode": "mode_auto",
"allowed_function_names": ["datetime", "greetings", "random"],
}
},
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.AUTO,
"allowed_function_names": ["datetime", "greetings", "random"],
}
},
),
dict(
testcase_name="test_glm_ToolConfig",
tool_config=glm.ToolConfig(
function_calling_config=glm.FunctionCallingConfig(
mode=content_types.FunctionCallingMode.NONE
)
),
expected_tool_config={
"function_calling_config": {
"mode": content_types.FunctionCallingMode.NONE,
"allowed_function_names": [],
}
},
),
)
async def test_tool_config(self, tool_config, expected_tool_config):
tools = dict(
function_declarations=[
dict(name="datetime", description="Returns the current UTC date and time."),
dict(name="greetings", description="Returns a greeting."),
dict(name="random", description="Returns a random number."),
]
)
self.responses["generate_content"] = [simple_response("echo echo")]

model = generative_models.GenerativeModel("gemini-pro", tools=tools)
_ = await model.generate_content_async("Hello", tools=[tools], tool_config=tool_config)

req = self.observed_requests[0]

self.assertLen(type(req.tools[0]).to_dict(req.tools[0]).get("function_declarations"), 3)
self.assertEqual(type(req.tool_config).to_dict(req.tool_config), expected_tool_config)

@parameterized.named_parameters(
["basic", "Hello"],
["list", ["Hello"]],
Expand Down

0 comments on commit 7d2c17e

Please sign in to comment.