Skip to content

Commit

Permalink
Merge pull request #889 from ss108/fix-tool-arg-read
Browse files Browse the repository at this point in the history
Fix tool arg read
  • Loading branch information
joschkabraun committed May 21, 2024
2 parents 7f4caf0 + 33b04ec commit 13a1f04
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions parea/wrapper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,21 @@ def clean_json_string(s):
return json_dumps(calls, indent=4)


def _resolve_functions(kwargs):
if "functions" in kwargs:
return kwargs.get("functions", [])
elif "tools" in kwargs:
tools = kwargs["tools"]
if isinstance(tools, list):
return [d.get("function", {}) for d in tools]

return [] # it is either a list or Stainless's `NotGiven`

return []


def _kwargs_to_llm_configuration(kwargs, model=None) -> LLMInputs:
functions = kwargs.get("functions", None) or [d.get("function", {}) for d in kwargs.get("tools", [])]
functions = _resolve_functions(kwargs)
function_call_default = "auto" if functions else None
return LLMInputs(
model=model or kwargs.get("model", None),
Expand Down Expand Up @@ -242,7 +255,12 @@ def _convert_oai_messages(messages: list) -> Union[List[Union[dict, Message]], N
if (is_chat_completion and m.role == "tool") or (isinstance(m, dict) and m.get("role") == "tool"):
tool_call_id = m.tool_call_id if is_chat_completion else m.get("tool_call_id")
content = m.content if is_chat_completion else m.get("content", "")
cleaned_messages.append(Message(role=Role.tool, content=json_dumps({"tool_call_id": tool_call_id, "content": content}, indent=4)))
cleaned_messages.append(
Message(
role=Role.tool,
content=json_dumps({"tool_call_id": tool_call_id, "content": content}, indent=4),
)
)
elif is_chat_completion:
cleaned_messages.append(
Message(
Expand Down Expand Up @@ -332,11 +350,21 @@ def _process_stream_response(content: list, tools: dict, data: dict, trace_id: s


def convert_openai_raw_stream_to_log(content: list, tools: dict, data: dict, trace_id: str):
log_in_thread(_process_stream_response, {"content": content, "tools": tools, "data": data, "trace_id": trace_id})
log_in_thread(
_process_stream_response,
{"content": content, "tools": tools, "data": data, "trace_id": trace_id},
)


def convert_openai_raw_to_log(r: dict, data: dict):
log_in_thread(_process_response, {"response": ChatCompletion(**r), "model_inputs": data, "trace_id": get_current_trace_id()})
log_in_thread(
_process_response,
{
"response": ChatCompletion(**r),
"model_inputs": data,
"trace_id": get_current_trace_id(),
},
)


def safe_format_template_to_prompt(_template: str, **kwargs) -> str:
Expand Down

0 comments on commit 13a1f04

Please sign in to comment.