Skip to content

Commit

Permalink
refactor: decorator got messy, so I pulled everything into individual…
Browse files Browse the repository at this point in the history
… files
  • Loading branch information
willbakst committed Jun 21, 2024
1 parent 2a87bce commit f036fcc
Show file tree
Hide file tree
Showing 30 changed files with 1,087 additions and 846 deletions.
10 changes: 6 additions & 4 deletions mirascope/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from .call_response_chunk import BaseCallResponseChunk
from .function_return import BaseFunctionReturn
from .message_param import BaseMessageParam
from .prompts import BasePrompt, tags
from .streams import BaseAsyncStream, BaseStream
from .structured_streams import BaseAsyncStructuredStream, BaseStructuredStream
from .tools import BaseTool
from .prompt import BasePrompt, tags
from .stream import BaseStream
from .stream_async import BaseAsyncStream
from .structured_stream import BaseStructuredStream
from .structured_stream_async import BaseAsyncStructuredStream
from .tool import BaseTool

__all__ = [
"BaseAsyncStream",
Expand Down
4 changes: 4 additions & 0 deletions mirascope/core/base/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def extract_tool_return(
)
if is_base_type(response_model):
temp_model = convert_base_type_to_base_tool(response_model, BaseModel) # type: ignore
if allow_partial:
return partial(temp_model).model_validate(json_obj).value # type: ignore
return temp_model.model_validate(json_obj).value # type: ignore

if allow_partial:
return partial(response_model).model_validate(json_obj) # type: ignore
return response_model.model_validate(json_obj) # type: ignore
2 changes: 1 addition & 1 deletion mirascope/core/base/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .call_params import BaseCallParams
from .function_return import BaseFunctionReturn
from .tools import BaseTool
from .tool import BaseTool

_ResponseT = TypeVar("_ResponseT", bound=Any)
_BaseToolT = TypeVar("_BaseToolT", bound=BaseTool)
Expand Down
2 changes: 1 addition & 1 deletion mirascope/core/base/call_response_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pydantic import BaseModel, ConfigDict, field_serializer

from .tools import BaseTool
from .tool import BaseTool

_ChunkT = TypeVar("_ChunkT", bound=Any)
_BaseToolT = TypeVar("_BaseToolT", bound=BaseTool)
Expand Down
2 changes: 1 addition & 1 deletion mirascope/core/base/function_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import NotRequired, TypedDict

from .call_params import BaseCallParams
from .tools import BaseTool
from .tool import BaseTool

_MessageParamT = TypeVar("_MessageParamT", bound=Any)
_CallParamsT = TypeVar("_CallParamsT", bound=BaseCallParams)
Expand Down
File renamed without changes.
61 changes: 61 additions & 0 deletions mirascope/core/base/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""This module contains the base classes for streaming responses from LLMs."""

from abc import ABC
from collections.abc import Generator
from typing import Any, Generic, TypeVar

from .call_response_chunk import BaseCallResponseChunk
from .tool import BaseTool

_BaseCallResponseChunkT = TypeVar(
"_BaseCallResponseChunkT", bound=BaseCallResponseChunk
)
_UserMessageParamT = TypeVar("_UserMessageParamT", bound=Any)
_AssistantMessageParamT = TypeVar("_AssistantMessageParamT", bound=Any)
_BaseToolT = TypeVar("_BaseToolT", bound=BaseTool)


class BaseStream(
Generic[
_BaseCallResponseChunkT,
_UserMessageParamT,
_AssistantMessageParamT,
_BaseToolT,
],
ABC,
):
"""A base class for streaming responses from LLMs."""

stream: Generator[_BaseCallResponseChunkT, None, None]
message_param_type: type[_AssistantMessageParamT]

cost: float | None = None
user_message_param: _UserMessageParamT | None = None
message_param: _BaseToolT

def __init__(
self,
stream: Generator[_BaseCallResponseChunkT, None, None],
message_param_type: type[_AssistantMessageParamT],
):
"""Initializes an instance of `BaseStream`."""
self.stream = stream
self.message_param_type = message_param_type

def __iter__(
self,
) -> Generator[tuple[_BaseCallResponseChunkT, _BaseToolT | None], None, None]:
"""Iterator over the stream and stores useful information."""
content = ""
for chunk in self.stream:
content += chunk.content
if chunk.cost is not None:
self.cost = chunk.cost
yield chunk, None
self.user_message_param = chunk.user_message_param
kwargs = {"role": "assistant"}
if "message" in self.message_param_type.__annotations__:
kwargs["message"] = content
else:
kwargs["content"] = content
self.message_param = self.message_param_type(**kwargs)
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""This module contains the base classes for streaming responses from LLMs."""
"""This module contains the base classes for async streaming responses from LLMs."""

from abc import ABC
from collections.abc import AsyncGenerator, Generator
from collections.abc import AsyncGenerator
from typing import Any, Generic, TypeVar

from .call_response_chunk import BaseCallResponseChunk
from .tools import BaseTool
from .tool import BaseTool

_BaseCallResponseChunkT = TypeVar(
"_BaseCallResponseChunkT", bound=BaseCallResponseChunk
Expand All @@ -15,52 +15,6 @@
_BaseToolT = TypeVar("_BaseToolT", bound=BaseTool)


class BaseStream(
Generic[
_BaseCallResponseChunkT,
_UserMessageParamT,
_AssistantMessageParamT,
_BaseToolT,
],
ABC,
):
"""A base class for streaming responses from LLMs."""

stream: Generator[_BaseCallResponseChunkT, None, None]
message_param_type: type[_AssistantMessageParamT]

cost: float | None = None
user_message_param: _UserMessageParamT | None = None
message_param: _BaseToolT

def __init__(
self,
stream: Generator[_BaseCallResponseChunkT, None, None],
message_param_type: type[_AssistantMessageParamT],
):
"""Initializes an instance of `BaseStream`."""
self.stream = stream
self.message_param_type = message_param_type

def __iter__(
self,
) -> Generator[tuple[_BaseCallResponseChunkT, _BaseToolT | None], None, None]:
"""Iterator over the stream and stores useful information."""
content = ""
for chunk in self.stream:
content += chunk.content
if chunk.cost is not None:
self.cost = chunk.cost
yield chunk, None
self.user_message_param = chunk.user_message_param
kwargs = {"role": "assistant"}
if "message" in self.message_param_type.__annotations__:
kwargs["message"] = content
else:
kwargs["content"] = content
self.message_param = self.message_param_type(**kwargs)


class BaseAsyncStream(
Generic[
_BaseCallResponseChunkT,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This module contains the base classes for structured streams from LLMs."""

from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator
from collections.abc import Generator
from typing import Any, Generic, TypeVar

from pydantic import BaseModel
Expand Down Expand Up @@ -33,26 +33,3 @@ def __init__(
@abstractmethod
def __iter__(self) -> Generator[_ResponseModelT, None, None]:
"""Iterates over the stream and extracts structured outputs."""


class BaseAsyncStructuredStream(Generic[_ChunkT, _ResponseModelT], ABC):
"""A base class for async streaming structured outputs from LLMs."""

stream: AsyncGenerator[_ChunkT, None]
response_model: type[_ResponseModelT]
json_mode: bool

def __init__(
self,
stream: AsyncGenerator[_ChunkT, None],
response_model: type[_ResponseModelT],
json_mode: bool = False,
):
"""Initializes an instance of `BaseAsyncStructuredStream`."""
self.stream = stream
self.response_model = response_model
self.json_mode = json_mode

@abstractmethod
def __aiter__(self) -> AsyncGenerator[_ResponseModelT, None]:
"""Iterates over the stream and extracts structured outputs."""
35 changes: 35 additions & 0 deletions mirascope/core/base/structured_stream_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""This module contains the base classes for async structured streams from LLMs."""

from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any, Generic, TypeVar

from pydantic import BaseModel

from ._utils import BaseType

_ChunkT = TypeVar("_ChunkT", bound=Any)
_ResponseModelT = TypeVar("_ResponseModelT", bound=BaseModel | BaseType)


class BaseAsyncStructuredStream(Generic[_ChunkT, _ResponseModelT], ABC):
"""A base class for async streaming structured outputs from LLMs."""

stream: AsyncGenerator[_ChunkT, None]
response_model: type[_ResponseModelT]
json_mode: bool

def __init__(
self,
stream: AsyncGenerator[_ChunkT, None],
response_model: type[_ResponseModelT],
json_mode: bool = False,
):
"""Initializes an instance of `BaseAsyncStructuredStream`."""
self.stream = stream
self.response_model = response_model
self.json_mode = json_mode

@abstractmethod
def __aiter__(self) -> AsyncGenerator[_ResponseModelT, None]:
"""Iterates over the stream and extracts structured outputs."""
File renamed without changes.
17 changes: 7 additions & 10 deletions mirascope/core/openai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
"""The Mirascope OpenAI Module."""

from .call import openai_call
from .call import openai_call as call
from .call_async import openai_call_async
from .call_async import openai_call_async as call_async
from .call_params import OpenAICallParams
from .call_response import OpenAICallResponse
from .call_response_chunk import OpenAICallResponseChunk
from .calls import openai_call, openai_call_async
from .calls import openai_call as call
from .calls import openai_call_async as call_async
from .function_return import OpenAICallFunctionReturn
from .streams import OpenAIAsyncStream, OpenAIStream
from .structured_streams import OpenAIAsyncStructuredStream, OpenAIStructuredStream
from .tools import OpenAITool
from .tool import OpenAITool

__all__ = [
"call",
"OpenAIAsyncStream",
"OpenAIAsyncStructuredStream",
"call_async",
"OpenAICallFunctionReturn",
"OpenAICallParams",
"OpenAICallResponse",
"OpenAICallResponseChunk",
"OpenAIStream",
"OpenAIStructuredStream",
"OpenAITool",
"openai_call",
"openai_call_async",
]
51 changes: 51 additions & 0 deletions mirascope/core/openai/_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""This module contains the OpenAI `call_decorator` function."""

import datetime
import inspect
from functools import wraps
from typing import Callable, ParamSpec

from openai import OpenAI

from ..base import BaseTool
from ._utils import openai_api_calculate_cost, setup_call
from .call_params import OpenAICallParams
from .call_response import OpenAICallResponse
from .function_return import OpenAICallFunctionReturn

_P = ParamSpec("_P")


def call_decorator(
fn: Callable[_P, OpenAICallFunctionReturn],
model: str,
tools: list[type[BaseTool] | Callable] | None,
call_params: OpenAICallParams,
) -> Callable[_P, OpenAICallResponse]:
@wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> OpenAICallResponse:
fn_args = inspect.signature(fn).bind(*args, **kwargs).arguments
fn_return = fn(*args, **kwargs)
prompt_template, messages, tool_types, call_kwargs = setup_call(
fn, fn_args, fn_return, tools, call_params
)
client = OpenAI()
start_time = datetime.datetime.now().timestamp() * 1000
response = client.chat.completions.create(
model=model, stream=False, messages=messages, **call_kwargs
)
return OpenAICallResponse(
response=response,
tool_types=tool_types,
prompt_template=prompt_template,
fn_args=fn_args,
fn_return=fn_return,
messages=messages,
call_params=call_kwargs,
user_message_param=messages[-1] if messages[-1]["role"] == "user" else None,
start_time=start_time,
end_time=datetime.datetime.now().timestamp() * 1000,
cost=openai_api_calculate_cost(response.usage, response.model),
)

return inner
54 changes: 54 additions & 0 deletions mirascope/core/openai/_call_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""This module contains the OpenAI `call_async_decorator` function."""

import datetime
import inspect
from functools import wraps
from typing import Awaitable, Callable, ParamSpec

from openai import AsyncOpenAI

from ..base import BaseTool
from ._utils import (
openai_api_calculate_cost,
setup_call,
)
from .call_params import OpenAICallParams
from .call_response import OpenAICallResponse
from .function_return import OpenAICallFunctionReturn

_P = ParamSpec("_P")


def call_async_decorator(
fn: Callable[_P, Awaitable[OpenAICallFunctionReturn]],
model: str,
tools: list[type[BaseTool] | Callable] | None,
call_params: OpenAICallParams,
) -> Callable[_P, Awaitable[OpenAICallResponse]]:
@wraps(fn)
async def inner_async(*args: _P.args, **kwargs: _P.kwargs) -> OpenAICallResponse:
fn_args = inspect.signature(fn).bind(*args, **kwargs).arguments
fn_return = await fn(*args, **kwargs)
prompt_template, messages, tool_types, call_kwargs = setup_call(
fn, fn_args, fn_return, tools, call_params
)
client = AsyncOpenAI()
start_time = datetime.datetime.now().timestamp() * 1000
response = await client.chat.completions.create(
model=model, stream=False, messages=messages, **call_kwargs
)
return OpenAICallResponse(
response=response,
tool_types=tool_types,
prompt_template=prompt_template,
fn_args=fn_args,
fn_return=fn_return,
messages=messages,
call_params=call_kwargs,
user_message_param=messages[-1] if messages[-1]["role"] == "user" else None,
start_time=start_time,
end_time=datetime.datetime.now().timestamp() * 1000,
cost=openai_api_calculate_cost(response.usage, response.model),
)

return inner_async
Loading

0 comments on commit f036fcc

Please sign in to comment.