Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Sep 6, 2024
1 parent a77aab9 commit 387f82a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 46 deletions.
18 changes: 11 additions & 7 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ async def _unpack_request_config(
credentials to a runnable. The RunnableConfig is presented in its
dictionary form. Note that only keys in `config_keys` will be
modifiable by this function.
config_keys: keys that are accepted by the server. This is used to
make sure that the server doesn't allow any keys that it doesn't want
to allow.
"""
config_dicts = []
for config in client_sent_configs:
Expand All @@ -192,6 +195,7 @@ async def _unpack_request_config(
config_dicts.append(model(**config).dict())
else:
raise TypeError(f"Expected a string, dict or BaseModel got {type(config)}")

config = merge_configs(*config_dicts)
if "configurable" in config and config["configurable"]:
if "configurable" not in config_keys:
Expand Down Expand Up @@ -228,7 +232,7 @@ def _update_config_with_defaults(
"""Set up some baseline configuration for the underlying runnable."""

# Currently all defaults are non-overridable
overridable_default_config = RunnableConfig()
overridable_default_config: RunnableConfig = {}

metadata = {
"__useragent": request.headers.get("user-agent"),
Expand Down Expand Up @@ -256,11 +260,10 @@ def _update_config_with_defaults(
}
metadata.update(hosted_metadata)

non_overridable_default_config = RunnableConfig(
run_name=run_name,
metadata=metadata,
)

non_overridable_default_config: RunnableConfig = {
"run_name": run_name,
"metadata": metadata,
}
# merge_configs is last-writer-wins, so we specifically pass in the
# overridable configs first, then the user provided configs, then
# finally the non-overridable configs
Expand All @@ -269,7 +272,6 @@ def _update_config_with_defaults(
incoming_config,
non_overridable_default_config,
)

# run_id may have been set by user (and accepted by server) or
# it may have been by the user on the server request path.
# If it's not set, we'll generate a new one.
Expand Down Expand Up @@ -803,6 +805,8 @@ async def invoke(
with any other configuration. It's the last to be written, so
it will override any other configuration.
"""
from langchain_core.runnables.config import var_child_runnable_config

# We do not use the InvokeRequest model here since configurable runnables
# have dynamic schema -- so the validation below is a bit more involved.
config, input_ = await self._get_config_and_input(
Expand Down
81 changes: 42 additions & 39 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,30 +1070,65 @@ async def mul_2(x: int) -> int:
) == StringPromptValue(text="What is your name? Bob")


async def test_input_validation(mocker: MockerFixture) -> None:
"""Test client side and server side exceptions."""
async def test_config_keys_validation(mocker: MockerFixture) -> None:
"""This test should not use a RemoteRunnable.
RemoteRunnable runs in the same process as the server during unit tests
and there's unfortunately an interaction between the config contextvars
that makes it difficult to test this behavior correctly using a RemoteRunnable.
Instead the behavior can be tested correctly by making a regular http request
using a FastAPI test client.
"""

async def add_one(x: int) -> int:
"""Add one to simulate a valid function"""
return x + 1

server_runnable = RunnableLambda(func=add_one)
server_runnable2 = RunnableLambda(func=add_one)

app = FastAPI()
add_routes(
app,
server_runnable,
input_type=int,
path="/add_one",
config_keys=["metadata"],
)
async with AsyncClient(app=app, base_url="http://localhost:9999") as async_client:
server_runnable_spy = mocker.spy(server_runnable, "ainvoke")
response = await async_client.post(
"/invoke",
json={"input": 1, "config": {"tags": ["hello"], "metadata": {"a": 5}}},
)
# Config should be ignored but default debug information
# will still be added
assert response.status_code == 200
config_seen = server_runnable_spy.call_args[0][1]
assert "metadata" in config_seen
assert "a" in config_seen["metadata"]
assert config_seen["tags"] == []
assert "__useragent" in config_seen["metadata"]
assert "__langserve_version" in config_seen["metadata"]
assert "__langserve_endpoint" in config_seen["metadata"]
assert config_seen["metadata"]["__langserve_endpoint"] == "invoke"


async def test_input_validation(mocker: MockerFixture) -> None:
"""Test client side and server side exceptions."""

async def add_one(x: int) -> int:
"""Add one to simulate a valid function"""
return x + 1

server_runnable = RunnableLambda(func=add_one)
server_runnable2 = RunnableLambda(func=add_one)

app = FastAPI()
add_routes(
app,
server_runnable2,
server_runnable,
input_type=int,
path="/add_one_config",
config_keys=["tags", "metadata"],
path="/add_one",
)

async with get_async_remote_runnable(
Expand All @@ -1108,38 +1143,6 @@ async def add_one(x: int) -> int:
with pytest.raises(httpx.HTTPError):
await runnable.abatch(["hello"])

config = {"tags": ["test"], "metadata": {"a": 5}}

server_runnable_spy = mocker.spy(server_runnable, "ainvoke")
# Verify config is handled correctly
async with get_async_remote_runnable(app, path="/add_one") as runnable1:
# Verify that can be invoked with valid input
# Config ignored for runnable1
assert await runnable1.ainvoke(1, config=config) == 2
# Config should be ignored but default debug information
# will still be added
config_seen = server_runnable_spy.call_args[0][1]
assert "metadata" in config_seen
assert "a" not in config_seen["metadata"]
assert "__useragent" in config_seen["metadata"]
assert "__langserve_version" in config_seen["metadata"]
assert "__langserve_endpoint" in config_seen["metadata"]
assert config_seen["metadata"]["__langserve_endpoint"] == "invoke"

server_runnable2_spy = mocker.spy(server_runnable2, "ainvoke")
async with get_async_remote_runnable(app, path="/add_one_config") as runnable2:
# Config accepted for runnable2
assert await runnable2.ainvoke(1, config=config) == 2
# Config ignored

config_seen = server_runnable2_spy.call_args[0][1]
assert config_seen["tags"] == ["test"]
assert config_seen["metadata"]["a"] == 5
assert "__useragent" in config_seen["metadata"]
assert "__langserve_version" in config_seen["metadata"]
assert "__langserve_endpoint" in config_seen["metadata"]
assert config_seen["metadata"]["__langserve_endpoint"] == "invoke"


async def test_input_validation_with_lc_types(event_loop: AbstractEventLoop) -> None:
"""Test client side and server side exceptions."""
Expand Down

0 comments on commit 387f82a

Please sign in to comment.