Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Sep 5, 2024
1 parent 9a033b6 commit a77aab9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
8 changes: 6 additions & 2 deletions langserve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,11 +431,15 @@ def batch(
self,
inputs: List[Input],
config: Optional[RunnableConfig] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Output]:
if kwargs:
raise NotImplementedError("kwargs not implemented yet.")
return self._batch_with_config(self._batch, inputs, config)
raise NotImplementedError(f"kwargs not implemented yet. Got {kwargs}")
return self._batch_with_config(
self._batch, inputs, config, return_exceptions=return_exceptions
)

async def _abatch(
self,
Expand Down
40 changes: 25 additions & 15 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,18 @@ def test_invoke(sync_remote_runnable: RemoteRunnable) -> None:
# Test tracing
tracer = FakeTracer()
assert sync_remote_runnable.invoke(1, config={"callbacks": [tracer]}) == 2
assert len(tracer.runs) == 1
# Light test to verify that we're picking up information about the server side
# function being invoked via a callback.
assert tracer.runs[0].child_runs[0].name == "RunnableLambda"
assert (
tracer.runs[0].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough"
# Picking up the run from the server side, and client side should also log a run
# from the RemoteRunnable that will have as a child the server side run.
assert len(tracer.runs) == 2

first_run = tracer.runs[0]

remote_runnable_run = (
tracer.runs[0] if first_run.name == "RemoteRunnable" else tracer.runs[1]
)
assert remote_runnable_run.name == "RemoteRunnable"

assert remote_runnable_run.child_runs[0].name == "add_one_or_passthrough"


def test_batch(sync_remote_runnable: RemoteRunnable) -> None:
Expand Down Expand Up @@ -577,13 +582,18 @@ async def test_ainvoke(async_remote_runnable: RemoteRunnable) -> None:
# Test tracing
tracer = FakeTracer()
assert await async_remote_runnable.ainvoke(1, config={"callbacks": [tracer]}) == 2
assert len(tracer.runs) == 1
# Light test to verify that we're picking up information about the server side
# function being invoked via a callback.
assert tracer.runs[0].child_runs[0].name == "RunnableLambda"
assert (
tracer.runs[0].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough"
# Picking up the run from the server side, and client side should also log a run
# from the RemoteRunnable that will have as a child the server side run.
assert len(tracer.runs) == 2

first_run = tracer.runs[0]

remote_runnable_run = (
tracer.runs[0] if first_run.name == "RemoteRunnable" else tracer.runs[1]
)
assert remote_runnable_run.name == "RemoteRunnable"

assert remote_runnable_run.child_runs[0].name == "add_one_or_passthrough"


async def test_abatch(async_remote_runnable: RemoteRunnable) -> None:
Expand Down Expand Up @@ -1060,9 +1070,7 @@ async def mul_2(x: int) -> int:
) == StringPromptValue(text="What is your name? Bob")


async def test_input_validation(
event_loop: AbstractEventLoop, mocker: MockerFixture
) -> None:
async def test_input_validation(mocker: MockerFixture) -> None:
"""Test client side and server side exceptions."""

async def add_one(x: int) -> int:
Expand Down Expand Up @@ -1415,13 +1423,15 @@ async def add_two(y: int) -> int:
"properties": {"question": {"title": "Question", "type": "string"}},
"title": "PromptInput",
"type": "object",
"required": ["question"],
}

response = await async_client.get("/prompt_2/input_schema")
assert response.json() == {
"properties": {"name": {"title": "Name", "type": "string"}},
"title": "PromptInput",
"type": "object",
"required": ["name"],
}

# output schema
Expand Down

0 comments on commit a77aab9

Please sign in to comment.