Skip to content

Commit

Permalink
Merge pull request #213 from Mirascope/feature/logfire-generators
Browse files Browse the repository at this point in the history
fix: updated streaming to behave like instrument_openai by logfire
  • Loading branch information
brenkao committed May 7, 2024
2 parents 2bb7fdd + 89830d0 commit 269c333
Showing 1 changed file with 70 additions and 38 deletions.
108 changes: 70 additions & 38 deletions mirascope/logfire/logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import inspect
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, Optional, overload
from typing import Any, Callable, Optional, Union, overload

import logfire
from pydantic import BaseModel
from typing_extensions import LiteralString

from mirascope.base.types import (
BaseCallResponse,
Expand All @@ -22,6 +23,11 @@
BaseVectorStoreT,
)

ONE_SECOND_IN_NANOSECONDS = 1_000_000_000
STEAMING_MSG_TEMPLATE: LiteralString = (
"streaming response from {request_data[model]!r} took {duration:.2f}s"
)


def mirascope_logfire_span(fn: Callable):
"""Wraps a pydantic class method with a Logfire span."""
Expand Down Expand Up @@ -118,13 +124,10 @@ def _extract_chunk_content(
return response_chunk_type(chunk=chunk).content


@contextmanager
def _mirascope_llm_span(
fn: Callable, suffix: str, is_async: bool, args: tuple[Any], kwargs: dict[str, Any]
):
"""Wraps a pydantic class method with a Logfire span."""
def _get_span_data(
suffix: str, is_async: bool, args: tuple[Any], kwargs: dict[str, Any]
) -> dict[str, Any]:
additional_request_data = {}
model = kwargs.get("model", "unknown")
if suffix == "gemini":
gemini_messages = args[0]
additional_request_data["messages"] = [
Expand All @@ -133,10 +136,19 @@ def _mirascope_llm_span(
]
model = kwargs.pop("model")
additional_request_data["model"] = model
span_data: dict[str, Any] = {
return {
"async": is_async,
"request_data": kwargs | additional_request_data,
}


@contextmanager
def _mirascope_llm_span(
fn: Callable, suffix: str, is_async: bool, args: tuple[Any], kwargs: dict[str, Any]
):
"""Wraps a pydantic class method with a Logfire span."""
model = kwargs.get("model", "unknown")
span_data = _get_span_data(suffix, is_async, args, kwargs)
with logfire.with_settings(custom_scope_suffix=suffix, tags=["llm"]).span(
f"{suffix}.{fn.__name__} with {model}", **span_data
) as logfire_span:
Expand Down Expand Up @@ -181,36 +193,64 @@ def wrapper(*args, **kwargs):
return wrapper


@contextmanager
def record_streaming(
logfire_span: logfire.Logfire,
span_data: dict[str, Any],
content_from_stream: Callable[
[ChunkT, type[BaseCallResponseChunk]], Union[str, None]
],
):
"""Logfire record_streaming with Mirascope providers"""
content: list[str] = []

def record_chunk(
chunk: ChunkT, response_chunk_type: type[BaseCallResponseChunk]
) -> Any:
"""Handles all provider chunk_types instead of only OpenAI"""
chunk_content = content_from_stream(chunk, response_chunk_type)
if chunk_content is not None:
content.append(chunk_content)

timer = logfire_span._config.ns_timestamp_generator # type: ignore
start = timer()
try:
yield record_chunk
finally:
duration = (timer() - start) / ONE_SECOND_IN_NANOSECONDS
logfire_span.info(
STEAMING_MSG_TEMPLATE,
**span_data,
duration=duration,
response_data={
"combined_chunk_content": "".join(content),
"chunk_count": len(content),
},
)


def mirascope_logfire_stream(
fn: Callable, suffix: str, response_chunk_type: type[BaseCallResponseChunk]
) -> Callable:
"""Wraps a function that yields a generator with a Logfire span."""

@wraps(fn)
def wrapper(*args, **kwargs):
with _mirascope_llm_span(fn, suffix, False, args, kwargs) as logfire_span:
content = []
logfire_span = logfire.with_settings(custom_scope_suffix=suffix, tags=["llm"])
span_data = _get_span_data(suffix, False, args, kwargs)
with record_streaming(
logfire_span, span_data, _extract_chunk_content
) as record_chunk:
stream = fn(*args, **kwargs)
if suffix != "anthropic":
for chunk in stream:
chunk_content = _extract_chunk_content(chunk, response_chunk_type)
content.append(chunk_content)
record_chunk(chunk, response_chunk_type)
yield chunk
else:
with stream as s:
for chunk in s:
chunk_content = _extract_chunk_content(
chunk, response_chunk_type
)
content.append(chunk_content)
record_chunk(chunk, response_chunk_type)
yield chunk
logfire_span.set_attribute(
"response_data",
{
"combined_chunk_content": "".join(content),
"chunk_count": len(content),
},
)

return wrapper

Expand Down Expand Up @@ -260,32 +300,24 @@ def mirascope_logfire_stream_async(

@wraps(fn)
async def wrapper(*args, **kwargs):
with _mirascope_llm_span(fn, suffix, True, args, kwargs) as logfire_span:
content = []
logfire_span = logfire.with_settings(custom_scope_suffix=suffix, tags=["llm"])
span_data = _get_span_data(suffix, True, args, kwargs)
with record_streaming(
logfire_span, span_data, _extract_chunk_content
) as record_chunk:
if suffix == "groq":
stream = await fn(*args, **kwargs)
else:
stream = fn(*args, **kwargs)
if suffix != "anthropic":
async for chunk in stream:
chunk_content = _extract_chunk_content(chunk, response_chunk_type)
content.append(chunk_content)
record_chunk(chunk, response_chunk_type)
yield chunk
else:
async with stream as s:
async for chunk in s:
chunk_content = _extract_chunk_content(
chunk, response_chunk_type
)
content.append(chunk_content)
record_chunk(chunk, response_chunk_type)
yield chunk
logfire_span.set_attribute(
"response_data",
{
"combined_chunk_content": "".join(content),
"chunk_count": len(content),
},
)

return wrapper

Expand Down

0 comments on commit 269c333

Please sign in to comment.