-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: decorator got messy, so I pulled everything into individual…
… files
- Loading branch information
Showing
30 changed files
with
1,087 additions
and
846 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""This module contains the base classes for streaming responses from LLMs.""" | ||
|
||
from abc import ABC | ||
from collections.abc import Generator | ||
from typing import Any, Generic, TypeVar | ||
|
||
from .call_response_chunk import BaseCallResponseChunk | ||
from .tool import BaseTool | ||
|
||
_BaseCallResponseChunkT = TypeVar( | ||
"_BaseCallResponseChunkT", bound=BaseCallResponseChunk | ||
) | ||
_UserMessageParamT = TypeVar("_UserMessageParamT", bound=Any) | ||
_AssistantMessageParamT = TypeVar("_AssistantMessageParamT", bound=Any) | ||
_BaseToolT = TypeVar("_BaseToolT", bound=BaseTool) | ||
|
||
|
||
class BaseStream( | ||
Generic[ | ||
_BaseCallResponseChunkT, | ||
_UserMessageParamT, | ||
_AssistantMessageParamT, | ||
_BaseToolT, | ||
], | ||
ABC, | ||
): | ||
"""A base class for streaming responses from LLMs.""" | ||
|
||
stream: Generator[_BaseCallResponseChunkT, None, None] | ||
message_param_type: type[_AssistantMessageParamT] | ||
|
||
cost: float | None = None | ||
user_message_param: _UserMessageParamT | None = None | ||
message_param: _BaseToolT | ||
|
||
def __init__( | ||
self, | ||
stream: Generator[_BaseCallResponseChunkT, None, None], | ||
message_param_type: type[_AssistantMessageParamT], | ||
): | ||
"""Initializes an instance of `BaseStream`.""" | ||
self.stream = stream | ||
self.message_param_type = message_param_type | ||
|
||
def __iter__( | ||
self, | ||
) -> Generator[tuple[_BaseCallResponseChunkT, _BaseToolT | None], None, None]: | ||
"""Iterator over the stream and stores useful information.""" | ||
content = "" | ||
for chunk in self.stream: | ||
content += chunk.content | ||
if chunk.cost is not None: | ||
self.cost = chunk.cost | ||
yield chunk, None | ||
self.user_message_param = chunk.user_message_param | ||
kwargs = {"role": "assistant"} | ||
if "message" in self.message_param_type.__annotations__: | ||
kwargs["message"] = content | ||
else: | ||
kwargs["content"] = content | ||
self.message_param = self.message_param_type(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""This module contains the base classes for async structured streams from LLMs.""" | ||
|
||
from abc import ABC, abstractmethod | ||
from collections.abc import AsyncGenerator | ||
from typing import Any, Generic, TypeVar | ||
|
||
from pydantic import BaseModel | ||
|
||
from ._utils import BaseType | ||
|
||
_ChunkT = TypeVar("_ChunkT", bound=Any) | ||
_ResponseModelT = TypeVar("_ResponseModelT", bound=BaseModel | BaseType) | ||
|
||
|
||
class BaseAsyncStructuredStream(Generic[_ChunkT, _ResponseModelT], ABC): | ||
"""A base class for async streaming structured outputs from LLMs.""" | ||
|
||
stream: AsyncGenerator[_ChunkT, None] | ||
response_model: type[_ResponseModelT] | ||
json_mode: bool | ||
|
||
def __init__( | ||
self, | ||
stream: AsyncGenerator[_ChunkT, None], | ||
response_model: type[_ResponseModelT], | ||
json_mode: bool = False, | ||
): | ||
"""Initializes an instance of `BaseAsyncStructuredStream`.""" | ||
self.stream = stream | ||
self.response_model = response_model | ||
self.json_mode = json_mode | ||
|
||
@abstractmethod | ||
def __aiter__(self) -> AsyncGenerator[_ResponseModelT, None]: | ||
"""Iterates over the stream and extracts structured outputs.""" |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,23 @@ | ||
"""The Mirascope OpenAI Module.""" | ||
|
||
from .call import openai_call | ||
from .call import openai_call as call | ||
from .call_async import openai_call_async | ||
from .call_async import openai_call_async as call_async | ||
from .call_params import OpenAICallParams | ||
from .call_response import OpenAICallResponse | ||
from .call_response_chunk import OpenAICallResponseChunk | ||
from .calls import openai_call, openai_call_async | ||
from .calls import openai_call as call | ||
from .calls import openai_call_async as call_async | ||
from .function_return import OpenAICallFunctionReturn | ||
from .streams import OpenAIAsyncStream, OpenAIStream | ||
from .structured_streams import OpenAIAsyncStructuredStream, OpenAIStructuredStream | ||
from .tools import OpenAITool | ||
from .tool import OpenAITool | ||
|
||
__all__ = [ | ||
"call", | ||
"OpenAIAsyncStream", | ||
"OpenAIAsyncStructuredStream", | ||
"call_async", | ||
"OpenAICallFunctionReturn", | ||
"OpenAICallParams", | ||
"OpenAICallResponse", | ||
"OpenAICallResponseChunk", | ||
"OpenAIStream", | ||
"OpenAIStructuredStream", | ||
"OpenAITool", | ||
"openai_call", | ||
"openai_call_async", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""This module contains the OpenAI `call_decorator` function.""" | ||
|
||
import datetime | ||
import inspect | ||
from functools import wraps | ||
from typing import Callable, ParamSpec | ||
|
||
from openai import OpenAI | ||
|
||
from ..base import BaseTool | ||
from ._utils import openai_api_calculate_cost, setup_call | ||
from .call_params import OpenAICallParams | ||
from .call_response import OpenAICallResponse | ||
from .function_return import OpenAICallFunctionReturn | ||
|
||
_P = ParamSpec("_P") | ||
|
||
|
||
def call_decorator( | ||
fn: Callable[_P, OpenAICallFunctionReturn], | ||
model: str, | ||
tools: list[type[BaseTool] | Callable] | None, | ||
call_params: OpenAICallParams, | ||
) -> Callable[_P, OpenAICallResponse]: | ||
@wraps(fn) | ||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> OpenAICallResponse: | ||
fn_args = inspect.signature(fn).bind(*args, **kwargs).arguments | ||
fn_return = fn(*args, **kwargs) | ||
prompt_template, messages, tool_types, call_kwargs = setup_call( | ||
fn, fn_args, fn_return, tools, call_params | ||
) | ||
client = OpenAI() | ||
start_time = datetime.datetime.now().timestamp() * 1000 | ||
response = client.chat.completions.create( | ||
model=model, stream=False, messages=messages, **call_kwargs | ||
) | ||
return OpenAICallResponse( | ||
response=response, | ||
tool_types=tool_types, | ||
prompt_template=prompt_template, | ||
fn_args=fn_args, | ||
fn_return=fn_return, | ||
messages=messages, | ||
call_params=call_kwargs, | ||
user_message_param=messages[-1] if messages[-1]["role"] == "user" else None, | ||
start_time=start_time, | ||
end_time=datetime.datetime.now().timestamp() * 1000, | ||
cost=openai_api_calculate_cost(response.usage, response.model), | ||
) | ||
|
||
return inner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""This module contains the OpenAI `call_async_decorator` function.""" | ||
|
||
import datetime | ||
import inspect | ||
from functools import wraps | ||
from typing import Awaitable, Callable, ParamSpec | ||
|
||
from openai import AsyncOpenAI | ||
|
||
from ..base import BaseTool | ||
from ._utils import ( | ||
openai_api_calculate_cost, | ||
setup_call, | ||
) | ||
from .call_params import OpenAICallParams | ||
from .call_response import OpenAICallResponse | ||
from .function_return import OpenAICallFunctionReturn | ||
|
||
_P = ParamSpec("_P") | ||
|
||
|
||
def call_async_decorator( | ||
fn: Callable[_P, Awaitable[OpenAICallFunctionReturn]], | ||
model: str, | ||
tools: list[type[BaseTool] | Callable] | None, | ||
call_params: OpenAICallParams, | ||
) -> Callable[_P, Awaitable[OpenAICallResponse]]: | ||
@wraps(fn) | ||
async def inner_async(*args: _P.args, **kwargs: _P.kwargs) -> OpenAICallResponse: | ||
fn_args = inspect.signature(fn).bind(*args, **kwargs).arguments | ||
fn_return = await fn(*args, **kwargs) | ||
prompt_template, messages, tool_types, call_kwargs = setup_call( | ||
fn, fn_args, fn_return, tools, call_params | ||
) | ||
client = AsyncOpenAI() | ||
start_time = datetime.datetime.now().timestamp() * 1000 | ||
response = await client.chat.completions.create( | ||
model=model, stream=False, messages=messages, **call_kwargs | ||
) | ||
return OpenAICallResponse( | ||
response=response, | ||
tool_types=tool_types, | ||
prompt_template=prompt_template, | ||
fn_args=fn_args, | ||
fn_return=fn_return, | ||
messages=messages, | ||
call_params=call_kwargs, | ||
user_message_param=messages[-1] if messages[-1]["role"] == "user" else None, | ||
start_time=start_time, | ||
end_time=datetime.datetime.now().timestamp() * 1000, | ||
cost=openai_api_calculate_cost(response.usage, response.model), | ||
) | ||
|
||
return inner_async |
Oops, something went wrong.