Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] make description optional #306

Merged
merged 5 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions erniebot-agent/src/erniebot_agent/tools/remote_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,19 @@ async def __post_process__(self, tool_response: dict) -> dict:
"请务必确保每个符合'file-'格式的字段只出现一次,无需将其转换为链接,也无需添加任何HTML、Markdown或其他格式化元素。"
)

if self.tool_view.returns is not None:
try:
origin_tool_response = deepcopy(tool_response)
valid_tool_response = self.tool_view.returns(**origin_tool_response).model_dump(mode="json")
tool_response.update(valid_tool_response)
except Exception as e:
_logger.warning(
"Unable to validate the 'tool_response' against the schema defined in the YAML file. "
f"The specific error encountered is: '<{e}>'. "
"As a result, the original response from the tool will be used.",
)
# if self.tool_view.returns is not None:
# try:
# origin_tool_response = deepcopy(tool_response)
# valid_tool_response = self.tool_view.returns(
# **origin_tool_response
# ).model_dump(mode="json")
# tool_response.update(valid_tool_response)
# except Exception as e:
# _logger.warning(
# "Unable to validate the 'tool_response' against the schema defined in the YAML file. "
# f"The specific error encountered is: '<{e}>'. "
# "As a result, the original response from the tool will be used.",
# )
return tool_response

async def __call__(self, **tool_arguments: Dict[str, Any]) -> Any:
Expand Down
3 changes: 0 additions & 3 deletions erniebot-agent/src/erniebot_agent/tools/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,6 @@ def from_openapi_dict(cls, schema: dict) -> Type[ToolParameterView]:
if "type" not in field_dict:
raise ToolError(f"`type` field not found in `{field_name}` property", stage="Loading")

if "description" not in field_dict:
raise ToolError(f"`description` field not found in `{field_name}` property", stage="Loading")

if field_name.startswith("__"):
continue

Expand Down
50 changes: 25 additions & 25 deletions erniebot-agent/tests/unit_tests/tools/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,31 +544,31 @@ async def test_enum_v1(self):
self.assertEqual(result["enum_field"], "2")
self.assertEqual(result["no_enum_field"], "no_enum_value")

@responses.activate
async def test_enum_v1_with_wrong_dtype(self):
tool = self.toolkit.get_tool("enum_v1")

responses.post(
"http://example.com/enum_v1_dtype",
json={"enum_field": 2, "no_enum_field": "no_enum_value"},
)

tool.tool_view.uri = "enum_v1_dtype"
with self.assertLogs("erniebot_agent.tools.remote_tool", level="INFO") as cm:
result = await tool()

logs = [item for item in cm.output if "Unable to validate the 'tool_response'" in item]

# test raise warning log msg
self.assertEqual(len(logs), 1)
warning_log_msg = (
"Unable to validate the 'tool_response' against the schema defined "
"in the YAML file. The specific error encountered is: '<1 validation error for "
)
self.assertIn(warning_log_msg, logs[0])

self.assertEqual(result["enum_field"], 2)
self.assertEqual(result["no_enum_field"], "no_enum_value")
# @responses.activate
# async def test_enum_v1_with_wrong_dtype(self):
# tool = self.toolkit.get_tool("enum_v1")

# responses.post(
# "http://example.com/enum_v1_dtype",
# json={"enum_field": 2, "no_enum_field": "no_enum_value"},
# )

# tool.tool_view.uri = "enum_v1_dtype"
# with self.assertLogs("erniebot_agent.tools.remote_tool", level="INFO") as cm:
# result = await tool()

# logs = [item for item in cm.output if "Unable to validate the 'tool_response'" in item]

# # test raise warning log msg
# self.assertEqual(len(logs), 1)
# warning_log_msg = (
# "Unable to validate the 'tool_response' against the schema defined "
# "in the YAML file. The specific error encountered is: '<1 validation error for "
# )
# self.assertIn(warning_log_msg, logs[0])

# self.assertEqual(result["enum_field"], 2)
# self.assertEqual(result["no_enum_field"], "no_enum_value")

@responses.activate
async def test_enum_v2(self):
Expand Down