From 73342bf51ad395433bda91d1f468bf33ef76f7c9 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Tue, 26 May 2026 00:57:18 -0700 Subject: [PATCH 1/3] feat: bedrock-oss-provisioned-throughput-polling --- .../sagemaker/serve/bedrock_model_builder.py | 176 +++++++++- .../tests/unit/test_bedrock_model_builder.py | 331 +++++++++++++++++- 2 files changed, 500 insertions(+), 7 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py index 38cbba09c2..663ed41579 100644 --- a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py @@ -118,12 +118,18 @@ def deploy( client_request_token: Optional[str] = None, imported_model_kms_key_id: Optional[str] = None, deployment_name: Optional[str] = None, + provisioned_model_name: Optional[str] = None, + model_units: int = 1, + commitment_duration: Optional[str] = None, + provisioned_model_tags: Optional[list] = None, ) -> Dict[str, Any]: """Deploy the model to Bedrock. Automatically detects if the model is a Nova model and uses the appropriate Bedrock API (create_custom_model for Nova, create_model_import_job for others). For Nova models, also creates a custom model deployment for inference. + For OSS models, creates a model import job, waits for completion, then creates + provisioned throughput and waits for it to become InService. Args: job_name: Name for the model import job (non-Nova models only). @@ -137,14 +143,25 @@ def deploy( imported_model_kms_key_id: KMS key ID for encryption (non-Nova models only). deployment_name: Name for the deployment (Nova models only). If not provided, defaults to custom_model_name suffixed with '-deployment'. + provisioned_model_name: Name for the provisioned throughput resource + (non-Nova models only). If not provided, defaults to + imported_model_name suffixed with '-throughput'. + model_units: Number of model units for provisioned throughput (non-Nova + models only). Defaults to 1. + commitment_duration: Commitment duration for provisioned throughput + (non-Nova models only). Valid values: 'OneMonth', 'SixMonths'. + If not provided, no commitment is set (on-demand). + provisioned_model_tags: Tags for the provisioned throughput resource + (non-Nova models only). Returns: Response from Bedrock API. For Nova models, returns the - create_custom_model_deployment response. For others, returns - the create_model_import_job response. + create_custom_model_deployment response. For OSS models, returns + the create_provisioned_model_throughput response. Raises: ValueError: If model_package is not set or required parameters are missing. + RuntimeError: If the import job or provisioned throughput fails or times out. """ if not self.model_package: raise ValueError( @@ -190,7 +207,26 @@ def deploy( params = {k: v for k, v in params.items() if v is not None} logger.info("Creating model import job for non-Nova deployment") - return self._get_bedrock_client().create_model_import_job(**params) + import_response = self._get_bedrock_client().create_model_import_job(**params) + + job_arn = import_response.get("jobArn") + self._wait_for_import_job_complete(job_arn) + + # Get the imported model ARN from the completed job + job_details = self._get_bedrock_client().get_model_import_job( + jobIdentifier=job_arn + ) + imported_model_arn = job_details.get("importedModelArn") + + # Create provisioned throughput + pt_name = provisioned_model_name or f"{imported_model_name}-throughput" + return self.create_provisioned_throughput( + model_id=imported_model_arn, + provisioned_model_name=pt_name, + model_units=model_units, + commitment_duration=commitment_duration, + tags=provisioned_model_tags, + ) def create_deployment( self, @@ -243,6 +279,140 @@ def create_deployment( return response + def create_provisioned_throughput( + self, + model_id: str, + provisioned_model_name: str, + model_units: int = 1, + commitment_duration: Optional[str] = None, + tags: Optional[list] = None, + poll_interval: int = 60, + max_wait: int = 3600, + ) -> Dict[str, Any]: + """Create provisioned throughput for an imported model on Bedrock. + + Calls CreateProvisionedModelThroughput and polls until the provisioned + throughput reaches InService status. + + Args: + model_id: ARN or ID of the imported model. + provisioned_model_name: Name for the provisioned throughput resource. + model_units: Number of model units to provision. Defaults to 1. + commitment_duration: Commitment duration. Valid values: 'OneMonth', + 'SixMonths'. If not provided, no commitment is set (on-demand). + tags: Tags for the provisioned throughput resource. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Returns: + Response from Bedrock create_provisioned_model_throughput API. + + Raises: + RuntimeError: If the provisioned throughput fails or times out. + ValueError: If model_id or provisioned_model_name is not provided. + """ + if not model_id: + raise ValueError("model_id is required for create_provisioned_throughput.") + if not provisioned_model_name: + raise ValueError( + "provisioned_model_name is required for create_provisioned_throughput." + ) + + params = { + "modelId": model_id, + "provisionedModelName": provisioned_model_name, + "modelUnits": model_units, + } + if commitment_duration: + params["commitmentDuration"] = commitment_duration + if tags: + params["tags"] = tags + + logger.info( + "Creating provisioned throughput '%s' for model %s with %d model units", + provisioned_model_name, + model_id, + model_units, + ) + response = self._get_bedrock_client().create_provisioned_model_throughput(**params) + + provisioned_model_arn = response.get("provisionedModelArn") + if provisioned_model_arn: + self._wait_for_provisioned_throughput_in_service( + provisioned_model_arn, poll_interval=poll_interval, max_wait=max_wait + ) + + return response + + def _wait_for_import_job_complete( + self, job_arn: str, poll_interval: int = 60, max_wait: int = 3600 + ): + """Poll Bedrock until the model import job reaches Completed status. + + Args: + job_arn: ARN of the model import job. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Raises: + RuntimeError: If the import job fails or times out. + """ + elapsed = 0 + status = None + while elapsed < max_wait: + resp = self._get_bedrock_client().get_model_import_job(jobIdentifier=job_arn) + status = resp.get("status") + logger.info("Import job status: %s (elapsed %ds)", status, elapsed) + if status == "Completed": + return + if status == "Failed": + failure_reason = resp.get("failureMessage", "Unknown") + raise RuntimeError( + f"Model import job {job_arn} failed. Reason: {failure_reason}" + ) + time.sleep(poll_interval) + elapsed += poll_interval + raise RuntimeError( + f"Timed out after {max_wait}s waiting for import job {job_arn} to complete. " + f"Last status: {status}" + ) + + def _wait_for_provisioned_throughput_in_service( + self, provisioned_model_arn: str, poll_interval: int = 60, max_wait: int = 3600 + ): + """Poll Bedrock until provisioned throughput reaches InService status. + + Args: + provisioned_model_arn: ARN of the provisioned model throughput. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Raises: + RuntimeError: If the provisioned throughput fails or times out. + """ + elapsed = 0 + status = None + while elapsed < max_wait: + resp = self._get_bedrock_client().get_provisioned_model_throughput( + provisionedModelId=provisioned_model_arn + ) + status = resp.get("status") + logger.info("Provisioned throughput status: %s (elapsed %ds)", status, elapsed) + if status == "InService": + return + if status == "Failed": + failure_reason = resp.get("failureMessage", "Unknown") + raise RuntimeError( + f"Provisioned throughput {provisioned_model_arn} failed. " + f"Reason: {failure_reason}" + ) + time.sleep(poll_interval) + elapsed += poll_interval + raise RuntimeError( + f"Timed out after {max_wait}s waiting for provisioned throughput " + f"{provisioned_model_arn} to become InService. Last status: {status}" + ) + def _wait_for_model_active( self, model_arn: str, poll_interval: int = 60, max_wait: int = 3600 ): diff --git a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py index 57a5c7abc9..5679b624d9 100644 --- a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py +++ b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py @@ -469,15 +469,120 @@ def test_timeout_raises(self): class TestDeploy: - def test_non_nova(self): + def test_non_nova_full_chain(self): + """Non-Nova deploy: import job → wait → get model ARN → create PT → wait PT.""" c = _make_container(s3_uri="s3://b/m.tar.gz") b = _builder() b.model_package = _make_model_package(c) b.s3_model_artifacts = "s3://b/m.tar.gz" b._bedrock_client = Mock() b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} - result = b.deploy(job_name="j", imported_model_name="m", role_arn="r") - assert result == {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelArn": "arn:aws:bedrock:us-west-2:123:imported-model/m", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + result = b.deploy(job_name="j", imported_model_name="m", role_arn="r") + + b._bedrock_client.create_model_import_job.assert_called_once() + b._bedrock_client.get_model_import_job.assert_called() + b._bedrock_client.create_provisioned_model_throughput.assert_called_once() + b._bedrock_client.get_provisioned_model_throughput.assert_called_once() + assert result["provisionedModelArn"] == "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput" + + def test_non_nova_default_provisioned_model_name(self): + """Default provisioned model name is imported_model_name + '-throughput'.""" + c = _make_container(s3_uri="s3://b/m.tar.gz") + b = _builder() + b.model_package = _make_model_package(c) + b.s3_model_artifacts = "s3://b/m.tar.gz" + b._bedrock_client = Mock() + b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelArn": "arn:model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy(job_name="j", imported_model_name="my-model", role_arn="r") + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["provisionedModelName"] == "my-model-throughput" + + def test_non_nova_custom_provisioned_model_name(self): + """User can override provisioned model name.""" + c = _make_container(s3_uri="s3://b/m.tar.gz") + b = _builder() + b.model_package = _make_model_package(c) + b.s3_model_artifacts = "s3://b/m.tar.gz" + b._bedrock_client = Mock() + b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelArn": "arn:model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy( + job_name="j", + imported_model_name="m", + role_arn="r", + provisioned_model_name="custom-pt-name", + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["provisionedModelName"] == "custom-pt-name" + + def test_non_nova_with_model_units_and_commitment(self): + """User can specify model_units and commitment_duration.""" + c = _make_container(s3_uri="s3://b/m.tar.gz") + b = _builder() + b.model_package = _make_model_package(c) + b.s3_model_artifacts = "s3://b/m.tar.gz" + b._bedrock_client = Mock() + b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelArn": "arn:model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy( + job_name="j", + imported_model_name="m", + role_arn="r", + model_units=3, + commitment_duration="SixMonths", + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["modelUnits"] == 3 + assert kw["commitmentDuration"] == "SixMonths" def test_nova_full_chain(self): c = _make_container(recipe_name="nova-micro", hub_content_name="nova") @@ -579,7 +684,225 @@ def test_non_nova_strips_none_params(self): b.s3_model_artifacts = "s3://b/k" b._bedrock_client = Mock() b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn"} - b.deploy(job_name="j", imported_model_name="m", role_arn="r") + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelArn": "arn:model", + } + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt", + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy(job_name="j", imported_model_name="m", role_arn="r") + kw = b._bedrock_client.create_model_import_job.call_args[1] assert "importedModelKmsKeyId" not in kw assert "clientRequestToken" not in kw + + +# ── _wait_for_import_job_complete ─────────────────────────────────────────── + + +class TestWaitForImportJobComplete: + def test_immediate_completed(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "Completed"} + b._wait_for_import_job_complete("arn:job") + b._bedrock_client.get_model_import_job.assert_called_once_with( + jobIdentifier="arn:job" + ) + + def test_polls_then_completed(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.side_effect = [ + {"status": "InProgress"}, + {"status": "InProgress"}, + {"status": "Completed"}, + ] + with patch(f"{MODULE}.time.sleep"): + b._wait_for_import_job_complete("arn:job", poll_interval=1, max_wait=10) + assert b._bedrock_client.get_model_import_job.call_count == 3 + + def test_failed_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = { + "status": "Failed", + "failureMessage": "Invalid model format", + } + with pytest.raises(RuntimeError, match="Invalid model format"): + b._wait_for_import_job_complete("arn:job") + + def test_failed_unknown_reason(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "Failed"} + with pytest.raises(RuntimeError, match="Unknown"): + b._wait_for_import_job_complete("arn:job") + + def test_timeout_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "InProgress"} + with patch(f"{MODULE}.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + b._wait_for_import_job_complete("arn:job", poll_interval=1, max_wait=2) + + +# ── create_provisioned_throughput ─────────────────────────────────────────── + + +class TestCreateProvisionedThroughput: + def test_creates_and_polls(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + result = b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="my-pt" + ) + + b._bedrock_client.create_provisioned_model_throughput.assert_called_once_with( + modelId="arn:model", + provisionedModelName="my-pt", + modelUnits=1, + ) + b._bedrock_client.get_provisioned_model_throughput.assert_called_once() + assert result["provisionedModelArn"] == "arn:pt" + + def test_passes_commitment_duration(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + b.create_provisioned_throughput( + model_id="arn:model", + provisioned_model_name="pt", + model_units=5, + commitment_duration="OneMonth", + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["modelUnits"] == 5 + assert kw["commitmentDuration"] == "OneMonth" + + def test_passes_tags(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + tags = [{"Key": "team", "Value": "ml"}] + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="pt", tags=tags + ) + + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["tags"] == tags + + def test_skips_polling_when_no_arn_in_response(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = {} + + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="pt" + ) + b._bedrock_client.get_provisioned_model_throughput.assert_not_called() + + def test_empty_model_id_raises(self): + b = _builder() + with pytest.raises(ValueError, match="model_id is required"): + b.create_provisioned_throughput(model_id="", provisioned_model_name="pt") + + def test_none_model_id_raises(self): + b = _builder() + with pytest.raises(ValueError, match="model_id is required"): + b.create_provisioned_throughput(model_id=None, provisioned_model_name="pt") + + def test_empty_provisioned_model_name_raises(self): + b = _builder() + with pytest.raises(ValueError, match="provisioned_model_name is required"): + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="" + ) + + +# ── _wait_for_provisioned_throughput_in_service ───────────────────────────── + + +class TestWaitForProvisionedThroughputInService: + def test_immediate_in_service(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + b._wait_for_provisioned_throughput_in_service("arn:pt") + b._bedrock_client.get_provisioned_model_throughput.assert_called_once_with( + provisionedModelId="arn:pt" + ) + + def test_polls_then_in_service(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.side_effect = [ + {"status": "Creating"}, + {"status": "Creating"}, + {"status": "InService"}, + ] + with patch(f"{MODULE}.time.sleep"): + b._wait_for_provisioned_throughput_in_service( + "arn:pt", poll_interval=1, max_wait=10 + ) + assert b._bedrock_client.get_provisioned_model_throughput.call_count == 3 + + def test_failed_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Failed", + "failureMessage": "Insufficient capacity", + } + with pytest.raises(RuntimeError, match="Insufficient capacity"): + b._wait_for_provisioned_throughput_in_service("arn:pt") + + def test_failed_unknown_reason(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Failed" + } + with pytest.raises(RuntimeError, match="Unknown"): + b._wait_for_provisioned_throughput_in_service("arn:pt") + + def test_timeout_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Creating" + } + with patch(f"{MODULE}.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + b._wait_for_provisioned_throughput_in_service( + "arn:pt", poll_interval=1, max_wait=2 + ) From d6636a04ade2830a756e58ed62e1aef825b788e2 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Tue, 26 May 2026 23:07:08 -0700 Subject: [PATCH 2/3] fix: use the importedModelName from the job response instead of ARN --- .../src/sagemaker/serve/bedrock_model_builder.py | 9 ++++++--- .../tests/unit/test_bedrock_model_builder.py | 15 +++++++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py index 663ed41579..535aab03b2 100644 --- a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py @@ -212,16 +212,19 @@ def deploy( job_arn = import_response.get("jobArn") self._wait_for_import_job_complete(job_arn) - # Get the imported model ARN from the completed job + # Get the imported model name from the completed job. + # We use importedModelName (not importedModelArn) because + # CreateProvisionedModelThroughput accepts model names but not + # the imported-model ARN format. job_details = self._get_bedrock_client().get_model_import_job( jobIdentifier=job_arn ) - imported_model_arn = job_details.get("importedModelArn") + imported_model_id = job_details.get("importedModelName") # Create provisioned throughput pt_name = provisioned_model_name or f"{imported_model_name}-throughput" return self.create_provisioned_throughput( - model_id=imported_model_arn, + model_id=imported_model_id, provisioned_model_name=pt_name, model_units=model_units, commitment_duration=commitment_duration, diff --git a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py index 5679b624d9..47a86aa087 100644 --- a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py +++ b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py @@ -470,7 +470,7 @@ def test_timeout_raises(self): class TestDeploy: def test_non_nova_full_chain(self): - """Non-Nova deploy: import job → wait → get model ARN → create PT → wait PT.""" + """Non-Nova deploy: import job → wait → get model name → create PT → wait PT.""" c = _make_container(s3_uri="s3://b/m.tar.gz") b = _builder() b.model_package = _make_model_package(c) @@ -479,7 +479,7 @@ def test_non_nova_full_chain(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelArn": "arn:aws:bedrock:us-west-2:123:imported-model/m", + "importedModelName": "my-imported-model", } b._bedrock_client.create_provisioned_model_throughput.return_value = { "provisionedModelArn": "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput", @@ -495,6 +495,9 @@ def test_non_nova_full_chain(self): b._bedrock_client.get_model_import_job.assert_called() b._bedrock_client.create_provisioned_model_throughput.assert_called_once() b._bedrock_client.get_provisioned_model_throughput.assert_called_once() + # Verify model_id passed to create_provisioned_model_throughput is the model name + pt_call_kwargs = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert pt_call_kwargs["modelId"] == "my-imported-model" assert result["provisionedModelArn"] == "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput" def test_non_nova_default_provisioned_model_name(self): @@ -507,7 +510,7 @@ def test_non_nova_default_provisioned_model_name(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelArn": "arn:model", + "importedModelName": "my-imported-model", } b._bedrock_client.create_provisioned_model_throughput.return_value = { "provisionedModelArn": "arn:pt", @@ -532,7 +535,7 @@ def test_non_nova_custom_provisioned_model_name(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelArn": "arn:model", + "importedModelName": "my-imported-model", } b._bedrock_client.create_provisioned_model_throughput.return_value = { "provisionedModelArn": "arn:pt", @@ -562,7 +565,7 @@ def test_non_nova_with_model_units_and_commitment(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelArn": "arn:model", + "importedModelName": "my-imported-model", } b._bedrock_client.create_provisioned_model_throughput.return_value = { "provisionedModelArn": "arn:pt", @@ -686,7 +689,7 @@ def test_non_nova_strips_none_params(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelArn": "arn:model", + "importedModelName": "my-imported-model", } b._bedrock_client.create_provisioned_model_throughput.return_value = { "provisionedModelArn": "arn:pt", From 1fe1959f4126cd547f4e218800c5618ee87fd6c5 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Mon, 1 Jun 2026 00:31:24 -0700 Subject: [PATCH 3/3] test: change OSS deploy to -- import with polling(OD), PT as separate method - deploy() for non-Nova models now waits for import job completion and returns job details. Model is ready for on-demand inference after deploy(). - Removed mandatory CreateProvisionedModelThroughput from deploy() flow. - create_provisioned_throughput() remains as a standalone public method for users who need dedicated throughput. - Updated unit tests to verify deploy() no longer calls PT APIs. - Added integ test for import job polling (test_deploy_oss_model_waits_for_import_completion). - Added skipped integ test for create_provisioned_throughput() pending PT MU quota approval. --- .../sagemaker/serve/bedrock_model_builder.py | 40 +-- .../test_bedrock_provisioned_throughput.py | 279 ++++++++++++++++++ .../tests/unit/test_bedrock_model_builder.py | 109 +------ 3 files changed, 300 insertions(+), 128 deletions(-) create mode 100644 sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py diff --git a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py index 535aab03b2..367ca58cd0 100644 --- a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py @@ -118,18 +118,15 @@ def deploy( client_request_token: Optional[str] = None, imported_model_kms_key_id: Optional[str] = None, deployment_name: Optional[str] = None, - provisioned_model_name: Optional[str] = None, - model_units: int = 1, - commitment_duration: Optional[str] = None, - provisioned_model_tags: Optional[list] = None, ) -> Dict[str, Any]: """Deploy the model to Bedrock. Automatically detects if the model is a Nova model and uses the appropriate Bedrock API (create_custom_model for Nova, create_model_import_job for others). For Nova models, also creates a custom model deployment for inference. - For OSS models, creates a model import job, waits for completion, then creates - provisioned throughput and waits for it to become InService. + For OSS models, creates a model import job and waits for it to complete. + Once complete, the model is ready for on-demand inference. If provisioned + throughput is needed, use the separate create_provisioned_throughput() method. Args: job_name: Name for the model import job (non-Nova models only). @@ -143,25 +140,15 @@ def deploy( imported_model_kms_key_id: KMS key ID for encryption (non-Nova models only). deployment_name: Name for the deployment (Nova models only). If not provided, defaults to custom_model_name suffixed with '-deployment'. - provisioned_model_name: Name for the provisioned throughput resource - (non-Nova models only). If not provided, defaults to - imported_model_name suffixed with '-throughput'. - model_units: Number of model units for provisioned throughput (non-Nova - models only). Defaults to 1. - commitment_duration: Commitment duration for provisioned throughput - (non-Nova models only). Valid values: 'OneMonth', 'SixMonths'. - If not provided, no commitment is set (on-demand). - provisioned_model_tags: Tags for the provisioned throughput resource - (non-Nova models only). Returns: Response from Bedrock API. For Nova models, returns the create_custom_model_deployment response. For OSS models, returns - the create_provisioned_model_throughput response. + the get_model_import_job response after the job completes. Raises: ValueError: If model_package is not set or required parameters are missing. - RuntimeError: If the import job or provisioned throughput fails or times out. + RuntimeError: If the import job fails or times out. """ if not self.model_package: raise ValueError( @@ -212,24 +199,11 @@ def deploy( job_arn = import_response.get("jobArn") self._wait_for_import_job_complete(job_arn) - # Get the imported model name from the completed job. - # We use importedModelName (not importedModelArn) because - # CreateProvisionedModelThroughput accepts model names but not - # the imported-model ARN format. + # Return the completed job details job_details = self._get_bedrock_client().get_model_import_job( jobIdentifier=job_arn ) - imported_model_id = job_details.get("importedModelName") - - # Create provisioned throughput - pt_name = provisioned_model_name or f"{imported_model_name}-throughput" - return self.create_provisioned_throughput( - model_id=imported_model_id, - provisioned_model_name=pt_name, - model_units=model_units, - commitment_duration=commitment_duration, - tags=provisioned_model_tags, - ) + return job_details def create_deployment( self, diff --git a/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py b/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py new file mode 100644 index 0000000000..587699ff36 --- /dev/null +++ b/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py @@ -0,0 +1,279 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration tests for BedrockModelBuilder import job polling and provisioned throughput.""" +from __future__ import absolute_import + +import json +import time +import random +import logging +from urllib.parse import urlparse + +import boto3 +import pytest + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.core.resources import TrainingJob +from sagemaker.serve.bedrock_model_builder import BedrockModelBuilder + +logger = logging.getLogger(__name__) + +AWS_REGION = "us-west-2" + + +@pytest.fixture(scope="module") +def training_job_name(): + """Training job name for testing (non-Nova, OSS model).""" + return "meta-textgeneration-llama-3-2-1b-instruct-sft-20251201172445" + + +@pytest.fixture(scope="module") +def role_arn(): + """IAM role ARN with Bedrock permissions.""" + return get_execution_role() + + +@pytest.fixture(scope="module") +def bedrock_client(): + """Create Bedrock client.""" + return boto3.client("bedrock", region_name=AWS_REGION) + + +@pytest.fixture(scope="module") +def s3_client(): + """Create S3 client.""" + return boto3.client("s3", region_name=AWS_REGION) + + +@pytest.fixture(scope="module") +def training_job(training_job_name): + """Get the training job.""" + return TrainingJob.get( + training_job_name=training_job_name, region=AWS_REGION + ) + + +def _setup_model_files(s3_artifacts_uri, s3_client): + """Setup required model files for Bedrock deployment. + + Bedrock model import requires HuggingFace-format files (config.json, + tokenizer.json, etc.) at the root of the S3 model artifacts path. + Training jobs often store these under checkpoints/hf_merged/, so we + copy them to the expected location. + + Args: + s3_artifacts_uri: The S3 URI that BedrockModelBuilder will use for import. + s3_client: boto3 S3 client. + """ + parsed = urlparse(s3_artifacts_uri) + bucket = parsed.netloc + base_prefix = parsed.path.lstrip("/").rstrip("/") + + hf_merged_prefix = f"{base_prefix}/checkpoints/hf_merged/" + root_prefix = f"{base_prefix}/" + + files_to_copy = [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "model.safetensors", + ] + + for file in files_to_copy: + try: + s3_client.head_object(Bucket=bucket, Key=root_prefix + file) + logger.info("File already exists: s3://%s/%s%s", bucket, root_prefix, file) + except Exception: + try: + s3_client.copy_object( + Bucket=bucket, + CopySource={"Bucket": bucket, "Key": hf_merged_prefix + file}, + Key=root_prefix + file, + ) + logger.info("Copied %s to root", file) + except Exception as e: + logger.warning("Could not copy %s: %s", file, e) + + try: + s3_client.head_object(Bucket=bucket, Key=root_prefix + "added_tokens.json") + except Exception: + try: + s3_client.put_object( + Bucket=bucket, + Key=root_prefix + "added_tokens.json", + Body=json.dumps({}), + ContentType="application/json", + ) + logger.info("Created added_tokens.json") + except Exception as e: + logger.warning("Could not create added_tokens.json: %s", e) + + +class TestBedrockImportJobPolling: + """Test import job polling for OSS models (Option C: deploy only waits for import).""" + + @pytest.fixture(autouse=True) + def _setup(self, bedrock_client): + """Store bedrock client and track resources for cleanup.""" + self._bedrock_client = bedrock_client + self._imported_model_arn = None + yield + self._cleanup() + + def _cleanup(self): + """Clean up Bedrock resources created during the test.""" + if self._imported_model_arn: + try: + logger.info("Deleting imported model: %s", self._imported_model_arn) + self._bedrock_client.delete_imported_model( + modelIdentifier=self._imported_model_arn + ) + except Exception as e: + logger.warning("Failed to delete imported model: %s", e) + + @pytest.mark.slow + def test_deploy_oss_model_waits_for_import_completion( + self, training_job, role_arn, bedrock_client, s3_client + ): + """Test that deploy() waits for import job to complete and returns job details. + + This test verifies that BedrockModelBuilder.deploy() for non-Nova models: + 1. Creates a model import job + 2. Polls until the import job reaches Completed status + 3. Returns the completed job details (model is ready for on-demand invoke) + 4. Does NOT create provisioned throughput + """ + builder = BedrockModelBuilder(model=training_job) + assert builder.s3_model_artifacts is not None + + _setup_model_files(builder.s3_model_artifacts, s3_client) + + suffix = f"{int(time.time())}-{random.randint(1000, 9999)}" + job_name = f"test-import-poll-{suffix}" + imported_model_name = f"test-import-model-{suffix}" + + result = builder.deploy( + job_name=job_name, + imported_model_name=imported_model_name, + role_arn=role_arn, + ) + + # Verify the result is the completed job details + assert result["status"] == "Completed", ( + f"Expected Completed, got {result.get('status')}" + ) + assert "importedModelName" in result + assert "importedModelArn" in result or "jobArn" in result + + # Track for cleanup + self._imported_model_arn = result.get("importedModelArn") + + # Verify model can be found (it exists and is ready) + models = bedrock_client.list_imported_models() + model_names = [m["modelName"] for m in models.get("modelSummaries", [])] + assert imported_model_name in model_names + + +class TestBedrockProvisionedThroughput: + """Test create_provisioned_throughput as a standalone method.""" + + @pytest.fixture(autouse=True) + def _setup(self, bedrock_client): + """Store bedrock client and track resources for cleanup.""" + self._bedrock_client = bedrock_client + self._provisioned_model_arn = None + self._imported_model_arn = None + yield + self._cleanup() + + def _cleanup(self): + """Clean up Bedrock resources created during the test.""" + if self._provisioned_model_arn: + try: + logger.info("Deleting provisioned throughput: %s", self._provisioned_model_arn) + self._bedrock_client.delete_provisioned_model_throughput( + provisionedModelId=self._provisioned_model_arn + ) + except Exception as e: + logger.warning("Failed to delete provisioned throughput: %s", e) + + if self._imported_model_arn: + time.sleep(5) + try: + logger.info("Deleting imported model: %s", self._imported_model_arn) + self._bedrock_client.delete_imported_model( + modelIdentifier=self._imported_model_arn + ) + except Exception as e: + logger.warning("Failed to delete imported model: %s", e) + + @pytest.mark.skip( + reason="Requires Provisioned Throughput MU quota which is not available " + "in the CI account. PT quota must be requested per-model via Matador " + "(https://console.harmony.a2z.com/bedrock-matador/) with 1-4 week SLA. " + "Additionally, the current test training job is based on Llama 3.2 1B which " + "may not support PT in us-west-2. A training job based on a PT-eligible model " + "(e.g., meta.llama3-1-8b-instruct or amazon.titan-embed-image-v1) is needed. " + "Remove this skip once quota is granted and a compatible training job is available." + ) + def test_create_provisioned_throughput_after_import( + self, training_job, role_arn, bedrock_client, s3_client + ): + """Test create_provisioned_throughput() as a standalone method after import. + + This test verifies: + 1. Import a model using deploy() (waits for completion) + 2. Call create_provisioned_throughput() with the imported model name + 3. Polls until provisioned throughput reaches InService + 4. Returns the provisioned throughput response + """ + builder = BedrockModelBuilder(model=training_job) + assert builder.s3_model_artifacts is not None + + _setup_model_files(builder.s3_model_artifacts, s3_client) + + suffix = f"{int(time.time())}-{random.randint(1000, 9999)}" + job_name = f"test-pt-{suffix}" + imported_model_name = f"test-pt-model-{suffix}" + provisioned_model_name = f"test-pt-throughput-{suffix}" + + # Step 1: Deploy (import + wait) + deploy_result = builder.deploy( + job_name=job_name, + imported_model_name=imported_model_name, + role_arn=role_arn, + ) + assert deploy_result["status"] == "Completed" + self._imported_model_arn = deploy_result.get("importedModelArn") + model_id = deploy_result.get("importedModelName") + + # Step 2: Create provisioned throughput (separate call) + pt_result = builder.create_provisioned_throughput( + model_id=model_id, + provisioned_model_name=provisioned_model_name, + model_units=1, + ) + + # Step 3: Verify result + assert "provisionedModelArn" in pt_result, ( + f"Expected 'provisionedModelArn' in result, got keys: {list(pt_result.keys())}" + ) + self._provisioned_model_arn = pt_result["provisionedModelArn"] + + # Step 4: Verify provisioned throughput is InService + pt_response = bedrock_client.get_provisioned_model_throughput( + provisionedModelId=self._provisioned_model_arn + ) + assert pt_response["status"] == "InService", ( + f"Expected InService, got {pt_response['status']}" + ) diff --git a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py index 47a86aa087..19b773391a 100644 --- a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py +++ b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py @@ -469,8 +469,8 @@ def test_timeout_raises(self): class TestDeploy: - def test_non_nova_full_chain(self): - """Non-Nova deploy: import job → wait → get model name → create PT → wait PT.""" + def test_non_nova_waits_for_import_and_returns_job_details(self): + """Non-Nova deploy: import job → wait → return job details.""" c = _make_container(s3_uri="s3://b/m.tar.gz") b = _builder() b.model_package = _make_model_package(c) @@ -480,12 +480,7 @@ def test_non_nova_full_chain(self): b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", "importedModelName": "my-imported-model", - } - b._bedrock_client.create_provisioned_model_throughput.return_value = { - "provisionedModelArn": "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput", - } - b._bedrock_client.get_provisioned_model_throughput.return_value = { - "status": "InService", + "importedModelArn": "arn:aws:bedrock:us-west-2:123:imported-model/abc", } with patch(f"{MODULE}.time.sleep"): @@ -493,70 +488,13 @@ def test_non_nova_full_chain(self): b._bedrock_client.create_model_import_job.assert_called_once() b._bedrock_client.get_model_import_job.assert_called() - b._bedrock_client.create_provisioned_model_throughput.assert_called_once() - b._bedrock_client.get_provisioned_model_throughput.assert_called_once() - # Verify model_id passed to create_provisioned_model_throughput is the model name - pt_call_kwargs = b._bedrock_client.create_provisioned_model_throughput.call_args[1] - assert pt_call_kwargs["modelId"] == "my-imported-model" - assert result["provisionedModelArn"] == "arn:aws:bedrock:us-west-2:123:provisioned-model/m-throughput" - - def test_non_nova_default_provisioned_model_name(self): - """Default provisioned model name is imported_model_name + '-throughput'.""" - c = _make_container(s3_uri="s3://b/m.tar.gz") - b = _builder() - b.model_package = _make_model_package(c) - b.s3_model_artifacts = "s3://b/m.tar.gz" - b._bedrock_client = Mock() - b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} - b._bedrock_client.get_model_import_job.return_value = { - "status": "Completed", - "importedModelName": "my-imported-model", - } - b._bedrock_client.create_provisioned_model_throughput.return_value = { - "provisionedModelArn": "arn:pt", - } - b._bedrock_client.get_provisioned_model_throughput.return_value = { - "status": "InService", - } - - with patch(f"{MODULE}.time.sleep"): - b.deploy(job_name="j", imported_model_name="my-model", role_arn="r") - - kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] - assert kw["provisionedModelName"] == "my-model-throughput" - - def test_non_nova_custom_provisioned_model_name(self): - """User can override provisioned model name.""" - c = _make_container(s3_uri="s3://b/m.tar.gz") - b = _builder() - b.model_package = _make_model_package(c) - b.s3_model_artifacts = "s3://b/m.tar.gz" - b._bedrock_client = Mock() - b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} - b._bedrock_client.get_model_import_job.return_value = { - "status": "Completed", - "importedModelName": "my-imported-model", - } - b._bedrock_client.create_provisioned_model_throughput.return_value = { - "provisionedModelArn": "arn:pt", - } - b._bedrock_client.get_provisioned_model_throughput.return_value = { - "status": "InService", - } - - with patch(f"{MODULE}.time.sleep"): - b.deploy( - job_name="j", - imported_model_name="m", - role_arn="r", - provisioned_model_name="custom-pt-name", - ) - - kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] - assert kw["provisionedModelName"] == "custom-pt-name" + # Should NOT call create_provisioned_model_throughput + b._bedrock_client.create_provisioned_model_throughput.assert_not_called() + assert result["status"] == "Completed" + assert result["importedModelName"] == "my-imported-model" - def test_non_nova_with_model_units_and_commitment(self): - """User can specify model_units and commitment_duration.""" + def test_non_nova_does_not_create_provisioned_throughput(self): + """deploy() for non-Nova should never call CreateProvisionedModelThroughput.""" c = _make_container(s3_uri="s3://b/m.tar.gz") b = _builder() b.model_package = _make_model_package(c) @@ -565,27 +503,14 @@ def test_non_nova_with_model_units_and_commitment(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelName": "my-imported-model", - } - b._bedrock_client.create_provisioned_model_throughput.return_value = { - "provisionedModelArn": "arn:pt", - } - b._bedrock_client.get_provisioned_model_throughput.return_value = { - "status": "InService", + "importedModelName": "m", } with patch(f"{MODULE}.time.sleep"): - b.deploy( - job_name="j", - imported_model_name="m", - role_arn="r", - model_units=3, - commitment_duration="SixMonths", - ) + b.deploy(job_name="j", imported_model_name="m", role_arn="r") - kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] - assert kw["modelUnits"] == 3 - assert kw["commitmentDuration"] == "SixMonths" + b._bedrock_client.create_provisioned_model_throughput.assert_not_called() + b._bedrock_client.get_provisioned_model_throughput.assert_not_called() def test_nova_full_chain(self): c = _make_container(recipe_name="nova-micro", hub_content_name="nova") @@ -689,13 +614,7 @@ def test_non_nova_strips_none_params(self): b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn"} b._bedrock_client.get_model_import_job.return_value = { "status": "Completed", - "importedModelName": "my-imported-model", - } - b._bedrock_client.create_provisioned_model_throughput.return_value = { - "provisionedModelArn": "arn:pt", - } - b._bedrock_client.get_provisioned_model_throughput.return_value = { - "status": "InService", + "importedModelName": "m", } with patch(f"{MODULE}.time.sleep"):