diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index fb50aba5..d5157923 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -8,7 +8,6 @@ from ..utils.trace_utils import trace_data from .wrapper import Wrapper - MODEL_COST_MAPPING: Dict[str, float] = { "gpt-4": 0.03, "gpt-4-0314": 0.03, @@ -75,9 +74,7 @@ def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], respon model_rate = OpenAIWrapper.get_model_cost(model) model_completion_rate = OpenAIWrapper.get_model_cost(model, is_completion=True) - completion_cost = model_completion_rate * ( - usage.get("completion_tokens", 0) / 1000 - ) + completion_cost = model_completion_rate * (usage.get("completion_tokens", 0) / 1000) prompt_cost = model_rate * (usage.get("prompt_tokens", 0) / 1000) total_cost = sum([prompt_cost, completion_cost]) @@ -119,12 +116,7 @@ def get_model_cost(model_name: str, is_completion: bool = False) -> float: cost = MODEL_COST_MAPPING.get(model_name, None) if cost is None: - msg = ( - f"Unknown model: {model_name}. " - f"Please provide a valid OpenAI model name. " - f"Known models are: {', '.join(MODEL_COST_MAPPING.keys())}" - ) + msg = f"Unknown model: {model_name}. " f"Please provide a valid OpenAI model name. " f"Known models are: {', '.join(MODEL_COST_MAPPING.keys())}" raise ValueError(msg) return cost -