Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi.exceptions import HTTPException
from pydantic import ValidationError

from ....logging_config import logger
from ..base_lora_api_transform import BaseLoRAApiTransform
from ..constants import ResponseMessage
from ..models import BaseLoRATransformRequestOutput, SageMakerRegisterLoRAAdapterRequest
Expand Down Expand Up @@ -48,11 +49,14 @@ async def transform_request(
"""
try:
request_data = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
except json.JSONDecodeError:
# if raw request does not have json body
# check if expected data is in the query parms
# and treat query params dict as body
# TODO: remove this once dependencies don't expect
# fields to be in `body.<...>`
logger.warning("No JSON body in the request. Using query parameters.")
request_data = raw_request.query_params
request = validate_sagemaker_register_request(request_data)
transformed_request = self._transform_request(request, raw_request)
return BaseLoRATransformRequestOutput(
Expand Down
179 changes: 142 additions & 37 deletions python/tests/integration/test_sagemaker_lora_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,52 @@ def setup_method(self):
sagemaker_standards.bootstrap(self.app)
self.client = TestClient(self.app)

def setup_handlers(self):
def make_adapter_request_params(self, test_type, name, src, base_url="/adapters"):
"""Helper to generate URL and JSON for adapter requests based on test_type.

Args:
test_type: Either "body" or "query_params"
name: The adapter name
src: The adapter source path
base_url: The base URL for the request (default: "/adapters")

Returns:
Tuple of (url, json_data) for use in client requests
"""
if test_type == "query_params":
url = f"{base_url}?name={name}&src={src}"
json_data = None
else: # body
url = base_url
json_data = {"name": name, "src": src}
return url, json_data

def setup_handlers(self, test_type="body"):
"""Define handlers for end-to-end lifecycle tests.

Sets up three handlers that simulate a LoRA-enabled inference engine:
1. load_lora_adapter - Loads adapters into the registry
2. unload_lora_adapter - Removes adapters from the registry
3. invocations - Handles inference with optional adapter selection

Args:
test_type: Either "body" or "query_params" to determine request source
"""
# Simulate a simple adapter registry
self.adapters = {}

# Determine request shape based on test type
source_prefix = "body" if test_type == "body" else "query_params"
request_shape = {
"lora_name": f"{source_prefix}.name",
"lora_path": f"{source_prefix}.src",
}

# Handler 1: Load adapter
# The decorator transforms: {"name": "x", "src": "y"} -> {"lora_name": "x", "lora_path": "y"}
@sagemaker_standards.register_load_adapter_handler(
request_shape={"lora_name": "body.name", "lora_path": "body.src"}
)
# The decorator transforms based on test_type:
# - body: {"name": "x", "src": "y"} -> {"lora_name": "x", "lora_path": "y"}
# - query_params: ?name=x&src=y -> {"lora_name": "x", "lora_path": "y"}
@sagemaker_standards.register_load_adapter_handler(request_shape=request_shape)
@self.router.post("/v1/load_lora_adapter")
async def load_lora_adapter(
request: EngineLoadLoRAAdapterRequest, raw_request: Request
Expand Down Expand Up @@ -180,12 +210,31 @@ async def invocations(request: Request):
class TestLoRARouterRedirection(BaseLoRAIntegrationTest):
"""Test that bootstrap() correctly mounts LoRA routes from decorated handlers."""

def test_register_adapter_route_mounted(self):
@pytest.mark.parametrize(
"test_type",
[
("body"),
("query_params"),
],
ids=[
"body",
"query_params",
],
)
def test_register_adapter_route_mounted(self, test_type):
"""Test that POST /adapters route is mounted by bootstrap()."""
# Re-setup handlers with the correct test_type for this parametrized test
handler_registry.clear()
self.setup_handlers(test_type)
sagemaker_standards.bootstrap(self.app)

# Call the SageMaker-standard route (not the engine's custom route)
response = self.client.post(
"/adapters", json={"name": "test-adapter", "src": "s3://bucket/adapter"}
lora_name = "test-adapter"
lora_path = "s3://bucket/adapter"
url, json_data = self.make_adapter_request_params(
test_type, lora_name, lora_path
)
response = self.client.post(url, json=json_data)

assert response.status_code == 200

Expand Down Expand Up @@ -372,32 +421,45 @@ async def invocations(request: Request):
assert response.status_code == 200
assert "lora-1" in response.text

def test_nested_jmespath_transformations(self):
@pytest.mark.parametrize(
"test_type",
[
("body"),
("query_params"),
],
ids=[
"body",
"query_params",
],
)
def test_nested_jmespath_transformations(self, test_type):
"""Test nested JMESPath expressions in request_shape.

Verifies that request_shape can contain nested dictionaries, and
JMESPath expressions work at any nesting level.
"""
# Re-setup handlers with the correct test_type for this parametrized test
handler_registry.clear()

app = FastAPI()
router = APIRouter()
self.capture.clear() # Clear the load capture

# Define request model with nested structure
class NestedLoadLoRAAdapterRequest(BaseModel):
adapter_config: dict # Nested dict field
source_path: str

# Determine request shape based on test type
source_prefix = "body" if test_type == "body" else "query_params"
nested_request_shape = {
"adapter_config": { # Target is a nested dict
"name": f"{source_prefix}.name", # Extract from source.name
},
"source_path": f"{source_prefix}.src", # Extract from source.src
}

@sagemaker_standards.register_load_adapter_handler(
request_shape={
"adapter_config": { # Target is a nested dict
"name": "body.name", # Extract from body.name
},
"source_path": "body.src", # Extract from body.src
}
request_shape=nested_request_shape
)
@router.post("/v1/nested_load")
@self.router.post("/v1/nested_load")
async def nested_load(
request: NestedLoadLoRAAdapterRequest, raw_request: Request
):
Expand All @@ -408,17 +470,14 @@ async def nested_load(
content=f"name={request.adapter_config['name']},source={request.source_path}",
)

app.include_router(router)
sagemaker_standards.bootstrap(app)
client = TestClient(app)
sagemaker_standards.bootstrap(self.app)

response = client.post(
"/adapters",
json={
"name": "nested-adapter",
"src": "s3://nested/path",
},
lora_name = "nested-adapter"
lora_path = "s3://nested/path"
url, json_data = self.make_adapter_request_params(
test_type, lora_name, lora_path
)
response = self.client.post(url, json=json_data)

assert response.status_code == 200
assert "nested-adapter" in response.text
Expand Down Expand Up @@ -500,36 +559,82 @@ class TestLoRAEndToEndFlow(BaseLoRAIntegrationTest):
simulating how a user would interact with a LoRA-enabled SageMaker endpoint.
"""

def test_full_adapter_lifecycle(self):
@pytest.mark.parametrize(
"test_type",
[
("body"),
("query_params"),
],
ids=[
"body",
"query_params",
],
)
def test_full_adapter_lifecycle(self, test_type):
"""Test complete lifecycle: register -> invoke with adapter -> unregister.

This is the primary happy path: load an adapter, use it for inference,
then unload it. Verifies all three operations work together.
"""
# Re-setup handlers with the correct test_type for this parametrized test
handler_registry.clear()
self.setup_handlers(test_type)
sagemaker_standards.bootstrap(self.app)

lora_name = "lora-1"
lora_path = "s3://bucket/lora-1"
# 1. Register an adapter
register_response = self.client.post(
"/adapters", json={"name": "lora-1", "src": "s3://bucket/lora-1"}
url, json_data = self.make_adapter_request_params(
test_type, lora_name, lora_path
)
register_response = self.client.post(url, json=json_data)
assert register_response.status_code == 200

# 2. Invoke with the adapter
invoke_response = self.client.post(
"/invocations",
json={"prompt": "hello"},
headers={"X-Amzn-SageMaker-Adapter-Identifier": "lora-1"},
headers={"X-Amzn-SageMaker-Adapter-Identifier": lora_name},
)
assert invoke_response.status_code == 200

# 3. Unregister the adapter
unregister_response = self.client.delete("/adapters/lora-1")
unregister_response = self.client.delete(f"/adapters/{lora_name}")
assert unregister_response.status_code == 200

def test_multiple_adapters(self):
@pytest.mark.parametrize(
"test_type",
[
("body"),
("query_params"),
],
ids=[
"body",
"query_params",
],
)
def test_multiple_adapters(self, test_type):
"""Test managing multiple adapters simultaneously."""
# Re-setup handlers with the correct test_type for this parametrized test
handler_registry.clear()
self.setup_handlers(test_type)
sagemaker_standards.bootstrap(self.app)

# Register multiple adapters
self.client.post("/adapters", json={"name": "adapter-a", "src": "s3://a"})
self.client.post("/adapters", json={"name": "adapter-b", "src": "s3://b"})
self.client.post("/adapters", json={"name": "adapter-c", "src": "s3://c"})
url_a, json_a = self.make_adapter_request_params(
test_type, "adapter_a", "s3://a"
)
self.client.post(url_a, json=json_a)

url_b, json_b = self.make_adapter_request_params(
test_type, "adapter_b", "s3://b"
)
self.client.post(url_b, json=json_b)

url_c, json_c = self.make_adapter_request_params(
test_type, "adapter_c", "s3://c"
)
self.client.post(url_c, json=json_c)

# Invoke with different adapters - each should route correctly
response_a = self.client.post(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Unit tests for RegisterLoRAApiTransform."""

from http import HTTPStatus
from json import JSONDecodeError
from unittest.mock import AsyncMock, Mock, patch

import pytest
Expand Down Expand Up @@ -212,7 +213,9 @@ def test_transform_ok_response(self):
@patch(
"model_hosting_container_standards.sagemaker.lora.utils.get_adapter_alias_from_request_header"
)
async def test_integration_transform_request_and_response(self, mock_get_alias):
async def test_integration_transform_request_and_response_json_body(
self, mock_get_alias
):
"""Test integration between request and response transformation."""
# Setup request transformation with mocked raw request
mock_raw_request = Mock(spec=Request)
Expand Down Expand Up @@ -241,3 +244,41 @@ async def test_integration_transform_request_and_response(self, mock_get_alias):
assert "registered" in ok_result.body.decode()

# TODO: test error transformation once implemented

@pytest.mark.asyncio
@patch(
"model_hosting_container_standards.sagemaker.lora.utils.get_adapter_alias_from_request_header"
)
async def test_integration_transform_request_and_response_query_params(
self, mock_get_alias
):
"""Test integration between request and response transformation."""
# Setup request transformation with mocked raw request
mock_raw_request = Mock(spec=Request)
mock_raw_request.json.side_effect = JSONDecodeError("test error", doc="", pos=1)
mock_raw_request.query_params = {
"name": "integration-test",
"src": "s3://integration",
}
mock_get_alias.return_value = "integration-alias"

# Transform request - only pass raw_request
transform_output = await self.transformer.transform_request(mock_raw_request)

# Verify request transformation
assert transform_output.adapter_name == "integration-test"

# Transform successful response
mock_ok_response = Mock(spec=Response)
mock_ok_response.status_code = HTTPStatus.OK
mock_ok_response.headers = {}
mock_ok_response.media_type = "application/json"

ok_result = self.transformer.transform_response(
mock_ok_response, transform_output
)

assert ok_result.status_code == HTTPStatus.OK
assert "registered" in ok_result.body.decode()

# TODO: test error transformation once implemented