Skip to content

Commit

Permalink
Merge pull request #974 from parea-ai/PAI-1313-instructor-integration…
Browse files Browse the repository at this point in the history
…-key-error

Pai 1313 instructor integration key error
  • Loading branch information
joschkabraun committed Jun 20, 2024
2 parents fdccc79 + ad6dbb5 commit cf43d2a
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 12 deletions.
Empty file.
Empty file removed parea/cookbook/assets/__init__.py
Empty file.
Empty file.
63 changes: 63 additions & 0 deletions parea/cookbook/instructor/instructor_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os

import instructor
from dotenv import load_dotenv
from openai import AsyncOpenAI

from parea import Parea

load_dotenv()

client = AsyncOpenAI()

p = Parea(api_key=os.getenv("PAREA_API_KEY"))
p.wrap_openai_client(client, "instructor")

client = instructor.from_openai(client)


from pydantic import BaseModel


class UserDetail(BaseModel):
name: str
age: int


async def main():
user = client.completions.create_partial(
model="gpt-3.5-turbo",
max_tokens=1024,
max_retries=3,
messages=[
{
"role": "user",
"content": "Please crea a user",
}
],
response_model=UserDetail,
)
# print(user)
async for u in user:
print(u)

user2 = client.completions.create_partial(
model="gpt-3.5-turbo",
max_tokens=1024,
max_retries=3,
messages=[
{
"role": "user",
"content": "Please crea a user",
}
],
response_model=UserDetail,
)
async for u in user2:
print(u)


if __name__ == "__main__":
import asyncio

asyncio.run(main())
Empty file.
Empty file removed parea/cookbook/marvin/__init__.py
Empty file.
Empty file removed parea/cookbook/openai/__init__.py
Empty file.
3 changes: 2 additions & 1 deletion parea/schemas/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple

import json
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple

from attrs import define, field, validators

Expand Down
20 changes: 12 additions & 8 deletions parea/utils/trace_integrations/instructor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Callable, Mapping, Tuple

import contextvars
from json import JSONDecodeError

from instructor.retry import InstructorRetryException
from pydantic import ValidationError
from wrapt import wrap_object

from parea import trace
Expand Down Expand Up @@ -33,6 +35,12 @@ def instrument_instructor_validation_errors() -> None:
)


def get_reasons(exception: Exception) -> list[str]:
if isinstance(exception, InstructorRetryException):
return [str(arg) for arg in exception.args]
return [str(exception)]


def report_instructor_validation_errors() -> None:
reason = "\n\n\n".join(instructor_val_errs.get())
if reason:
Expand Down Expand Up @@ -82,11 +90,9 @@ def __call__(
)(
wrapped
)(*args, **kwargs)
except InstructorRetryException as e:
except (InstructorRetryException, ValidationError, JSONDecodeError) as e:
instructor_val_err_count.set(instructor_val_err_count.get() + 1)
reasons = []
for arg in e.args:
reasons.append(str(arg))
reasons = get_reasons(e)
instructor_val_errs.set(instructor_val_errs.get() + reasons)

report_instructor_validation_errors()
Expand All @@ -105,11 +111,9 @@ def __call__(
kwargs: Mapping[str, Any],
) -> Any:
if instructor_trace_id.get() is not None:
if len(args) > 1 and args[1] is not None and isinstance(args[1], InstructorRetryException):
if len(args) > 1 and args[1] is not None and isinstance(args[1], (InstructorRetryException, ValidationError, JSONDecodeError)):
instructor_val_err_count.set(instructor_val_err_count.get() + 1)
reasons = []
for arg in args[1].args:
reasons.append(str(arg))
reasons = get_reasons(args[1])
instructor_val_errs.set(instructor_val_errs.get() + reasons)
else:
report_instructor_validation_errors()
Expand Down
19 changes: 17 additions & 2 deletions parea/wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
from parea.evals.utils import _make_evaluations
from parea.helpers import is_logging_disabled, timezone_aware_now
from parea.schemas.models import TraceLog, UpdateTraceScenario
from parea.utils.trace_utils import call_eval_funcs_then_log, execution_order_counters, fill_trace_data, trace_context, trace_data
from parea.schemas.models import TraceLog, UpdateLog, UpdateTraceScenario
from parea.utils.trace_utils import call_eval_funcs_then_log, execution_order_counters, fill_trace_data, logger_update_record, trace_context, trace_data
from parea.wrapper.utils import safe_format_template_to_prompt, skip_decorator_if_func_in_stack

logger = logging.getLogger()
Expand Down Expand Up @@ -213,7 +213,22 @@ def final_log():

self.log(trace_id)
try:
trace_context_before_reset = trace_context.get()
trace_context.reset(context_token)
trace_context_after_reset = trace_context.get()
if len(trace_context_after_reset) > len(trace_context_before_reset):
# this can happen if this is a streaming call and the LLM client got modified (e.g. instructor)
# so we need to manually reset the trace context to the previous state
trace_context.set(trace_context_before_reset)
# if the parent trace didn't have any output, we can also update the output
if (parent_trace_id := trace_data.get()[trace_id].parent_trace_id) is not None and not trace_data.get()[parent_trace_id].output:
logger_update_record(
UpdateLog(
trace_id=parent_trace_id,
field_name_to_value_map={"output": trace_data.get()[trace_id].output},
root_trace_id=trace_data.get()[parent_trace_id].root_trace_id,
)
)
except IndexError:
pass

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.178"
version = "0.2.179"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit cf43d2a

Please sign in to comment.