Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions coolprompt/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ def __init__(

logger.info("PromptTuner successfully initialized")

def get_stats(self):
if hasattr(self._target_model, "get_stats"):
return self._target_model.get_stats()
return None

def reset_stats(self):
if hasattr(self._target_model, "reset_stats"):
self._target_model.reset_stats()

def get_task_prompt_template(self, task: str, method: str) -> str:
"""Returns the prompt template for the given task.

Expand Down
9 changes: 8 additions & 1 deletion coolprompt/data_generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,15 @@ def _generate(
Returns:
Any: generated data
"""
if not isinstance(self.model, BaseChatModel):
if hasattr(self.model, 'model'):
wrapped_model = self.model.model
else:
wrapped_model = self.model

if not isinstance(wrapped_model, BaseChatModel):
output = self.model.invoke(request)
if isinstance(output, AIMessage):
output = output.content
return extract_json(output)[field_name]

structured_model = self.model.with_structured_output(
Expand Down
16 changes: 16 additions & 0 deletions coolprompt/language_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from coolprompt.language_model.tracker import (
OpenAITracker,
TrackedLLMWrapper,
create_chat_model,
model_tracker,
)
from coolprompt.language_model.llm import DefaultLLM

__all__ = [
"OpenAITracker",
"TrackedLLMWrapper",
"create_chat_model",
"model_tracker",
"DefaultLLM",
]

113 changes: 113 additions & 0 deletions coolprompt/language_model/tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from langchain_community.callbacks import get_openai_callback
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

все новые классы и методы нужны обернуть докстрингами по Google Style Code (посмотри как у нас в других скриптах)


from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_openai import ChatOpenAI
from typing import Any
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

вынеси на самый верх from typing import Any



class OpenAITracker:
_instance = None

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._reset_stats()
return cls._instance

def _reset_stats(self):
self.stats = {
"total_calls": 0,
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"total_cost": 0.0,
"invoke_calls": 0,
"batch_calls": 0,
"batch_items": 0,
}

def _update_stats(self, callback, invoke_flag, **kwargs):
self.stats["total_calls"] += 1
self.stats["total_tokens"] += callback.total_tokens
self.stats["prompt_tokens"] += callback.prompt_tokens
self.stats["completion_tokens"] += callback.completion_tokens
self.stats["total_cost"] += callback.total_cost

if invoke_flag:
self.stats["invoke_calls"] += 1
else:
self.stats["batch_calls"] += 1
self.stats["batch_items"] += kwargs.get("batch_size", 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

почему по дефолту batch_size = 0, если произошел сам вызов модели?


def wrap_model(self, model):
return TrackedLLMWrapper(model, self)

def get_stats(self):
return self.stats.copy()

def reset_stats(self):
self._reset_stats()


class TrackedLLMWrapper(BaseLanguageModel):
model: Any
tracker: Any

def __init__(self, model, tracker):
super().__init__(model=model, tracker=tracker)

@property
def _llm_type(self):
return "tracked_" + getattr(self.model, "_llm_type", "llm")

def generate_prompt(self, prompts, stop=None, **kwargs):
return self.model.generate_prompt(prompts, stop=stop, **kwargs)

async def agenerate_prompt(self, prompts, stop=None, **kwargs):
return await self.model.agenerate_prompt(prompts, stop=stop, **kwargs)

def invoke(self, input, **kwargs):
with get_openai_callback() as cb:
result = self.model.invoke(input, **kwargs)
self.tracker._update_stats(cb, True)
return result

def batch(self, inputs, **kwargs):
with get_openai_callback() as cb:
results = self.model.batch(inputs, **kwargs)
self.tracker._update_stats(cb, False, batch_size=len(inputs))
return results

def with_structured_output(self, schema, **kwargs):
if hasattr(self.model, 'with_structured_output'):
return self.model.with_structured_output(schema, **kwargs)
raise NotImplementedError(
f"Model {type(self.model)} does not support structured output"
)

def reset_stats(self):
self.tracker.reset_stats()

def get_stats(self):
return self.tracker.get_stats()

def __getattr__(self, name):
return getattr(self.model, name)


model_tracker = OpenAITracker()


def create_chat_model(model=None, **kwargs):
if isinstance(model, BaseLanguageModel):
base_model = model
elif model is not None:
kwargs["model"] = model
base_model = ChatOpenAI(**kwargs)
else:
base_model = ChatOpenAI(**kwargs)

return model_tracker.wrap_model(base_model)

9 changes: 8 additions & 1 deletion coolprompt/task_detector/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,15 @@ def _generate(
Returns:
Any: generated data
"""
if not isinstance(self.model, BaseChatModel):
if hasattr(self.model, 'model'):
wrapped_model = self.model.model
else:
wrapped_model = self.model

if not isinstance(wrapped_model, BaseChatModel):
output = self.model.invoke(request)
if isinstance(output, AIMessage):
output = output.content
return extract_json(output)[field_name]

structured_model = self.model.with_structured_output(
Expand Down
Loading