diff --git a/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py b/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py index 9d20114..c451e57 100644 --- a/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py +++ b/python/model_hosting_container_standards/sagemaker/lora/transforms/register.py @@ -5,6 +5,7 @@ from fastapi.exceptions import HTTPException from pydantic import ValidationError +from ....logging_config import logger from ..base_lora_api_transform import BaseLoRAApiTransform from ..constants import ResponseMessage from ..models import BaseLoRATransformRequestOutput, SageMakerRegisterLoRAAdapterRequest @@ -48,11 +49,14 @@ async def transform_request( """ try: request_data = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}", - ) from e + except json.JSONDecodeError: + # if raw request does not have json body + # check if expected data is in the query parms + # and treat query params dict as body + # TODO: remove this once dependencies don't expect + # fields to be in `body.<...>` + logger.warning("No JSON body in the request. Using query parameters.") + request_data = raw_request.query_params request = validate_sagemaker_register_request(request_data) transformed_request = self._transform_request(request, raw_request) return BaseLoRATransformRequestOutput( diff --git a/python/tests/integration/test_sagemaker_lora_integration.py b/python/tests/integration/test_sagemaker_lora_integration.py index e4d9689..d5d5d8f 100644 --- a/python/tests/integration/test_sagemaker_lora_integration.py +++ b/python/tests/integration/test_sagemaker_lora_integration.py @@ -102,22 +102,52 @@ def setup_method(self): sagemaker_standards.bootstrap(self.app) self.client = TestClient(self.app) - def setup_handlers(self): + def make_adapter_request_params(self, test_type, name, src, base_url="/adapters"): + """Helper to generate URL and JSON for adapter requests based on test_type. + + Args: + test_type: Either "body" or "query_params" + name: The adapter name + src: The adapter source path + base_url: The base URL for the request (default: "/adapters") + + Returns: + Tuple of (url, json_data) for use in client requests + """ + if test_type == "query_params": + url = f"{base_url}?name={name}&src={src}" + json_data = None + else: # body + url = base_url + json_data = {"name": name, "src": src} + return url, json_data + + def setup_handlers(self, test_type="body"): """Define handlers for end-to-end lifecycle tests. Sets up three handlers that simulate a LoRA-enabled inference engine: 1. load_lora_adapter - Loads adapters into the registry 2. unload_lora_adapter - Removes adapters from the registry 3. invocations - Handles inference with optional adapter selection + + Args: + test_type: Either "body" or "query_params" to determine request source """ # Simulate a simple adapter registry self.adapters = {} + # Determine request shape based on test type + source_prefix = "body" if test_type == "body" else "query_params" + request_shape = { + "lora_name": f"{source_prefix}.name", + "lora_path": f"{source_prefix}.src", + } + # Handler 1: Load adapter - # The decorator transforms: {"name": "x", "src": "y"} -> {"lora_name": "x", "lora_path": "y"} - @sagemaker_standards.register_load_adapter_handler( - request_shape={"lora_name": "body.name", "lora_path": "body.src"} - ) + # The decorator transforms based on test_type: + # - body: {"name": "x", "src": "y"} -> {"lora_name": "x", "lora_path": "y"} + # - query_params: ?name=x&src=y -> {"lora_name": "x", "lora_path": "y"} + @sagemaker_standards.register_load_adapter_handler(request_shape=request_shape) @self.router.post("/v1/load_lora_adapter") async def load_lora_adapter( request: EngineLoadLoRAAdapterRequest, raw_request: Request @@ -180,12 +210,31 @@ async def invocations(request: Request): class TestLoRARouterRedirection(BaseLoRAIntegrationTest): """Test that bootstrap() correctly mounts LoRA routes from decorated handlers.""" - def test_register_adapter_route_mounted(self): + @pytest.mark.parametrize( + "test_type", + [ + ("body"), + ("query_params"), + ], + ids=[ + "body", + "query_params", + ], + ) + def test_register_adapter_route_mounted(self, test_type): """Test that POST /adapters route is mounted by bootstrap().""" + # Re-setup handlers with the correct test_type for this parametrized test + handler_registry.clear() + self.setup_handlers(test_type) + sagemaker_standards.bootstrap(self.app) + # Call the SageMaker-standard route (not the engine's custom route) - response = self.client.post( - "/adapters", json={"name": "test-adapter", "src": "s3://bucket/adapter"} + lora_name = "test-adapter" + lora_path = "s3://bucket/adapter" + url, json_data = self.make_adapter_request_params( + test_type, lora_name, lora_path ) + response = self.client.post(url, json=json_data) assert response.status_code == 200 @@ -372,16 +421,25 @@ async def invocations(request: Request): assert response.status_code == 200 assert "lora-1" in response.text - def test_nested_jmespath_transformations(self): + @pytest.mark.parametrize( + "test_type", + [ + ("body"), + ("query_params"), + ], + ids=[ + "body", + "query_params", + ], + ) + def test_nested_jmespath_transformations(self, test_type): """Test nested JMESPath expressions in request_shape. Verifies that request_shape can contain nested dictionaries, and JMESPath expressions work at any nesting level. """ + # Re-setup handlers with the correct test_type for this parametrized test handler_registry.clear() - - app = FastAPI() - router = APIRouter() self.capture.clear() # Clear the load capture # Define request model with nested structure @@ -389,15 +447,19 @@ class NestedLoadLoRAAdapterRequest(BaseModel): adapter_config: dict # Nested dict field source_path: str + # Determine request shape based on test type + source_prefix = "body" if test_type == "body" else "query_params" + nested_request_shape = { + "adapter_config": { # Target is a nested dict + "name": f"{source_prefix}.name", # Extract from source.name + }, + "source_path": f"{source_prefix}.src", # Extract from source.src + } + @sagemaker_standards.register_load_adapter_handler( - request_shape={ - "adapter_config": { # Target is a nested dict - "name": "body.name", # Extract from body.name - }, - "source_path": "body.src", # Extract from body.src - } + request_shape=nested_request_shape ) - @router.post("/v1/nested_load") + @self.router.post("/v1/nested_load") async def nested_load( request: NestedLoadLoRAAdapterRequest, raw_request: Request ): @@ -408,17 +470,14 @@ async def nested_load( content=f"name={request.adapter_config['name']},source={request.source_path}", ) - app.include_router(router) - sagemaker_standards.bootstrap(app) - client = TestClient(app) + sagemaker_standards.bootstrap(self.app) - response = client.post( - "/adapters", - json={ - "name": "nested-adapter", - "src": "s3://nested/path", - }, + lora_name = "nested-adapter" + lora_path = "s3://nested/path" + url, json_data = self.make_adapter_request_params( + test_type, lora_name, lora_path ) + response = self.client.post(url, json=json_data) assert response.status_code == 200 assert "nested-adapter" in response.text @@ -500,36 +559,82 @@ class TestLoRAEndToEndFlow(BaseLoRAIntegrationTest): simulating how a user would interact with a LoRA-enabled SageMaker endpoint. """ - def test_full_adapter_lifecycle(self): + @pytest.mark.parametrize( + "test_type", + [ + ("body"), + ("query_params"), + ], + ids=[ + "body", + "query_params", + ], + ) + def test_full_adapter_lifecycle(self, test_type): """Test complete lifecycle: register -> invoke with adapter -> unregister. This is the primary happy path: load an adapter, use it for inference, then unload it. Verifies all three operations work together. """ + # Re-setup handlers with the correct test_type for this parametrized test + handler_registry.clear() + self.setup_handlers(test_type) + sagemaker_standards.bootstrap(self.app) + + lora_name = "lora-1" + lora_path = "s3://bucket/lora-1" # 1. Register an adapter - register_response = self.client.post( - "/adapters", json={"name": "lora-1", "src": "s3://bucket/lora-1"} + url, json_data = self.make_adapter_request_params( + test_type, lora_name, lora_path ) + register_response = self.client.post(url, json=json_data) assert register_response.status_code == 200 # 2. Invoke with the adapter invoke_response = self.client.post( "/invocations", json={"prompt": "hello"}, - headers={"X-Amzn-SageMaker-Adapter-Identifier": "lora-1"}, + headers={"X-Amzn-SageMaker-Adapter-Identifier": lora_name}, ) assert invoke_response.status_code == 200 # 3. Unregister the adapter - unregister_response = self.client.delete("/adapters/lora-1") + unregister_response = self.client.delete(f"/adapters/{lora_name}") assert unregister_response.status_code == 200 - def test_multiple_adapters(self): + @pytest.mark.parametrize( + "test_type", + [ + ("body"), + ("query_params"), + ], + ids=[ + "body", + "query_params", + ], + ) + def test_multiple_adapters(self, test_type): """Test managing multiple adapters simultaneously.""" + # Re-setup handlers with the correct test_type for this parametrized test + handler_registry.clear() + self.setup_handlers(test_type) + sagemaker_standards.bootstrap(self.app) + # Register multiple adapters - self.client.post("/adapters", json={"name": "adapter-a", "src": "s3://a"}) - self.client.post("/adapters", json={"name": "adapter-b", "src": "s3://b"}) - self.client.post("/adapters", json={"name": "adapter-c", "src": "s3://c"}) + url_a, json_a = self.make_adapter_request_params( + test_type, "adapter_a", "s3://a" + ) + self.client.post(url_a, json=json_a) + + url_b, json_b = self.make_adapter_request_params( + test_type, "adapter_b", "s3://b" + ) + self.client.post(url_b, json=json_b) + + url_c, json_c = self.make_adapter_request_params( + test_type, "adapter_c", "s3://c" + ) + self.client.post(url_c, json=json_c) # Invoke with different adapters - each should route correctly response_a = self.client.post( diff --git a/python/tests/sagemaker/lora/transforms/test_register_transform.py b/python/tests/sagemaker/lora/transforms/test_register_transform.py index 3cae845..18f2897 100644 --- a/python/tests/sagemaker/lora/transforms/test_register_transform.py +++ b/python/tests/sagemaker/lora/transforms/test_register_transform.py @@ -1,6 +1,7 @@ """Unit tests for RegisterLoRAApiTransform.""" from http import HTTPStatus +from json import JSONDecodeError from unittest.mock import AsyncMock, Mock, patch import pytest @@ -212,7 +213,9 @@ def test_transform_ok_response(self): @patch( "model_hosting_container_standards.sagemaker.lora.utils.get_adapter_alias_from_request_header" ) - async def test_integration_transform_request_and_response(self, mock_get_alias): + async def test_integration_transform_request_and_response_json_body( + self, mock_get_alias + ): """Test integration between request and response transformation.""" # Setup request transformation with mocked raw request mock_raw_request = Mock(spec=Request) @@ -241,3 +244,41 @@ async def test_integration_transform_request_and_response(self, mock_get_alias): assert "registered" in ok_result.body.decode() # TODO: test error transformation once implemented + + @pytest.mark.asyncio + @patch( + "model_hosting_container_standards.sagemaker.lora.utils.get_adapter_alias_from_request_header" + ) + async def test_integration_transform_request_and_response_query_params( + self, mock_get_alias + ): + """Test integration between request and response transformation.""" + # Setup request transformation with mocked raw request + mock_raw_request = Mock(spec=Request) + mock_raw_request.json.side_effect = JSONDecodeError("test error", doc="", pos=1) + mock_raw_request.query_params = { + "name": "integration-test", + "src": "s3://integration", + } + mock_get_alias.return_value = "integration-alias" + + # Transform request - only pass raw_request + transform_output = await self.transformer.transform_request(mock_raw_request) + + # Verify request transformation + assert transform_output.adapter_name == "integration-test" + + # Transform successful response + mock_ok_response = Mock(spec=Response) + mock_ok_response.status_code = HTTPStatus.OK + mock_ok_response.headers = {} + mock_ok_response.media_type = "application/json" + + ok_result = self.transformer.transform_response( + mock_ok_response, transform_output + ) + + assert ok_result.status_code == HTTPStatus.OK + assert "registered" in ok_result.body.decode() + + # TODO: test error transformation once implemented