Skip to content

Commit

Permalink
chore: cleanup streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
willbakst committed Jun 21, 2024
1 parent a1dbe92 commit 2a87bce
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 84 deletions.
27 changes: 27 additions & 0 deletions mirascope/core/base/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
get_type_hints,
)

import jiter
from docstring_parser import parse
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo

from ._partial import partial
from .message_param import BaseMessageParam

DEFAULT_TOOL_DOCSTRING = """\
Expand Down Expand Up @@ -257,3 +259,28 @@ def convert_base_type_to_base_tool(
__doc__=DEFAULT_TOOL_DOCSTRING,
value=(schema, ...),
)


_ResponseModelT = TypeVar("_ResponseModelT", bound=BaseModel | BaseType)


def setup_extract_tool(
response_model: type[BaseModel] | type[BaseType], tool_type: type[BaseToolT]
) -> type[BaseToolT]:
if is_base_type(response_model):
return convert_base_type_to_base_tool(response_model, tool_type) # type: ignore
return convert_base_model_to_base_tool(response_model, tool_type) # type: ignore


def extract_tool_return(
response_model: type[_ResponseModelT], json_output: str, allow_partial: bool
) -> _ResponseModelT:
json_obj = jiter.from_json(
json_output.encode(),
partial_mode="trailing-strings" if allow_partial else "off",
)
if is_base_type(response_model):
temp_model = convert_base_type_to_base_tool(response_model, BaseModel) # type: ignore
return temp_model.model_validate(json_obj).value # type: ignore

return response_model.model_validate(json_obj) # type: ignore
3 changes: 1 addition & 2 deletions mirascope/core/base/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from abc import abstractmethod
from typing import Any

from pydantic import BaseModel, ConfigDict, GetCoreSchemaHandler
from pydantic_core import core_schema
from pydantic import BaseModel, ConfigDict

from . import _utils

Expand Down
31 changes: 4 additions & 27 deletions mirascope/core/openai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def setup_call(
list[ChatCompletionMessageParam],
None,
OpenAICallParams,
]: ... # pragma: no cover
]:
... # pragma: no cover


@overload
Expand All @@ -45,7 +46,8 @@ def setup_call(
list[ChatCompletionMessageParam],
list[type[OpenAITool]],
OpenAICallParams,
]: ... # pragma: no cover
]:
... # pragma: no cover


def setup_call(
Expand Down Expand Up @@ -134,31 +136,6 @@ def setup_extract(
return json_mode, messages, call_kwargs


def setup_extract_tool(
response_model: type[BaseModel] | type[_utils.BaseType],
) -> type[OpenAITool]:
if _utils.is_base_type(response_model):
return _utils.convert_base_type_to_base_tool(response_model, OpenAITool) # type: ignore
return _utils.convert_base_model_to_base_tool(response_model, OpenAITool) # type: ignore


def extract_tool_return(
response_model: type[_ResponseModelT], json_output: str, allow_partial: bool
) -> _ResponseModelT:
temp_model = response_model
if is_base_type := _utils.is_base_type(response_model):
temp_model = _utils.convert_base_type_to_base_tool(response_model, BaseModel) # type: ignore

if allow_partial:
json_obj = jiter.from_json(
json_output.encode(), partial_mode="trailing-strings"
)
output = _partial.partial(temp_model).model_validate(json_obj) # type: ignore
else:
output = temp_model.model_validate_json(json_output) # type: ignore
return output if not is_base_type else output.value # type: ignore


def openai_api_calculate_cost(
usage: CompletionUsage | None, model="gpt-3.5-turbo-16k"
) -> float | None:
Expand Down
48 changes: 32 additions & 16 deletions mirascope/core/openai/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import inspect
from functools import wraps
from typing import (
AsyncIterable,
Awaitable,
Expand All @@ -19,18 +20,17 @@

from ..base import BaseTool, _utils
from ._utils import (
extract_tool_return,
openai_api_calculate_cost,
setup_call,
setup_extract,
setup_extract_tool,
)
from .call_params import OpenAICallParams
from .call_response import OpenAICallResponse
from .call_response_chunk import OpenAICallResponseChunk
from .function_return import OpenAICallFunctionReturn
from .streams import OpenAIAsyncStream, OpenAIStream
from .structured_streams import OpenAIAsyncStructuredStream, OpenAIStructuredStream
from .tools import OpenAITool

_P = ParamSpec("_P")
_ResponseModelT = TypeVar("_ResponseModelT", bound=BaseModel | _utils.BaseType)
Expand All @@ -47,7 +47,8 @@ def openai_call(
) -> Callable[
[Callable[_P, OpenAICallFunctionReturn]],
Callable[_P, OpenAICallResponse],
]: ... # pragma: no cover
]:
... # pragma: no cover


@overload
Expand All @@ -61,7 +62,8 @@ def openai_call(
) -> Callable[
[Callable[_P, OpenAICallFunctionReturn]],
Callable[_P, _ResponseModelT],
]: ... # pragma: no cover
]:
... # pragma: no cover


@overload
Expand All @@ -75,7 +77,8 @@ def openai_call(
) -> Callable[
[Callable[_P, OpenAICallFunctionReturn]],
Callable[_P, OpenAIStream],
]: ... # pragma: no cover
]:
... # pragma: no cover


@overload
Expand All @@ -89,7 +92,8 @@ def openai_call(
) -> Callable[
[Callable[_P, OpenAICallFunctionReturn]],
Callable[_P, Iterable[_ResponseModelT]],
]: ... # pragma: no cover
]:
... # pragma: no cover


def openai_call(
Expand Down Expand Up @@ -135,6 +139,7 @@ def recommend_book(genre: str):
def call_decorator(
fn: Callable[_P, OpenAICallFunctionReturn],
) -> Callable[_P, OpenAICallResponse]:
@wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> OpenAICallResponse:
fn_args = inspect.signature(fn).bind(*args, **kwargs).arguments
fn_return = fn(*args, **kwargs)
Expand Down Expand Up @@ -167,6 +172,7 @@ def inner(*args: _P.args, **kwargs: _P.kwargs) -> OpenAICallResponse:
def stream_decorator(
fn: Callable[_P, OpenAICallFunctionReturn],
) -> Callable[_P, OpenAIStream]:
@wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> OpenAIStream:
fn_args = inspect.signature(fn).bind(*args, **kwargs).arguments
fn_return = fn(*args, **kwargs)
Expand Down Expand Up @@ -201,8 +207,9 @@ def extract_decorator(
fn: Callable[_P, OpenAICallFunctionReturn],
) -> Callable[_P, _ResponseModelT]:
assert response_model is not None
tool = setup_extract_tool(response_model)
tool = _utils.setup_extract_tool(response_model, OpenAITool)

@wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _ResponseModelT:
assert response_model is not None
fn_args = inspect.signature(fn).bind(*args, **kwargs).arguments
Expand All @@ -222,7 +229,7 @@ def inner(*args: _P.args, **kwargs: _P.kwargs) -> _ResponseModelT:
else:
raise ValueError("No tool call or JSON object found in response.")

output = extract_tool_return(response_model, json_output, False)
output = _utils.extract_tool_return(response_model, json_output, False)
if isinstance(response_model, BaseModel):
output._response = response # type: ignore
return output
Expand All @@ -233,8 +240,9 @@ def extract_stream_decorator(
fn: Callable[_P, OpenAICallFunctionReturn],
) -> Callable[_P, Iterable[_ResponseModelT]]:
assert response_model is not None
tool = setup_extract_tool(response_model)
tool = _utils.setup_extract_tool(response_model, OpenAITool)

@wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterable[_ResponseModelT]:
assert response_model is not None
fn_args = inspect.signature(fn).bind(*args, **kwargs).arguments
Expand Down Expand Up @@ -274,7 +282,8 @@ def openai_call_async(
) -> Callable[
[Callable[_P, Awaitable[OpenAICallFunctionReturn]]],
Callable[_P, Awaitable[OpenAICallResponse]],
]: ... # pragma: no cover
]:
... # pragma: no cover


@overload
Expand All @@ -288,7 +297,8 @@ def openai_call_async(
) -> Callable[
[Callable[_P, Awaitable[OpenAICallFunctionReturn]]],
Callable[_P, Awaitable[_ResponseModelT]],
]: ... # pragma: no cover
]:
... # pragma: no cover


@overload
Expand All @@ -302,7 +312,8 @@ def openai_call_async(
) -> Callable[
[Callable[_P, Awaitable[OpenAICallFunctionReturn]]],
Callable[_P, Awaitable[OpenAIAsyncStream]],
]: ... # pragma: no cover
]:
... # pragma: no cover


@overload
Expand All @@ -316,7 +327,8 @@ def openai_call_async(
) -> Callable[
[Callable[_P, Awaitable[OpenAICallFunctionReturn]]],
Callable[_P, Awaitable[AsyncIterable[_ResponseModelT]]],
]: ... # pragma: no cover
]:
... # pragma: no cover


def openai_call_async(
Expand Down Expand Up @@ -368,6 +380,7 @@ async def run():
def call_decorator(
fn: Callable[_P, Awaitable[OpenAICallFunctionReturn]],
) -> Callable[_P, Awaitable[OpenAICallResponse]]:
@wraps(fn)
async def inner_async(
*args: _P.args, **kwargs: _P.kwargs
) -> OpenAICallResponse:
Expand Down Expand Up @@ -402,6 +415,7 @@ async def inner_async(
def stream_decorator(
fn: Callable[_P, Awaitable[OpenAICallFunctionReturn]],
) -> Callable[_P, Awaitable[OpenAIAsyncStream]]:
@wraps(fn)
async def inner_async(*args: _P.args, **kwargs: _P.kwargs) -> OpenAIAsyncStream:
fn_args = inspect.signature(fn).bind(*args, **kwargs).arguments
fn_return = await fn(*args, **kwargs)
Expand Down Expand Up @@ -437,8 +451,9 @@ def extract_decorator(
) -> Callable[_P, Awaitable[_ResponseModelT]]:
nonlocal response_model
assert response_model is not None
tool = setup_extract_tool(response_model)
tool = _utils.setup_extract_tool(response_model, OpenAITool)

@wraps(fn)
async def inner(*args: _P.args, **kwargs: _P.kwargs) -> _ResponseModelT:
assert response_model is not None
fn_args = inspect.signature(fn).bind(*args, **kwargs).arguments
Expand All @@ -458,7 +473,7 @@ async def inner(*args: _P.args, **kwargs: _P.kwargs) -> _ResponseModelT:
else:
raise ValueError("No tool call or JSON object found in response.")

output = extract_tool_return(response_model, json_output, False)
output = _utils.extract_tool_return(response_model, json_output, False)
if isinstance(response_model, BaseModel):
output._response = response # type: ignore
return output
Expand All @@ -469,8 +484,9 @@ def extract_stream_decorator(
fn: Callable[_P, Awaitable[OpenAICallFunctionReturn]],
) -> Callable[_P, Awaitable[AsyncIterable[_ResponseModelT]]]:
assert response_model is not None
tool = setup_extract_tool(response_model)
tool = _utils.setup_extract_tool(response_model, OpenAITool)

@wraps(fn)
async def inner(
*args: _P.args, **kwargs: _P.kwargs
) -> AsyncIterable[_ResponseModelT]:
Expand Down
35 changes: 8 additions & 27 deletions mirascope/core/openai/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from openai.types.chat.chat_completion_message_tool_call import Function

from ..base import BaseAsyncStream, BaseStream
from ..base._partial import partial
from .call_response import OpenAICallResponse
from .call_response_chunk import OpenAICallResponseChunk
from .tools import OpenAITool
Expand All @@ -20,16 +19,14 @@ def _handle_chunk(
chunk: OpenAICallResponseChunk,
current_tool_call: ChatCompletionMessageToolCall,
current_tool_type: type[OpenAITool] | None,
allow_partial: bool,
) -> tuple[
OpenAITool | None,
ChatCompletionMessageToolCall,
type[OpenAITool] | None,
bool,
]:
"""Handles a chunk of the stream."""
if not chunk.tool_types or not chunk.tool_calls:
return None, current_tool_call, current_tool_type, False
return None, current_tool_call, current_tool_type

tool_call = chunk.tool_calls[0]
# Reset on new tool
Expand All @@ -55,29 +52,16 @@ def _handle_chunk(
)
if previous_tool_call.id and previous_tool_type is not None:
return (
previous_tool_type.from_tool_call(
previous_tool_call, allow_partial=allow_partial
),
previous_tool_type.from_tool_call(previous_tool_call),
current_tool_call,
current_tool_type,
True,
)

# Update arguments with each chunk
if tool_call.function and tool_call.function.arguments:
current_tool_call.function.arguments += tool_call.function.arguments

if allow_partial and current_tool_type:
return (
partial(current_tool_type).from_tool_call(
current_tool_call, allow_partial=True
),
current_tool_call,
current_tool_type,
False,
)

return None, current_tool_call, current_tool_type, False
return None, current_tool_call, current_tool_type


class OpenAIStream(
Expand Down Expand Up @@ -110,8 +94,8 @@ def __iter__(
current_tool_type = None
else:
yield chunk, None
tool, current_tool_call, current_tool_type, _ = _handle_chunk(
chunk, current_tool_call, current_tool_type, False
tool, current_tool_call, current_tool_type = _handle_chunk(
chunk, current_tool_call, current_tool_type
)
if tool is not None:
yield chunk, tool
Expand Down Expand Up @@ -158,12 +142,9 @@ async def generator():
current_tool_type = None
else:
yield chunk, None
(
tool,
current_tool_call,
current_tool_type,
_,
) = _handle_chunk(chunk, current_tool_call, current_tool_type, False)
tool, current_tool_call, current_tool_type = _handle_chunk(
chunk, current_tool_call, current_tool_type
)
if tool is not None:
yield chunk, tool
tool_calls.append(tool.tool_call)
Expand Down
Loading

0 comments on commit 2a87bce

Please sign in to comment.