Skip to content

Commit

Permalink
fix for reading tool argument to chat functions when creating LLMInputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ss108 committed May 21, 2024
1 parent 6bccc0b commit 33b04ec
Showing 1 changed file with 20 additions and 58 deletions.
78 changes: 20 additions & 58 deletions parea/wrapper/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Callable, Dict, List, Optional, Union

import json
import re
import sys
from functools import lru_cache, wraps
from typing import Callable, Dict, List, Optional, Union

import tiktoken
from openai import __version__ as openai_version
Expand All @@ -11,13 +12,7 @@
from parea.parea_logger import parea_logger
from parea.schemas.log import LLMInputs, Message, ModelParams, Role
from parea.schemas.models import UpdateLog, UpdateTraceScenario
from parea.utils.trace_utils import (
fill_trace_data,
get_current_trace_id,
log_in_thread,
trace_data,
trace_insert,
)
from parea.utils.trace_utils import fill_trace_data, get_current_trace_id, log_in_thread, trace_data, trace_insert
from parea.utils.universal_encoder import json_dumps

is_openai_1 = openai_version.startswith("1.")
Expand All @@ -36,10 +31,7 @@ def wrapper(*args, **kwargs):
while frame:
caller_names += frame.f_code.co_name + "|"
frame = frame.f_back
if any(
func_to_check.__name__ in caller_names
for func_to_check in funcs_to_check
):
if any(func_to_check.__name__ in caller_names for func_to_check in funcs_to_check):
return func(*args, **kwargs)
return decorator(self, func)(*args, **kwargs) # Include self

Expand All @@ -58,9 +50,7 @@ def _safe_encode(encoding, text):
return 0


def _num_tokens_from_messages(
messages, model="gpt-3.5-turbo-0613", is_azure: bool = False
):
def _num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613", is_azure: bool = False):
"""Return the number of tokens used by a list of messages.
source: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken
"""
Expand Down Expand Up @@ -90,14 +80,10 @@ def _num_tokens_from_messages(
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-4" in model:
print(
"Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
)
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return _num_tokens_from_messages(messages, model="gpt-4-0613")
else:
print(
Expand Down Expand Up @@ -165,9 +151,7 @@ def _num_tokens_from_functions(functions, function_call, model="gpt-3.5-turbo-06
num_tokens += 10
function_call_tokens = min(_safe_encode(encoding, "auto") - 1, 0)
if isinstance(function_call, dict):
function_call_tokens = min(
_safe_encode(encoding, json_dumps(function_call)) - 1, 0
)
function_call_tokens = min(_safe_encode(encoding, json_dumps(function_call)) - 1, 0)
return num_tokens + function_call_tokens


Expand All @@ -189,11 +173,7 @@ def _calculate_input_tokens(
) -> int:
is_azure = model.startswith("azure_") or model in AZURE_MODEL_INFO
num_function_tokens = _num_tokens_from_functions(functions, function_call, model)
num_input_tokens = (
_num_tokens_from_string(json_dumps(messages), model)
if model == "gpt-4-vision-preview"
else _num_tokens_from_messages(messages, model, is_azure)
)
num_input_tokens = _num_tokens_from_string(json_dumps(messages), model) if model == "gpt-4-vision-preview" else _num_tokens_from_messages(messages, model, is_azure)
return num_input_tokens + num_function_tokens


Expand Down Expand Up @@ -240,10 +220,10 @@ def _resolve_functions(kwargs):
if isinstance(tools, list):
return [d.get("function", {}) for d in tools]

return [] # it is either a list or Stainless's `NotGiven`
return [] # it is either a list or Stainless's `NotGiven`

return []


def _kwargs_to_llm_configuration(kwargs, model=None) -> LLMInputs:
functions = _resolve_functions(kwargs)
Expand All @@ -253,8 +233,7 @@ def _kwargs_to_llm_configuration(kwargs, model=None) -> LLMInputs:
provider="openai",
messages=_convert_oai_messages(kwargs.get("messages", None)),
functions=functions,
function_call=kwargs.get("function_call", function_call_default)
or kwargs.get("tool_choice", function_call_default),
function_call=kwargs.get("function_call", function_call_default) or kwargs.get("tool_choice", function_call_default),
model_params=ModelParams(
temp=kwargs.get("temperature", 1.0),
max_length=kwargs.get("max_tokens", None),
Expand All @@ -273,19 +252,13 @@ def _convert_oai_messages(messages: list) -> Union[List[Union[dict, Message]], N
cleaned_messages = []
for m in messages:
is_chat_completion = isinstance(m, ChatCompletionMessage)
if (is_chat_completion and m.role == "tool") or (
isinstance(m, dict) and m.get("role") == "tool"
):
tool_call_id = (
m.tool_call_id if is_chat_completion else m.get("tool_call_id")
)
if (is_chat_completion and m.role == "tool") or (isinstance(m, dict) and m.get("role") == "tool"):
tool_call_id = m.tool_call_id if is_chat_completion else m.get("tool_call_id")
content = m.content if is_chat_completion else m.get("content", "")
cleaned_messages.append(
Message(
role=Role.tool,
content=json_dumps(
{"tool_call_id": tool_call_id, "content": content}, indent=4
),
content=json_dumps({"tool_call_id": tool_call_id, "content": content}, indent=4),
)
)
elif is_chat_completion:
Expand Down Expand Up @@ -315,13 +288,8 @@ def _compute_cost(prompt_tokens: int, completion_tokens: int, model: str) -> flo
if model in AZURE_MODEL_INFO:
cost_per_token = AZURE_MODEL_INFO[model]
else:
cost_per_token = ALL_NON_AZURE_MODELS_INFO.get(
model, {"prompt": 0, "completion": 0}
)
cost = (
(prompt_tokens * cost_per_token["prompt"])
+ (completion_tokens * cost_per_token["completion"])
) / 1_000_000
cost_per_token = ALL_NON_AZURE_MODELS_INFO.get(model, {"prompt": 0, "completion": 0})
cost = ((prompt_tokens * cost_per_token["prompt"]) + (completion_tokens * cost_per_token["completion"])) / 1_000_000
cost = round(cost, 10)
return cost

Expand All @@ -341,9 +309,7 @@ def _process_response(response, model_inputs, trace_id):
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
"total_tokens": usage.prompt_tokens + usage.completion_tokens,
"cost": _compute_cost(
usage.prompt_tokens, usage.completion_tokens, response.model
),
"cost": _compute_cost(usage.prompt_tokens, usage.completion_tokens, response.model),
},
trace_id,
)
Expand All @@ -369,9 +335,7 @@ def _process_stream_response(content: list, tools: dict, data: dict, trace_id: s
data.get("function_call", "auto") or data.get("tool_choice", "auto"),
data.get("model"),
)
completion_tokens = _num_tokens_from_string(
final_content if final_content else json_dumps(tool_calls), model
)
completion_tokens = _num_tokens_from_string(final_content if final_content else json_dumps(tool_calls), model)
data = {
"configuration": _kwargs_to_llm_configuration(data, model),
"output": completion,
Expand All @@ -385,9 +349,7 @@ def _process_stream_response(content: list, tools: dict, data: dict, trace_id: s
parea_logger.default_log(data=data_with_config)


def convert_openai_raw_stream_to_log(
content: list, tools: dict, data: dict, trace_id: str
):
def convert_openai_raw_stream_to_log(content: list, tools: dict, data: dict, trace_id: str):
log_in_thread(
_process_stream_response,
{"content": content, "tools": tools, "data": data, "trace_id": trace_id},
Expand Down

0 comments on commit 33b04ec

Please sign in to comment.