Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Temporary workaround for generating unique model names #47

Merged
merged 10 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 51 additions & 25 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
FastAPI app or APIRouter.
"""
import json
import re
from inspect import isclass
from typing import (
Any,
Expand Down Expand Up @@ -100,13 +101,40 @@ def _unpack_input(validated_model: BaseModel) -> Any:
return model


def _rename_pydantic_model(model: Type[BaseModel], name: str) -> Type[BaseModel]:
"""Rename the given pydantic model to the given name."""
return create_model(
name,
__config__=model.__config__,
**{
fieldname: (
field.annotation,
Field(
field.default,
title=fieldname,
description=field.field_info.description,
),
)
for fieldname, field in model.__fields__.items()
},
)


# This is a global registry of models to avoid creating the same model
# multiple times.
# Duplicated model names break fastapi's openapi generation.
_MODEL_REGISTRY = {}
_SEEN_NAMES = set()


def _replace_non_alphanumeric_with_underscores(s: str) -> str:
"""Replace non-alphanumeric characters with underscores."""
return re.sub(r"[^a-zA-Z0-9]", "_", s)

def _resolve_model(type_: Union[Type, BaseModel], default_name: str) -> Type[BaseModel]:

def _resolve_model(
type_: Union[Type, BaseModel], default_name: str, namespace: str
) -> Type[BaseModel]:
"""Resolve the input type to a BaseModel."""
if isclass(type_) and issubclass(type_, BaseModel):
model = type_
Expand All @@ -115,8 +143,17 @@ def _resolve_model(type_: Union[Type, BaseModel], default_name: str) -> Type[Bas

hash_ = model.schema_json()

if model.__name__ in _SEEN_NAMES and hash_ not in _MODEL_REGISTRY:
# If the model name has been seen before, but the model itself is different
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the order models get loaded in deterministic? Could this result in an indeterministic openapi spec depending on that order?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on order of add_routes invocation so I think deterministic in this case. [This is because we're only registering top level model names (which is also why we can't resolve the full issue)]

# generate a new name for the model.
model_to_use = _rename_pydantic_model(model, f"{namespace}{model.__name__}")
hash_ = model_to_use.schema_json()
else:
model_to_use = model

if hash_ not in _MODEL_REGISTRY:
_MODEL_REGISTRY[hash_] = model
_SEEN_NAMES.add(model_to_use.__name__)
_MODEL_REGISTRY[hash_] = model_to_use

return _MODEL_REGISTRY[hash_]

Expand All @@ -134,24 +171,9 @@ def _add_namespace_to_model(namespace: str, model: Type[BaseModel]) -> Type[Base
Returns:
A new model with name prepended with the given namespace.
"""

class Config:
arbitrary_types_allowed = True

model_with_unique_name = create_model(
model_with_unique_name = _rename_pydantic_model(
model,
f"{namespace}{model.__name__}",
config=Config,
**{
name: (
field.annotation,
Field(
field.default,
title=name,
description=field.field_info.description,
),
)
for name, field in model.__fields__.items()
},
)
model_with_unique_name.update_forward_refs()
return model_with_unique_name
Expand Down Expand Up @@ -222,18 +244,22 @@ def add_routes(
"Use `pip install sse_starlette` to install."
)

namespace = path or ""

model_namespace = _replace_non_alphanumeric_with_underscores(path.strip("/"))

input_type_ = _resolve_model(
runnable.input_schema if input_type == "auto" else input_type, "Input"
runnable.input_schema if input_type == "auto" else input_type,
"Input",
model_namespace,
)

output_type_ = _resolve_model(
runnable.output_schema if output_type == "auto" else output_type, "Output"
runnable.output_schema if output_type == "auto" else output_type,
"Output",
model_namespace,
)

namespace = path or ""

model_namespace = path.strip("/").replace("/", "_")

ConfigPayload = _add_namespace_to_model(
model_namespace, runnable.config_schema(include=config_keys)
)
Expand Down
25 changes: 23 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fastapi = {version = ">=0.90.1", optional = true}
sse-starlette = {version = "^1.3.0", optional = true}
httpx-sse = {version = ">=0.3.1", optional = true}
pydantic = "^1"
langchain = ">=0.0.313"
langchain = ">=0.0.316"

[tool.poetry.group.dev.dependencies]
jupyterlab = "^3.6.1"
Expand Down
75 changes: 74 additions & 1 deletion tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from httpx import AsyncClient
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.prompts import PromptTemplate
from langchain.prompts.base import StringPromptValue
from langchain.schema.messages import HumanMessage, SystemMessage
from langchain.schema.runnable import RunnableConfig, RunnablePassthrough
from langchain.schema.runnable.base import RunnableLambda
Expand All @@ -21,9 +22,18 @@

from langserve.client import RemoteRunnable
from langserve.lzstring import LZString
from langserve.server import add_routes
from langserve.server import (
_rename_pydantic_model,
_replace_non_alphanumeric_with_underscores,
add_routes,
)
from tests.unit_tests.utils import FakeListLLM

try:
from pydantic.v1 import BaseModel, Field
except ImportError:
from pydantic import BaseModel, Field


@pytest.fixture(scope="session")
def event_loop():
Expand Down Expand Up @@ -455,6 +465,12 @@ async def mul_2(x: int) -> int:
path="/mul_2",
)

add_routes(app, PromptTemplate.from_template("{question}"), path="/prompt_1")

add_routes(
app, PromptTemplate.from_template("{question} {answer}"), path="/prompt_2"
)

async with get_async_client(app, path="/add_one") as runnable:
async with get_async_client(app, path="/mul_2") as runnable2:
assert await runnable.ainvoke(1) == 2
Expand All @@ -467,6 +483,16 @@ async def mul_2(x: int) -> int:
composite_runnable_2 = runnable | add_one | runnable2
assert await composite_runnable_2.ainvoke(3) == 10

async with get_async_client(app, path="/prompt_1") as runnable:
assert await runnable.ainvoke(
{"question": "What is your name?"}
) == StringPromptValue(text="What is your name?")

async with get_async_client(app, path="/prompt_2") as runnable:
assert await runnable.ainvoke(
{"question": "What is your name?", "answer": "Bob"}
) == StringPromptValue(text="What is your name? Bob")


@pytest.mark.asyncio
async def test_input_validation(
Expand Down Expand Up @@ -620,6 +646,8 @@ async def add_one(x: int) -> int:

server_runnable = RunnableLambda(func=add_one)
server_runnable2 = RunnableLambda(func=add_one)
server_runnable3 = PromptTemplate.from_template("say {name}")
server_runnable4 = PromptTemplate.from_template("say {name} {hello}")

app = FastAPI()
add_routes(
Expand All @@ -635,6 +663,20 @@ async def add_one(x: int) -> int:
config_keys=["tags"],
)

add_routes(
app,
server_runnable3,
path="/c",
config_keys=["tags"],
)

add_routes(
app,
server_runnable4,
path="/d",
config_keys=["tags"],
)

async with AsyncClient(app=app, base_url="http://localhost:9999") as async_client:
response = await async_client.get("/openapi.json")
assert response.status_code == 200
Expand Down Expand Up @@ -687,3 +729,34 @@ async def test_configurable_runnables(event_loop: AbstractEventLoop) -> None:
)
== "hello Mr. Kitten!"
)


# Test for utilities


@pytest.mark.parametrize(
"s,expected",
[
("hello", "hello"),
("hello world", "hello_world"),
("hello-world", "hello_world"),
("hello_world", "hello_world"),
("hello.world", "hello_world"),
],
)
def test_replace_non_alphanumeric(s: str, expected: str) -> None:
"""Test replace non alphanumeric."""
assert _replace_non_alphanumeric_with_underscores(s) == expected


def test_rename_pydantic_model() -> None:
"""Test rename pydantic model."""

class Foo(BaseModel):
bar: str = Field(..., description="A bar")
baz: str = Field(..., description="A baz")

Model = _rename_pydantic_model(Foo, "Bar")

assert isinstance(Model, type)
assert Model.__name__ == "Bar"