Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
from temporalio.contrib.openai_agents._temporal_openai_agents import (
OpenAIAgentsPlugin,
OpenAIPayloadConverter,
TestModel,
TestModelProvider,
)
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError

from . import workflow
from . import testing, workflow

__all__ = [
"AgentsWorkflowError",
Expand All @@ -33,7 +31,6 @@
"OpenAIPayloadConverter",
"StatelessMCPServerProvider",
"StatefulMCPServerProvider",
"TestModel",
"TestModelProvider",
"testing",
"workflow",
]
66 changes: 1 addition & 65 deletions temporalio/contrib/openai_agents/_temporal_openai_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,7 @@
from datetime import timedelta
from typing import AsyncIterator, Callable, Optional, Sequence, Union

from agents import (
AgentOutputSchemaBase,
Handoff,
Model,
ModelProvider,
ModelResponse,
ModelSettings,
ModelTracing,
Tool,
TResponseInputItem,
set_trace_provider,
)
from agents.items import TResponseStreamEvent
from agents import ModelProvider, set_trace_provider
from agents.run import get_default_agent_runner, set_default_agent_runner
from agents.tracing import get_trace_provider
from agents.tracing.provider import DefaultTraceProvider
Expand Down Expand Up @@ -97,58 +85,6 @@ def set_open_ai_agent_temporal_overrides(
set_trace_provider(previous_trace_provider or DefaultTraceProvider())


class TestModelProvider(ModelProvider):
"""Test model provider which simply returns the given module."""

__test__ = False

def __init__(self, model: Model):
"""Initialize a test model provider with a model."""
self._model = model

def get_model(self, model_name: Union[str, None]) -> Model:
"""Get a model from the model provider."""
return self._model


class TestModel(Model):
"""Test model for use mocking model responses."""

__test__ = False

def __init__(self, fn: Callable[[], ModelResponse]) -> None:
"""Initialize a test model with a callable."""
self.fn = fn

async def get_response(
self,
system_instructions: Union[str, None],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Union[AgentOutputSchemaBase, None],
handoffs: list[Handoff],
tracing: ModelTracing,
**kwargs,
) -> ModelResponse:
"""Get a response from the model."""
return self.fn()

def stream_response(
self,
system_instructions: Optional[str],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Optional[AgentOutputSchemaBase],
handoffs: list[Handoff],
tracing: ModelTracing,
**kwargs,
) -> AsyncIterator[TResponseStreamEvent]:
"""Get a streamed response from the model. Unimplemented."""
raise NotImplementedError()


class OpenAIPayloadConverter(PydanticPayloadConverter):
"""PayloadConverter for OpenAI agents."""

Expand Down
Loading
Loading