Skip to content

Commit

Permalink
first pass at removing deprecated usaged (#751)
Browse files Browse the repository at this point in the history
Reduces a significant number of run time warnings when running unit
tests (down to ~350)
  • Loading branch information
eyurtsev authored Sep 10, 2024
1 parent 21c2e3d commit 72c200f
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 22 deletions.
25 changes: 15 additions & 10 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ async def _unpack_request_config(
config_dicts = []
for config in client_sent_configs:
if isinstance(config, str):
config_dicts.append(model(**_config_from_hash(config)).dict())
config_dicts.append(model(**_config_from_hash(config)).model_dump())
elif isinstance(config, BaseModel):
config_dicts.append(config.dict())
config_dicts.append(config.model_dump())
elif isinstance(config, Mapping):
config_dicts.append(model(**config).dict())
config_dicts.append(model(**config).model_dump())
else:
raise TypeError(f"Expected a string, dict or BaseModel got {type(config)}")
config = merge_configs(*config_dicts)
Expand Down Expand Up @@ -298,7 +298,7 @@ def _unpack_input(validated_model: BaseModel) -> Any:
# This logic should be applied recursively to nested models.
return {
fieldname: _unpack_input(getattr(model, fieldname))
for fieldname in model.__fields__.keys()
for fieldname in model.model_fields.keys()
}

return model
Expand Down Expand Up @@ -330,6 +330,11 @@ def _replace_non_alphanumeric_with_underscores(s: str) -> str:
return re.sub(r"[^a-zA-Z0-9]", "_", s)


def _schema_json(model: Type[BaseModel]) -> str:
"""Return the JSON representation of the model schema."""
return json.dumps(model.model_json_schema(), sort_keys=True, indent=False)


def _resolve_model(
type_: Union[Type, BaseModel], default_name: str, namespace: str
) -> Type[BaseModel]:
Expand All @@ -339,13 +344,13 @@ def _resolve_model(
else:
model = _create_root_model(default_name, type_)

hash_ = model.schema_json()
hash_ = _schema_json(model)

if model.__name__ in _SEEN_NAMES and hash_ not in _MODEL_REGISTRY:
# If the model name has been seen before, but the model itself is different
# generate a new name for the model.
model_to_use = _rename_pydantic_model(model, namespace)
hash_ = model_to_use.schema_json()
hash_ = _schema_json(model_to_use)
else:
model_to_use = model

Expand Down Expand Up @@ -755,7 +760,7 @@ async def _get_config_and_input(
except json.JSONDecodeError:
raise RequestValidationError(errors=["Invalid JSON body"])
try:
body = InvokeRequestShallowValidator.validate(body)
body = InvokeRequestShallowValidator.model_validate(body)

# Merge the config from the path with the config from the body.
user_provided_config = await _unpack_request_config(
Expand Down Expand Up @@ -1407,7 +1412,7 @@ async def input_schema(
self._run_name, user_provided_config, request
)

return self._runnable.get_input_schema(config).schema()
return self._runnable.get_input_schema(config).model_json_schema()

async def output_schema(
self,
Expand All @@ -1434,7 +1439,7 @@ async def output_schema(
config = _update_config_with_defaults(
self._run_name, user_provided_config, request
)
return self._runnable.get_output_schema(config).schema()
return self._runnable.get_output_schema(config).model_json_schema()

async def config_schema(
self,
Expand Down Expand Up @@ -1464,7 +1469,7 @@ async def config_schema(
return (
self._runnable.with_config(config)
.config_schema(include=self._config_keys)
.schema()
.model_json_schema()
)

async def playground(
Expand Down
8 changes: 5 additions & 3 deletions langserve/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ async def serve_playground(
if base_url.startswith("/")
else base_url,
LANGSERVE_CONFIG_SCHEMA=json.dumps(
runnable.config_schema(include=config_keys).schema()
runnable.config_schema(include=config_keys).model_json_schema()
),
LANGSERVE_INPUT_SCHEMA=json.dumps(input_schema.model_json_schema()),
LANGSERVE_OUTPUT_SCHEMA=json.dumps(
output_schema.model_json_schema()
),
LANGSERVE_INPUT_SCHEMA=json.dumps(input_schema.schema()),
LANGSERVE_OUTPUT_SCHEMA=json.dumps(output_schema.schema()),
LANGSERVE_FEEDBACK_ENABLED=json.dumps(
"true" if feedback_enabled else "false"
),
Expand Down
12 changes: 6 additions & 6 deletions langserve/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _log_error_message_once(error_message: str) -> None:
def default(obj) -> Any:
"""Default serialization for well known objects."""
if isinstance(obj, BaseModel):
return obj.dict()
return obj.model_dump()
return super().default(obj)


Expand All @@ -96,7 +96,7 @@ def _decode_lc_objects(value: Any) -> Any:
try:
obj = WellKnownLCObject.model_validate(v)
parsed = obj.root
if set(parsed.dict()) != set(value):
if set(parsed.model_dump()) != set(value):
raise ValueError("Invalid object")
return parsed
except (ValidationError, ValueError, TypeError):
Expand All @@ -121,11 +121,11 @@ def _decode_event_data(value: Any) -> Any:
"""Decode the event data from a JSON object representation."""
if isinstance(value, dict):
try:
obj = CallbackEvent.parse_obj(value)
obj = CallbackEvent.model_validate(value)
return obj.root
except ValidationError:
try:
obj = WellKnownLCObject.parse_obj(value)
obj = WellKnownLCObject.model_validate(value)
return obj.root
except ValidationError:
return {key: _decode_event_data(v) for key, v in value.items()}
Expand Down Expand Up @@ -176,7 +176,7 @@ def loads(self, s: bytes) -> Any:

def _project_top_level(model: BaseModel) -> Dict[str, Any]:
"""Project the top level of the model as dict."""
return {key: getattr(model, key) for key in model.__fields__}
return {key: getattr(model, key) for key in model.model_fields}


def load_events(events: Any) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -207,7 +207,7 @@ def load_events(events: Any) -> List[Dict[str, Any]]:

# Then validate the event
try:
full_event = CallbackEvent.parse_obj(decoded_event_data)
full_event = CallbackEvent.model_validate(decoded_event_data)
except ValidationError as e:
msg = f"Encountered an invalid event: {e}"
if "type" in decoded_event_data:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_document_serialization() -> None:
"""Simple test. Exhaustive tests follow below."""
doc = Document(page_content="hello")
d = doc.dict()
d = doc.model_dump()
WellKnownLCObject.model_validate(d)


Expand Down Expand Up @@ -87,7 +87,7 @@ def _get_full_representation(data: Any) -> Any:
elif isinstance(data, list):
return [_get_full_representation(value) for value in data]
elif isinstance(data, BaseModel):
return data.schema()
return data.model_json_schema()
else:
return data

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ async def test_invoke_request_with_runnables() -> None:
assert request.config.tags == ["hello"]
assert request.config.run_name == "run"
assert isinstance(request.config.configurable, BaseModel)
assert request.config.configurable.dict() == {
assert request.config.configurable.model_dump() == {
"template": "goodbye {name}",
}

Expand Down

0 comments on commit 72c200f

Please sign in to comment.