diff --git a/parea/wrapper/utils.py b/parea/wrapper/utils.py index cd1899f0..c0e218b0 100644 --- a/parea/wrapper/utils.py +++ b/parea/wrapper/utils.py @@ -212,8 +212,21 @@ def clean_json_string(s): return json_dumps(calls, indent=4) +def _resolve_functions(kwargs): + if "functions" in kwargs: + return kwargs.get("functions", []) + elif "tools" in kwargs: + tools = kwargs["tools"] + if isinstance(tools, list): + return [d.get("function", {}) for d in tools] + + return [] # it is either a list or Stainless's `NotGiven` + + return [] + + def _kwargs_to_llm_configuration(kwargs, model=None) -> LLMInputs: - functions = kwargs.get("functions", None) or [d.get("function", {}) for d in kwargs.get("tools", [])] + functions = _resolve_functions(kwargs) function_call_default = "auto" if functions else None return LLMInputs( model=model or kwargs.get("model", None), @@ -242,7 +255,12 @@ def _convert_oai_messages(messages: list) -> Union[List[Union[dict, Message]], N if (is_chat_completion and m.role == "tool") or (isinstance(m, dict) and m.get("role") == "tool"): tool_call_id = m.tool_call_id if is_chat_completion else m.get("tool_call_id") content = m.content if is_chat_completion else m.get("content", "") - cleaned_messages.append(Message(role=Role.tool, content=json_dumps({"tool_call_id": tool_call_id, "content": content}, indent=4))) + cleaned_messages.append( + Message( + role=Role.tool, + content=json_dumps({"tool_call_id": tool_call_id, "content": content}, indent=4), + ) + ) elif is_chat_completion: cleaned_messages.append( Message( @@ -332,11 +350,21 @@ def _process_stream_response(content: list, tools: dict, data: dict, trace_id: s def convert_openai_raw_stream_to_log(content: list, tools: dict, data: dict, trace_id: str): - log_in_thread(_process_stream_response, {"content": content, "tools": tools, "data": data, "trace_id": trace_id}) + log_in_thread( + _process_stream_response, + {"content": content, "tools": tools, "data": data, "trace_id": trace_id}, + ) def convert_openai_raw_to_log(r: dict, data: dict): - log_in_thread(_process_response, {"response": ChatCompletion(**r), "model_inputs": data, "trace_id": get_current_trace_id()}) + log_in_thread( + _process_response, + { + "response": ChatCompletion(**r), + "model_inputs": data, + "trace_id": get_current_trace_id(), + }, + ) def safe_format_template_to_prompt(_template: str, **kwargs) -> str: