Skip to content

Commit

Permalink
more run time warnings (#753)
Browse files Browse the repository at this point in the history
This PR resolves most of the remaining run time warnings
  • Loading branch information
eyurtsev authored Sep 10, 2024
1 parent 54eee64 commit 04236b0
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 58 deletions.
6 changes: 3 additions & 3 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ async def _get_config_and_input(
# This takes into account changes in the input type when
# using configuration.
schema = self._runnable.with_config(config).input_schema
input_ = schema.validate(body.input)
input_ = schema.model_validate(body.input)
return config, _unpack_input(input_)
except ValidationError as e:
raise RequestValidationError(e.errors(), body=body)
Expand Down Expand Up @@ -892,7 +892,7 @@ async def batch(
raise RequestValidationError(errors=["Invalid JSON body"])

with _with_validation_error_translation():
body = BatchRequestShallowValidator.validate(body)
body = BatchRequestShallowValidator.model_validate(body)
config = body.config

# First unpack the config
Expand Down Expand Up @@ -943,7 +943,7 @@ async def batch(

inputs = [
_unpack_input(
self._runnable.with_config(config_).input_schema.validate(input_)
self._runnable.with_config(config_).input_schema.model_validate(input_)
)
for config_, input_ in zip(configs_, inputs_)
]
Expand Down
73 changes: 42 additions & 31 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
The main entry point is the `add_routes` function which adds the routes to an existing
FastAPI app or APIRouter.
"""
import warnings
import weakref
from typing import (
Any,
Expand Down Expand Up @@ -201,37 +202,47 @@ def _register_path_for_app(
def _setup_global_app_handlers(
app: Union[FastAPI, APIRouter], endpoint_configuration: _EndpointConfiguration
) -> None:
@app.on_event("startup")
async def startup_event():
LANGSERVE = r"""
__ ___ .__ __. _______ _______. _______ .______ ____ ____ _______
| | / \ | \ | | / _____| / || ____|| _ \ \ \ / / | ____|
| | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__
| | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __|
| `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____
|_______/__/ \__\ |__| \__| \______| |_______/ |_______|| _| `._____| \__/ |_______|
""" # noqa: E501

def green(text: str) -> str:
"""Return the given text in green."""
return "\x1b[1;32;40m" + text + "\x1b[0m"

def orange(text: str) -> str:
"""Return the given text in orange."""
return "\x1b[1;31;40m" + text + "\x1b[0m"

paths = _APP_TO_PATHS[app]
print(LANGSERVE)
for path in paths:
if endpoint_configuration.is_playground_enabled:
print(
f'{green("LANGSERVE:")} Playground for chain "{path or ""}/" is '
f"live at:"
)
print(f'{green("LANGSERVE:")} │')
print(f'{green("LANGSERVE:")} └──> {path}/playground/')
print(f'{green("LANGSERVE:")}')
print(f'{green("LANGSERVE:")} See all available routes at {app.docs_url}/')
with warnings.catch_warnings():
# We are using deprecated functionality here.
# This code should be re-written to simply construct a pydantic model
# using inspect.signature and create_model.
warnings.filterwarnings(
"ignore",
"[\\s.]*on_event is deprecated[\\s.]*",
category=DeprecationWarning,
)

@app.on_event("startup")
async def startup_event():
LANGSERVE = r"""
__ ___ .__ __. _______ _______. _______ .______ ____ ____ _______
| | / \ | \ | | / _____| / || ____|| _ \ \ \ / / | ____|
| | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__
| | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __|
| `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____
|_______/__/ \__\ |__| \__| \______| |_______/ |_______|| _| `._____| \__/ |_______|
""" # noqa: E501

def green(text: str) -> str:
"""Return the given text in green."""
return "\x1b[1;32;40m" + text + "\x1b[0m"

def orange(text: str) -> str:
"""Return the given text in orange."""
return "\x1b[1;31;40m" + text + "\x1b[0m"

paths = _APP_TO_PATHS[app]
print(LANGSERVE)
for path in paths:
if endpoint_configuration.is_playground_enabled:
print(
f'{green("LANGSERVE:")} Playground for chain "{path or ""}/" '
f'is live at:'
)
print(f'{green("LANGSERVE:")} │')
print(f'{green("LANGSERVE:")} └──> {path}/playground/')
print(f'{green("LANGSERVE:")}')
print(f'{green("LANGSERVE:")} See all available routes at {app.docs_url}/')


# PUBLIC API
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@ addopts = "--strict-markers --strict-config --durations=5 -vv"
# take more than 5 seconds
timeout = 5
asyncio_mode = "auto"
filterwarnings = [
"ignore::langchain_core._api.beta_decorator.LangChainBetaWarning",
]

43 changes: 19 additions & 24 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import sys
import uuid
from asyncio import AbstractEventLoop
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -123,7 +122,7 @@ def _replace_run_id_in_stream_resp(streamed_resp: str) -> str:
return streamed_resp.replace(uuid, "<REPLACED>")


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def event_loop():
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop()
Expand All @@ -134,7 +133,7 @@ def event_loop():


@pytest.fixture()
def app(event_loop: AbstractEventLoop) -> FastAPI:
def app() -> FastAPI:
"""A simple server that wraps a Runnable and exposes it as an API."""

async def add_one_or_passthrough(
Expand All @@ -158,7 +157,7 @@ async def add_one_or_passthrough(


@pytest.fixture()
def app_for_config(event_loop: AbstractEventLoop) -> FastAPI:
def app_for_config() -> FastAPI:
"""A simple server that wraps a Runnable and exposes it as an API."""

async def return_config(
Expand Down Expand Up @@ -223,7 +222,7 @@ async def get_async_test_client(
app=server,
raise_app_exceptions=raise_app_exceptions,
)
async_client = AsyncClient(app=server, base_url=url, transport=transport)
async_client = AsyncClient(base_url=url, transport=transport)
try:
yield async_client
finally:
Expand Down Expand Up @@ -333,7 +332,7 @@ async def test_server_async(app: FastAPI) -> None:
# test bad requests
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
# Test invoke
response = await async_client.post("/invoke", data="bad json []")
response = await async_client.post("/invoke", content="bad json []")
# Client side error bad json.
assert response.status_code == 422

Expand All @@ -353,7 +352,7 @@ async def test_server_async(app: FastAPI) -> None:
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
# Test invoke
# Test bad batch requests
response = await async_client.post("/batch", data="bad json []")
response = await async_client.post("/batch", content="bad json []")
# Client side error bad json.
assert response.status_code == 422

Expand All @@ -378,15 +377,15 @@ async def test_server_async(app: FastAPI) -> None:
# test stream bad requests
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
# Test bad stream requests
response = await async_client.post("/stream", data="bad json []")
response = await async_client.post("/stream", content="bad json []")
assert response.status_code == 422

response = await async_client.post("/stream", json={})
assert response.status_code == 422

# test stream_log bad requests
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
response = await async_client.post("/stream_log", data="bad json []")
response = await async_client.post("/stream_log", content="bad json []")
assert response.status_code == 422

response = await async_client.post("/stream_log", json={})
Expand Down Expand Up @@ -448,7 +447,7 @@ async def test_server_astream_events(app: FastAPI) -> None:

# test stream_events with bad requests
async with get_async_test_client(app, raise_app_exceptions=True) as async_client:
response = await async_client.post("/stream_events", data="bad json []")
response = await async_client.post("/stream_events", content="bad json []")
assert response.status_code == 422

response = await async_client.post("/stream_events", json={})
Expand Down Expand Up @@ -854,7 +853,7 @@ async def with_errors(inputs: dict) -> AsyncIterator[int]:
assert e.value.response.status_code == 500


async def test_astream_log_allowlist(event_loop: AbstractEventLoop) -> None:
async def test_astream_log_allowlist() -> None:
"""Test async stream with an allowlist."""

async def add_one(x: int) -> int:
Expand Down Expand Up @@ -1035,7 +1034,7 @@ async def test_invoke_as_part_of_sequence_async(
}


async def test_multiple_runnables(event_loop: AbstractEventLoop) -> None:
async def test_multiple_runnables() -> None:
"""Test serving multiple runnables."""

async def add_one(x: int) -> int:
Expand Down Expand Up @@ -1159,7 +1158,7 @@ async def add_one(x: int) -> int:
await runnable.abatch(["hello"])


async def test_input_validation_with_lc_types(event_loop: AbstractEventLoop) -> None:
async def test_input_validation_with_lc_types() -> None:
"""Test client side and server side exceptions."""

app = FastAPI()
Expand Down Expand Up @@ -1252,9 +1251,7 @@ async def test_async_client_close() -> None:
assert async_client.is_closed is True


async def test_openapi_docs_with_identical_runnables(
event_loop: AbstractEventLoop, mocker: MockerFixture
) -> None:
async def test_openapi_docs_with_identical_runnables(mocker: MockerFixture) -> None:
"""Test client side and server side exceptions."""

async def add_one(x: int) -> int:
Expand Down Expand Up @@ -1301,7 +1298,7 @@ async def add_one(x: int) -> int:
assert response.status_code == 200


async def test_configurable_runnables(event_loop: AbstractEventLoop) -> None:
async def test_configurable_runnables() -> None:
"""Add tests for using langchain's configurable runnables"""

template = PromptTemplate.from_template("say {name}").configurable_fields(
Expand Down Expand Up @@ -1391,7 +1388,7 @@ class Foo(BaseModel):
assert Model.__name__ == "BarFoo"


async def test_input_config_output_schemas(event_loop: AbstractEventLoop) -> None:
async def test_input_config_output_schemas() -> None:
"""Test schemas returned for different configurations."""
# TODO(Fix me): need to fix handling of global state -- we get problems
# gives inconsistent results when running multiple tests / results
Expand Down Expand Up @@ -1753,7 +1750,7 @@ async def test_server_side_error() -> None:
# assert e.response.text == "Internal Server Error"


def test_server_side_error_sync(event_loop: AbstractEventLoop) -> None:
def test_server_side_error_sync() -> None:
"""Test server side error handling."""

app = FastAPI()
Expand Down Expand Up @@ -1982,7 +1979,7 @@ async def test_enforce_trailing_slash_in_client() -> None:
assert r.url == "nosuchurl/"


async def test_per_request_config_modifier(event_loop: AbstractEventLoop) -> None:
async def test_per_request_config_modifier() -> None:
"""Test updating the config based on the raw request object."""

async def add_one(x: int) -> int:
Expand Down Expand Up @@ -2025,9 +2022,7 @@ async def header_passthru_modifier(
assert response.json()["output"] == 2


async def test_per_request_config_modifier_endpoints(
event_loop: AbstractEventLoop,
) -> None:
async def test_per_request_config_modifier_endpoints() -> None:
"""Verify that per request modifier is only applied for the expected endpoints."""

# this test verifies that per request modifier is only
Expand Down Expand Up @@ -2097,7 +2092,7 @@ async def buggy_modifier(
assert response.status_code != 500


async def test_uuid_serialization(event_loop: AbstractEventLoop) -> None:
async def test_uuid_serialization() -> None:
"""Test updating the config based on the raw request object."""
import datetime

Expand Down

0 comments on commit 04236b0

Please sign in to comment.