From 0fb91e381847d3f979d5251e3b6e14db6927f98b Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Fri, 12 Dec 2025 22:54:27 +0000 Subject: [PATCH 1/4] fix(sagemaker/lora): allow no json body for register lora and parse query params as json body. --- .../sagemaker/lora/transforms/register.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py b/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py index 9d20114..92c2347 100644 --- a/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py +++ b/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py @@ -5,6 +5,7 @@ from fastapi.exceptions import HTTPException from pydantic import ValidationError +from ....logging_config import logger from ..base_lora_api_transform import BaseLoRAApiTransform from ..constants import ResponseMessage from ..models import BaseLoRATransformRequestOutput, SageMakerRegisterLoRAAdapterRequest @@ -48,11 +49,12 @@ async def transform_request( """ try: request_data = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}", - ) from e + except json.JSONDecodeError: + # if raw request does not have json body + # check if expected data is in the query parms + # and treat query params dict as body + logger.warning("No JSON body in the request. Using query parameters.") + request_data = raw_request.query_params request = validate_sagemaker_register_request(request_data) transformed_request = self._transform_request(request, raw_request) return BaseLoRATransformRequestOutput( From a98dfb684b557b4acf09cf80914194933c573930 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Sat, 13 Dec 2025 00:14:13 +0000 Subject: [PATCH 2/4] Update integration test to use query params option. Add unit test for query param happy path. --- .../test_sagemaker_lora_integration.py | 74 ++++++++++++++++--- .../transforms/test_register_transform.py | 43 ++++++++++- 2 files changed, 106 insertions(+), 11 deletions(-) diff --git a/python/tests/integration/test_sagemaker_lora_integration.py b/python/tests/integration/test_sagemaker_lora_integration.py index e4d9689..477fbed 100644 --- a/python/tests/integration/test_sagemaker_lora_integration.py +++ b/python/tests/integration/test_sagemaker_lora_integration.py @@ -180,11 +180,25 @@ async def invocations(request: Request): class TestLoRARouterRedirection(BaseLoRAIntegrationTest): """Test that bootstrap() correctly mounts LoRA routes from decorated handlers.""" - def test_register_adapter_route_mounted(self): + @pytest.mark.parametrize( + "test_type", + [ + ("body"), + ("query_params"), + ], + ids=[ + "body", + "query_params", + ], + ) + def test_register_adapter_route_mounted(self, test_type): """Test that POST /adapters route is mounted by bootstrap().""" # Call the SageMaker-standard route (not the engine's custom route) + lora_name = "test-adapter" + lora_path = "s3://bucket/adapter" response = self.client.post( - "/adapters", json={"name": "test-adapter", "src": "s3://bucket/adapter"} + f"/adapters{f'?name={lora_name}&src={lora_path}' if test_type == 'query_params' else ''}", + json={"name": lora_name, "src": lora_path} if test_type == "body" else None, ) assert response.status_code == 200 @@ -500,15 +514,29 @@ class TestLoRAEndToEndFlow(BaseLoRAIntegrationTest): simulating how a user would interact with a LoRA-enabled SageMaker endpoint. """ - def test_full_adapter_lifecycle(self): + @pytest.mark.parametrize( + "test_type", + [ + ("body"), + ("query_params"), + ], + ids=[ + "body", + "query_params", + ], + ) + def test_full_adapter_lifecycle(self, test_type): """Test complete lifecycle: register -> invoke with adapter -> unregister. This is the primary happy path: load an adapter, use it for inference, then unload it. Verifies all three operations work together. """ + lora_name = "lora-1" + lora_path = "s3://bucket/lora-1" # 1. Register an adapter register_response = self.client.post( - "/adapters", json={"name": "lora-1", "src": "s3://bucket/lora-1"} + f"/adapters{f'?name={lora_name}&src={lora_path}' if test_type == 'query_params' else ''}", + json={"name": lora_name, "src": lora_path} if test_type == "body" else None, ) assert register_response.status_code == 200 @@ -516,20 +544,46 @@ def test_full_adapter_lifecycle(self): invoke_response = self.client.post( "/invocations", json={"prompt": "hello"}, - headers={"X-Amzn-SageMaker-Adapter-Identifier": "lora-1"}, + headers={"X-Amzn-SageMaker-Adapter-Identifier": lora_name}, ) assert invoke_response.status_code == 200 # 3. Unregister the adapter - unregister_response = self.client.delete("/adapters/lora-1") + unregister_response = self.client.delete(f"/adapters/{lora_name}") assert unregister_response.status_code == 200 - def test_multiple_adapters(self): + @pytest.mark.parametrize( + "test_type", + [ + ("body"), + ("query_params"), + ], + ids=[ + "body", + "query_params", + ], + ) + def test_multiple_adapters(self, test_type): """Test managing multiple adapters simultaneously.""" # Register multiple adapters - self.client.post("/adapters", json={"name": "adapter-a", "src": "s3://a"}) - self.client.post("/adapters", json={"name": "adapter-b", "src": "s3://b"}) - self.client.post("/adapters", json={"name": "adapter-c", "src": "s3://c"}) + self.client.post( + f"/adapters{'?name=adapter_a&src=s3://a' if test_type == 'query_params' else ''}", + json=( + {"name": "adapter_a", "src": "s3://a"} if test_type == "body" else None + ), + ) + self.client.post( + f"/adapters{'?name=adapter_b&src=s3://b' if test_type == 'query_params' else ''}", + json=( + {"name": "adapter_b", "src": "s3://b"} if test_type == "body" else None + ), + ) + self.client.post( + f"/adapters{'?name=adapter_c&src=s3://c' if test_type == 'query_params' else ''}", + json=( + {"name": "adapter_c", "src": "s3://c"} if test_type == "body" else None + ), + ) # Invoke with different adapters - each should route correctly response_a = self.client.post( diff --git a/python/tests/sagemaker/lora/transforms/test_register_transform.py b/python/tests/sagemaker/lora/transforms/test_register_transform.py index 3cae845..18f2897 100644 --- a/python/tests/sagemaker/lora/transforms/test_register_transform.py +++ b/python/tests/sagemaker/lora/transforms/test_register_transform.py @@ -1,6 +1,7 @@ """Unit tests for RegisterLoRAApiTransform.""" from http import HTTPStatus +from json import JSONDecodeError from unittest.mock import AsyncMock, Mock, patch import pytest @@ -212,7 +213,9 @@ def test_transform_ok_response(self): @patch( "model_hosting_container_standards.sagemaker.lora.utils.get_adapter_alias_from_request_header" ) - async def test_integration_transform_request_and_response(self, mock_get_alias): + async def test_integration_transform_request_and_response_json_body( + self, mock_get_alias + ): """Test integration between request and response transformation.""" # Setup request transformation with mocked raw request mock_raw_request = Mock(spec=Request) @@ -241,3 +244,41 @@ async def test_integration_transform_request_and_response(self, mock_get_alias): assert "registered" in ok_result.body.decode() # TODO: test error transformation once implemented + + @pytest.mark.asyncio + @patch( + "model_hosting_container_standards.sagemaker.lora.utils.get_adapter_alias_from_request_header" + ) + async def test_integration_transform_request_and_response_query_params( + self, mock_get_alias + ): + """Test integration between request and response transformation.""" + # Setup request transformation with mocked raw request + mock_raw_request = Mock(spec=Request) + mock_raw_request.json.side_effect = JSONDecodeError("test error", doc="", pos=1) + mock_raw_request.query_params = { + "name": "integration-test", + "src": "s3://integration", + } + mock_get_alias.return_value = "integration-alias" + + # Transform request - only pass raw_request + transform_output = await self.transformer.transform_request(mock_raw_request) + + # Verify request transformation + assert transform_output.adapter_name == "integration-test" + + # Transform successful response + mock_ok_response = Mock(spec=Response) + mock_ok_response.status_code = HTTPStatus.OK + mock_ok_response.headers = {} + mock_ok_response.media_type = "application/json" + + ok_result = self.transformer.transform_response( + mock_ok_response, transform_output + ) + + assert ok_result.status_code == HTTPStatus.OK + assert "registered" in ok_result.body.decode() + + # TODO: test error transformation once implemented From 57c2414d7a067c2956639b9a4a20cd32f2905092 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Sat, 13 Dec 2025 01:58:45 +0000 Subject: [PATCH 3/4] Improve unit tests. --- .../test_sagemaker_lora_integration.py | 143 ++++++++++++------ 1 file changed, 97 insertions(+), 46 deletions(-) diff --git a/python/tests/integration/test_sagemaker_lora_integration.py b/python/tests/integration/test_sagemaker_lora_integration.py index 477fbed..d5d5d8f 100644 --- a/python/tests/integration/test_sagemaker_lora_integration.py +++ b/python/tests/integration/test_sagemaker_lora_integration.py @@ -102,22 +102,52 @@ def setup_method(self): sagemaker_standards.bootstrap(self.app) self.client = TestClient(self.app) - def setup_handlers(self): + def make_adapter_request_params(self, test_type, name, src, base_url="/adapters"): + """Helper to generate URL and JSON for adapter requests based on test_type. + + Args: + test_type: Either "body" or "query_params" + name: The adapter name + src: The adapter source path + base_url: The base URL for the request (default: "/adapters") + + Returns: + Tuple of (url, json_data) for use in client requests + """ + if test_type == "query_params": + url = f"{base_url}?name={name}&src={src}" + json_data = None + else: # body + url = base_url + json_data = {"name": name, "src": src} + return url, json_data + + def setup_handlers(self, test_type="body"): """Define handlers for end-to-end lifecycle tests. Sets up three handlers that simulate a LoRA-enabled inference engine: 1. load_lora_adapter - Loads adapters into the registry 2. unload_lora_adapter - Removes adapters from the registry 3. invocations - Handles inference with optional adapter selection + + Args: + test_type: Either "body" or "query_params" to determine request source """ # Simulate a simple adapter registry self.adapters = {} + # Determine request shape based on test type + source_prefix = "body" if test_type == "body" else "query_params" + request_shape = { + "lora_name": f"{source_prefix}.name", + "lora_path": f"{source_prefix}.src", + } + # Handler 1: Load adapter - # The decorator transforms: {"name": "x", "src": "y"} -> {"lora_name": "x", "lora_path": "y"} - @sagemaker_standards.register_load_adapter_handler( - request_shape={"lora_name": "body.name", "lora_path": "body.src"} - ) + # The decorator transforms based on test_type: + # - body: {"name": "x", "src": "y"} -> {"lora_name": "x", "lora_path": "y"} + # - query_params: ?name=x&src=y -> {"lora_name": "x", "lora_path": "y"} + @sagemaker_standards.register_load_adapter_handler(request_shape=request_shape) @self.router.post("/v1/load_lora_adapter") async def load_lora_adapter( request: EngineLoadLoRAAdapterRequest, raw_request: Request @@ -193,13 +223,18 @@ class TestLoRARouterRedirection(BaseLoRAIntegrationTest): ) def test_register_adapter_route_mounted(self, test_type): """Test that POST /adapters route is mounted by bootstrap().""" + # Re-setup handlers with the correct test_type for this parametrized test + handler_registry.clear() + self.setup_handlers(test_type) + sagemaker_standards.bootstrap(self.app) + # Call the SageMaker-standard route (not the engine's custom route) lora_name = "test-adapter" lora_path = "s3://bucket/adapter" - response = self.client.post( - f"/adapters{f'?name={lora_name}&src={lora_path}' if test_type == 'query_params' else ''}", - json={"name": lora_name, "src": lora_path} if test_type == "body" else None, + url, json_data = self.make_adapter_request_params( + test_type, lora_name, lora_path ) + response = self.client.post(url, json=json_data) assert response.status_code == 200 @@ -386,16 +421,25 @@ async def invocations(request: Request): assert response.status_code == 200 assert "lora-1" in response.text - def test_nested_jmespath_transformations(self): + @pytest.mark.parametrize( + "test_type", + [ + ("body"), + ("query_params"), + ], + ids=[ + "body", + "query_params", + ], + ) + def test_nested_jmespath_transformations(self, test_type): """Test nested JMESPath expressions in request_shape. Verifies that request_shape can contain nested dictionaries, and JMESPath expressions work at any nesting level. """ + # Re-setup handlers with the correct test_type for this parametrized test handler_registry.clear() - - app = FastAPI() - router = APIRouter() self.capture.clear() # Clear the load capture # Define request model with nested structure @@ -403,15 +447,19 @@ class NestedLoadLoRAAdapterRequest(BaseModel): adapter_config: dict # Nested dict field source_path: str + # Determine request shape based on test type + source_prefix = "body" if test_type == "body" else "query_params" + nested_request_shape = { + "adapter_config": { # Target is a nested dict + "name": f"{source_prefix}.name", # Extract from source.name + }, + "source_path": f"{source_prefix}.src", # Extract from source.src + } + @sagemaker_standards.register_load_adapter_handler( - request_shape={ - "adapter_config": { # Target is a nested dict - "name": "body.name", # Extract from body.name - }, - "source_path": "body.src", # Extract from body.src - } + request_shape=nested_request_shape ) - @router.post("/v1/nested_load") + @self.router.post("/v1/nested_load") async def nested_load( request: NestedLoadLoRAAdapterRequest, raw_request: Request ): @@ -422,17 +470,14 @@ async def nested_load( content=f"name={request.adapter_config['name']},source={request.source_path}", ) - app.include_router(router) - sagemaker_standards.bootstrap(app) - client = TestClient(app) + sagemaker_standards.bootstrap(self.app) - response = client.post( - "/adapters", - json={ - "name": "nested-adapter", - "src": "s3://nested/path", - }, + lora_name = "nested-adapter" + lora_path = "s3://nested/path" + url, json_data = self.make_adapter_request_params( + test_type, lora_name, lora_path ) + response = self.client.post(url, json=json_data) assert response.status_code == 200 assert "nested-adapter" in response.text @@ -531,13 +576,18 @@ def test_full_adapter_lifecycle(self, test_type): This is the primary happy path: load an adapter, use it for inference, then unload it. Verifies all three operations work together. """ + # Re-setup handlers with the correct test_type for this parametrized test + handler_registry.clear() + self.setup_handlers(test_type) + sagemaker_standards.bootstrap(self.app) + lora_name = "lora-1" lora_path = "s3://bucket/lora-1" # 1. Register an adapter - register_response = self.client.post( - f"/adapters{f'?name={lora_name}&src={lora_path}' if test_type == 'query_params' else ''}", - json={"name": lora_name, "src": lora_path} if test_type == "body" else None, + url, json_data = self.make_adapter_request_params( + test_type, lora_name, lora_path ) + register_response = self.client.post(url, json=json_data) assert register_response.status_code == 200 # 2. Invoke with the adapter @@ -565,25 +615,26 @@ def test_full_adapter_lifecycle(self, test_type): ) def test_multiple_adapters(self, test_type): """Test managing multiple adapters simultaneously.""" + # Re-setup handlers with the correct test_type for this parametrized test + handler_registry.clear() + self.setup_handlers(test_type) + sagemaker_standards.bootstrap(self.app) + # Register multiple adapters - self.client.post( - f"/adapters{'?name=adapter_a&src=s3://a' if test_type == 'query_params' else ''}", - json=( - {"name": "adapter_a", "src": "s3://a"} if test_type == "body" else None - ), + url_a, json_a = self.make_adapter_request_params( + test_type, "adapter_a", "s3://a" ) - self.client.post( - f"/adapters{'?name=adapter_b&src=s3://b' if test_type == 'query_params' else ''}", - json=( - {"name": "adapter_b", "src": "s3://b"} if test_type == "body" else None - ), + self.client.post(url_a, json=json_a) + + url_b, json_b = self.make_adapter_request_params( + test_type, "adapter_b", "s3://b" ) - self.client.post( - f"/adapters{'?name=adapter_c&src=s3://c' if test_type == 'query_params' else ''}", - json=( - {"name": "adapter_c", "src": "s3://c"} if test_type == "body" else None - ), + self.client.post(url_b, json=json_b) + + url_c, json_c = self.make_adapter_request_params( + test_type, "adapter_c", "s3://c" ) + self.client.post(url_c, json=json_c) # Invoke with different adapters - each should route correctly response_a = self.client.post( From 0df4365168946e49affaba42c1781c060630de60 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Mon, 15 Dec 2025 19:15:39 +0000 Subject: [PATCH 4/4] add TODO to not use query params once dependencies are updated --- .../sagemaker/lora/transforms/register.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py b/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py index 92c2347..c451e57 100644 --- a/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py +++ b/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py @@ -53,6 +53,8 @@ async def transform_request( # if raw request does not have json body # check if expected data is in the query parms # and treat query params dict as body + # TODO: remove this once dependencies don't expect + # fields to be in `body.<...>` logger.warning("No JSON body in the request. Using query parameters.") request_data = raw_request.query_params request = validate_sagemaker_register_request(request_data)