From 72c200ff819307727af3f9a91de569689c7323cd Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 10 Sep 2024 10:34:16 -0400 Subject: [PATCH] first pass at removing deprecated usaged (#751) Reduces a significant number of run time warnings when running unit tests (down to ~350) --- langserve/api_handler.py | 25 +++++++++++++++---------- langserve/playground.py | 8 +++++--- langserve/serialization.py | 12 ++++++------ tests/unit_tests/test_serialization.py | 4 ++-- tests/unit_tests/test_validation.py | 2 +- 5 files changed, 29 insertions(+), 22 deletions(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index 6e70d0b9..23f29c35 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -186,11 +186,11 @@ async def _unpack_request_config( config_dicts = [] for config in client_sent_configs: if isinstance(config, str): - config_dicts.append(model(**_config_from_hash(config)).dict()) + config_dicts.append(model(**_config_from_hash(config)).model_dump()) elif isinstance(config, BaseModel): - config_dicts.append(config.dict()) + config_dicts.append(config.model_dump()) elif isinstance(config, Mapping): - config_dicts.append(model(**config).dict()) + config_dicts.append(model(**config).model_dump()) else: raise TypeError(f"Expected a string, dict or BaseModel got {type(config)}") config = merge_configs(*config_dicts) @@ -298,7 +298,7 @@ def _unpack_input(validated_model: BaseModel) -> Any: # This logic should be applied recursively to nested models. return { fieldname: _unpack_input(getattr(model, fieldname)) - for fieldname in model.__fields__.keys() + for fieldname in model.model_fields.keys() } return model @@ -330,6 +330,11 @@ def _replace_non_alphanumeric_with_underscores(s: str) -> str: return re.sub(r"[^a-zA-Z0-9]", "_", s) +def _schema_json(model: Type[BaseModel]) -> str: + """Return the JSON representation of the model schema.""" + return json.dumps(model.model_json_schema(), sort_keys=True, indent=False) + + def _resolve_model( type_: Union[Type, BaseModel], default_name: str, namespace: str ) -> Type[BaseModel]: @@ -339,13 +344,13 @@ def _resolve_model( else: model = _create_root_model(default_name, type_) - hash_ = model.schema_json() + hash_ = _schema_json(model) if model.__name__ in _SEEN_NAMES and hash_ not in _MODEL_REGISTRY: # If the model name has been seen before, but the model itself is different # generate a new name for the model. model_to_use = _rename_pydantic_model(model, namespace) - hash_ = model_to_use.schema_json() + hash_ = _schema_json(model_to_use) else: model_to_use = model @@ -755,7 +760,7 @@ async def _get_config_and_input( except json.JSONDecodeError: raise RequestValidationError(errors=["Invalid JSON body"]) try: - body = InvokeRequestShallowValidator.validate(body) + body = InvokeRequestShallowValidator.model_validate(body) # Merge the config from the path with the config from the body. user_provided_config = await _unpack_request_config( @@ -1407,7 +1412,7 @@ async def input_schema( self._run_name, user_provided_config, request ) - return self._runnable.get_input_schema(config).schema() + return self._runnable.get_input_schema(config).model_json_schema() async def output_schema( self, @@ -1434,7 +1439,7 @@ async def output_schema( config = _update_config_with_defaults( self._run_name, user_provided_config, request ) - return self._runnable.get_output_schema(config).schema() + return self._runnable.get_output_schema(config).model_json_schema() async def config_schema( self, @@ -1464,7 +1469,7 @@ async def config_schema( return ( self._runnable.with_config(config) .config_schema(include=self._config_keys) - .schema() + .model_json_schema() ) async def playground( diff --git a/langserve/playground.py b/langserve/playground.py index c9b77569..5ece4fa5 100644 --- a/langserve/playground.py +++ b/langserve/playground.py @@ -89,10 +89,12 @@ async def serve_playground( if base_url.startswith("/") else base_url, LANGSERVE_CONFIG_SCHEMA=json.dumps( - runnable.config_schema(include=config_keys).schema() + runnable.config_schema(include=config_keys).model_json_schema() + ), + LANGSERVE_INPUT_SCHEMA=json.dumps(input_schema.model_json_schema()), + LANGSERVE_OUTPUT_SCHEMA=json.dumps( + output_schema.model_json_schema() ), - LANGSERVE_INPUT_SCHEMA=json.dumps(input_schema.schema()), - LANGSERVE_OUTPUT_SCHEMA=json.dumps(output_schema.schema()), LANGSERVE_FEEDBACK_ENABLED=json.dumps( "true" if feedback_enabled else "false" ), diff --git a/langserve/serialization.py b/langserve/serialization.py index 1690c487..c5d45014 100644 --- a/langserve/serialization.py +++ b/langserve/serialization.py @@ -84,7 +84,7 @@ def _log_error_message_once(error_message: str) -> None: def default(obj) -> Any: """Default serialization for well known objects.""" if isinstance(obj, BaseModel): - return obj.dict() + return obj.model_dump() return super().default(obj) @@ -96,7 +96,7 @@ def _decode_lc_objects(value: Any) -> Any: try: obj = WellKnownLCObject.model_validate(v) parsed = obj.root - if set(parsed.dict()) != set(value): + if set(parsed.model_dump()) != set(value): raise ValueError("Invalid object") return parsed except (ValidationError, ValueError, TypeError): @@ -121,11 +121,11 @@ def _decode_event_data(value: Any) -> Any: """Decode the event data from a JSON object representation.""" if isinstance(value, dict): try: - obj = CallbackEvent.parse_obj(value) + obj = CallbackEvent.model_validate(value) return obj.root except ValidationError: try: - obj = WellKnownLCObject.parse_obj(value) + obj = WellKnownLCObject.model_validate(value) return obj.root except ValidationError: return {key: _decode_event_data(v) for key, v in value.items()} @@ -176,7 +176,7 @@ def loads(self, s: bytes) -> Any: def _project_top_level(model: BaseModel) -> Dict[str, Any]: """Project the top level of the model as dict.""" - return {key: getattr(model, key) for key in model.__fields__} + return {key: getattr(model, key) for key in model.model_fields} def load_events(events: Any) -> List[Dict[str, Any]]: @@ -207,7 +207,7 @@ def load_events(events: Any) -> List[Dict[str, Any]]: # Then validate the event try: - full_event = CallbackEvent.parse_obj(decoded_event_data) + full_event = CallbackEvent.model_validate(decoded_event_data) except ValidationError as e: msg = f"Encountered an invalid event: {e}" if "type" in decoded_event_data: diff --git a/tests/unit_tests/test_serialization.py b/tests/unit_tests/test_serialization.py index 86a8094f..0c700657 100644 --- a/tests/unit_tests/test_serialization.py +++ b/tests/unit_tests/test_serialization.py @@ -18,7 +18,7 @@ def test_document_serialization() -> None: """Simple test. Exhaustive tests follow below.""" doc = Document(page_content="hello") - d = doc.dict() + d = doc.model_dump() WellKnownLCObject.model_validate(d) @@ -87,7 +87,7 @@ def _get_full_representation(data: Any) -> Any: elif isinstance(data, list): return [_get_full_representation(value) for value in data] elif isinstance(data, BaseModel): - return data.schema() + return data.model_json_schema() else: return data diff --git a/tests/unit_tests/test_validation.py b/tests/unit_tests/test_validation.py index 3f364a1b..3e63d3e8 100644 --- a/tests/unit_tests/test_validation.py +++ b/tests/unit_tests/test_validation.py @@ -174,7 +174,7 @@ async def test_invoke_request_with_runnables() -> None: assert request.config.tags == ["hello"] assert request.config.run_name == "run" assert isinstance(request.config.configurable, BaseModel) - assert request.config.configurable.dict() == { + assert request.config.configurable.model_dump() == { "template": "goodbye {name}", }