From 3d55ead27aa76d1beac0ffd17308961431bb3ded Mon Sep 17 00:00:00 2001 From: Pranav Murthy Date: Tue, 21 Jan 2025 23:42:02 -0500 Subject: [PATCH 1/8] ModelDataUrl is made optional, applies to certain prebuilt containers and gen-ai dlami --- src/sagemaker/local/entities.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index a21a375f54..079423e77d 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -615,6 +615,10 @@ def serve(self): self.container = _SageMakerContainer( instance_type, instance_count, image, self.local_session ) + + if "ModelDataUrl" not in self.primary_container.keys(): + self.primary_container["ModelDataUrl"] = None + self.container.serve( self.primary_container["ModelDataUrl"], self.primary_container["Environment"] ) From c235f9a98287095572abb9069a3c48008d292351 Mon Sep 17 00:00:00 2001 From: Pranav Murthy Date: Tue, 21 Jan 2025 23:43:03 -0500 Subject: [PATCH 2/8] update docker compose structure with nvidia --- src/sagemaker/local/image.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index ef24bb0d99..00c3bc5deb 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -340,7 +340,10 @@ def serve(self, model_dir, environment): self.container_root = self._create_tmp_folder() logger.info("creating hosting dir in %s", self.container_root) - volumes = self._prepare_serving_volumes(model_dir) + if model_dir is not None: + volumes = self._prepare_serving_volumes(model_dir) + else: + volumes = None # If the user script was passed as a file:// mount it to the container. if sagemaker.estimator.DIR_PARAM_NAME.upper() in environment: @@ -859,7 +862,7 @@ def _create_docker_host( if self.instance_type == "local_gpu": host_config["deploy"] = { "resources": { - "reservations": {"devices": [{"count": "all", "capabilities": ["gpu"]}]} + "reservations": {"devices": [{"driver": "nvidia", "count": "all", "capabilities": ["gpu"]}]} } } From fb584ba84833b66b6192da25c9559c92f619cd2f Mon Sep 17 00:00:00 2001 From: Pranav Murthy Date: Fri, 31 Jan 2025 18:38:55 -0500 Subject: [PATCH 3/8] add test scripts for all cpu based traditional ML/DL workloads --- .../local_mode/sample_inference_script.py | 30 ++ .../local_mode/sample_processing_script.py | 41 +++ .../local_mode/sample_training_script.py | 88 ++++++ .../local_mode/test_local_model_cpu_jobs.py | 291 ++++++++++++++++++ 4 files changed, 450 insertions(+) create mode 100644 tests/integ/sagemaker/local_mode/sample_inference_script.py create mode 100644 tests/integ/sagemaker/local_mode/sample_processing_script.py create mode 100644 tests/integ/sagemaker/local_mode/sample_training_script.py create mode 100644 tests/integ/sagemaker/local_mode/test_local_model_cpu_jobs.py diff --git a/tests/integ/sagemaker/local_mode/sample_inference_script.py b/tests/integ/sagemaker/local_mode/sample_inference_script.py new file mode 100644 index 0000000000..ef82e0e11e --- /dev/null +++ b/tests/integ/sagemaker/local_mode/sample_inference_script.py @@ -0,0 +1,30 @@ +import os + +# Install the required package +os.system("pip install transformers==4.18.0") + +from transformers import pipeline + +CSV_CONTENT_TYPE = 'text/csv' + + +def model_fn(model_dir): + sentiment_analysis = pipeline( + "sentiment-analysis", + model=model_dir, + tokenizer=model_dir, + return_all_scores=True + ) + return sentiment_analysis + + +def input_fn(serialized_input_data, content_type=CSV_CONTENT_TYPE): + if content_type == CSV_CONTENT_TYPE: + input_data = serialized_input_data.splitlines() + return input_data + else: + raise Exception('Requested unsupported ContentType in Accept: ' + content_type) + + +def predict_fn(input_data, model): + return model(input_data) diff --git a/tests/integ/sagemaker/local_mode/sample_processing_script.py b/tests/integ/sagemaker/local_mode/sample_processing_script.py new file mode 100644 index 0000000000..05828a1c6b --- /dev/null +++ b/tests/integ/sagemaker/local_mode/sample_processing_script.py @@ -0,0 +1,41 @@ +import pandas as pd +import numpy as np +import argparse +import os +from sklearn.preprocessing import OrdinalEncoder + +def _parse_args(): + + parser = argparse.ArgumentParser() + + # Data, model, and output directories + # model_dir is always passed in from SageMaker. By default this is a S3 path under the default bucket. + parser.add_argument('--filepath', type=str, default='/opt/ml/processing/input/') + parser.add_argument('--filename', type=str, default='bank-additional-full.csv') + parser.add_argument('--outputpath', type=str, default='/opt/ml/processing/output/') + parser.add_argument('--categorical_features', type=str, default='y, job, marital, education, default, housing, loan, contact, month, day_of_week, poutcome') + + return parser.parse_known_args() + +if __name__=="__main__": + # Process arguments + args, _ = _parse_args() + # Load data + df = pd.read_csv(os.path.join(args.filepath, args.filename)) + # Change the value . into _ + df = df.replace(regex=r'\.', value='_') + df = df.replace(regex=r'\_$', value='') + # Add two new indicators + df["no_previous_contact"] = (df["pdays"] == 999).astype(int) + df["not_working"] = df["job"].isin(["student", "retired", "unemployed"]).astype(int) + df = df.drop(['duration', 'emp.var.rate', 'cons.price.idx', 'cons.conf.idx', 'euribor3m', 'nr.employed'], axis=1) + # Encode the categorical features + df = pd.get_dummies(df) + # Train, test, validation split + train_data, validation_data, test_data = np.split(df.sample(frac=1, random_state=42), [int(0.7 * len(df)), int(0.9 * len(df))]) # Randomly sort the data then split out first 70%, second 20%, and last 10% + # Local store + pd.concat([train_data['y_yes'], train_data.drop(['y_yes','y_no'], axis=1)], axis=1).to_csv(os.path.join(args.outputpath, 'train/train.csv'), index=False, header=False) + pd.concat([validation_data['y_yes'], validation_data.drop(['y_yes','y_no'], axis=1)], axis=1).to_csv(os.path.join(args.outputpath, 'validation/validation.csv'), index=False, header=False) + test_data['y_yes'].to_csv(os.path.join(args.outputpath, 'test/test_y.csv'), index=False, header=False) + test_data.drop(['y_yes','y_no'], axis=1).to_csv(os.path.join(args.outputpath, 'test/test_x.csv'), index=False, header=False) + print("## Processing complete. Exiting.") \ No newline at end of file diff --git a/tests/integ/sagemaker/local_mode/sample_training_script.py b/tests/integ/sagemaker/local_mode/sample_training_script.py new file mode 100644 index 0000000000..9a86b8e1dc --- /dev/null +++ b/tests/integ/sagemaker/local_mode/sample_training_script.py @@ -0,0 +1,88 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from __future__ import print_function + +import argparse +import os + +import joblib +import pandas as pd +from sklearn import tree +from sklearn.metrics import mean_squared_error + +if __name__ == "__main__": + print("Training Started") + parser = argparse.ArgumentParser() + + # Hyperparameters are described here. In this simple example we are just including one hyperparameter. + parser.add_argument("--max_leaf_nodes", type=int, default=-1) + + # Sagemaker specific arguments. Defaults are set in the environment variables. + parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) + parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) + parser.add_argument("--train", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) + parser.add_argument("--validation", type=str, default=os.environ["SM_CHANNEL_VALIDATION"]) + + args = parser.parse_args() + print("Got Args: {}".format(args)) + + # Take the set of files and read them all into a single pandas dataframe + input_files = [os.path.join(args.train, file) for file in os.listdir(args.train)] + if len(input_files) == 0: + raise ValueError( + ( + "There are no files in {}.\n" + + "This usually indicates that the channel ({}) was incorrectly specified,\n" + + "the data specification in S3 was incorrectly specified or the role specified\n" + + "does not have permission to access the data." + ).format(args.train, "train") + ) + raw_data = [pd.read_csv(file, header=None, engine="python") for file in input_files] + train_data = pd.concat(raw_data) + + # labels are in the first column + train_y = train_data.iloc[:, 0] + train_X = train_data.iloc[:, 1:] + + # Here we support a single hyperparameter, 'max_leaf_nodes'. Note that you can add as many + # as your training my require in the ArgumentParser above. + max_leaf_nodes = args.max_leaf_nodes + + # Now use scikit-learn's decision tree regression to train the model. + clf = tree.DecisionTreeRegressor(max_leaf_nodes=max_leaf_nodes) + clf = clf.fit(train_X, train_y) + + input_files = [os.path.join(args.validation, file) for file in os.listdir(args.validation)] + raw_data = [pd.read_csv(file, header=None, engine="python") for file in input_files] + validation_data = pd.concat(raw_data) + # labels are in the first column + validation_y = validation_data.iloc[:, 0] + validation_X = validation_data.iloc[:, 1:] + # + predictions = clf.predict(validation_X) + error = mean_squared_error(predictions, validation_y) + print(f"RMSE: {error}") + # Print the coefficients of the trained classifier, and save the coefficients + joblib.dump(clf, os.path.join(args.model_dir, "model.joblib")) + + print("Training Completed") + + +def model_fn(model_dir): + """Deserialized and return fitted model + + Note that this should have the same name as the serialized model in the main method + """ + clf = joblib.load(os.path.join(model_dir, "model.joblib")) + return clf \ No newline at end of file diff --git a/tests/integ/sagemaker/local_mode/test_local_model_cpu_jobs.py b/tests/integ/sagemaker/local_mode/test_local_model_cpu_jobs.py new file mode 100644 index 0000000000..6b10d8016e --- /dev/null +++ b/tests/integ/sagemaker/local_mode/test_local_model_cpu_jobs.py @@ -0,0 +1,291 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import time +from typing import Union + + +import os +import re +import pytest +import subprocess +import logging +import numpy as np +import pandas as pd +import sagemaker +import boto3 +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.metrics import mean_squared_error +from pathlib import Path +from sagemaker.local import LocalSession +from sagemaker.processing import ( + ProcessingInput, + ProcessingOutput +) +from sagemaker.sklearn import SKLearn +from sagemaker.sklearn.processing import SKLearnProcessor +from sagemaker.deserializers import CSVDeserializer +from sagemaker.pytorch import PyTorchModel +from sagemaker.serializers import CSVSerializer + + +# Replace this role ARN with an appropriate role for your environment +ROLE = "arn:aws:iam::111111111111:role/service-role/AmazonSageMaker-ExecutionRole-20200101T000001" + + +def ensure_docker_compose_installed(): + """ + Downloads the Docker Compose plugin if not present, and verifies installation + by checking the output of 'docker compose version' matches the pattern: + 'Docker Compose version vX.Y.Z' + """ + + cli_plugins_path = Path.home() / ".docker" / "cli-plugins" + cli_plugins_path.mkdir(parents=True, exist_ok=True) + + compose_binary_path = cli_plugins_path / "docker-compose" + if not compose_binary_path.exists(): + subprocess.run( + [ + "curl", + "-SL", + "https://github.com/docker/compose/releases/download/v2.3.3/docker-compose-linux-x86_64", + "-o", + str(compose_binary_path), + ], + check=True, + ) + subprocess.run(["chmod", "+x", str(compose_binary_path)], check=True) + + # Verify Docker Compose version + try: + output = subprocess.check_output(["docker", "compose", "version"], stderr=subprocess.STDOUT) + output_decoded = output.decode("utf-8").strip() + logging.info(f"'docker compose version' output: {output_decoded}") + + # Example expected format: "Docker Compose version vxxx" + pattern = r"Docker Compose version+" + match = re.search(pattern, output_decoded) + assert ( + match is not None + ), f"Could not find a Docker Compose version string matching '{pattern}' in: {output_decoded}" + + except subprocess.CalledProcessError as e: + raise AssertionError(f"Failed to verify Docker Compose: {e}") + + +""" +Local Model: ProcessingJob +""" +@pytest.mark.local +def test_scikit_learn_local_processing(): + """ + Test local mode processing with a scikit-learn processor. + This uses the same logic as scikit_learn_local_processing.py but in a pytest test function. + + Requirements/Assumptions: + - Docker must be installed and running on the local machine. + - 'processing_script.py' must be in the current working directory (or specify the correct path). + - There should be some local input data if 'processing_script.py' needs it (see ProcessingInput below). + """ + ensure_docker_compose_installed() + + # 1. Create local session for testing + sagemaker_session = LocalSession() + sagemaker_session.config = {"local": {"local_code": True}} + + # 2. Define a scikit-learn processor in local mode + processor = SKLearnProcessor( + framework_version="1.2-1", + instance_count=1, + instance_type="local", + role=ROLE, + sagemaker_session=sagemaker_session + ) + + logging.warning("Starting local processing job.") + logging.warning("Note: the first run may take time to pull the required Docker image.") + + # 3. Run the processing job locally + # - Update 'source' and 'destination' paths based on your local folder setup + processor.run( + code="sample_processing_script.py", + inputs=[ + ProcessingInput( + source="s3://sagemaker-example-files-prod-us-east-1/datasets/tabular/uci_bank_marketing/bank-additional-full.csv", + destination="/opt/ml/processing/input" + ) + ], + outputs=[ + ProcessingOutput( + output_name="train_data", + source="/opt/ml/processing/output/train", + destination="./output_data/train", + ), + ProcessingOutput( + output_name="validation_data", + source="/opt/ml/processing/output/validation", + destination="./output_data/validation" + ), + ProcessingOutput( + output_name="test_data", + source="/opt/ml/processing/output/test", + destination="./output_data/test" + ), + ], + ) + assert True + + +""" +Local Model: Inference +""" +@pytest.mark.local +def test_pytorch_local_model_inference(): + """ + Test local mode inference for a TensorFlow NLP model using PyTorch. + This test deploys the model locally via Docker, performs an inference + on a sample image URL, and asserts that the output is received. + """ + ensure_docker_compose_installed() + + # 1. Create a local session for inference + sagemaker_session = LocalSession() + sagemaker_session.config = {"local": {"local_code": True}} + + # pre created model for inference + model_dir = 's3://aws-ml-blog/artifacts/pytorch-nlp-script-mode-local-model-inference/model.tar.gz' + # sample dummy inference + test_data = [ + "Never allow the same bug to bite you twice.", + "The best part of Amazon SageMaker is that it makes machine learning easy.", + "Amazon SageMaker Inference Recommender helps you choose the best available compute instance and configuration to deploy machine learning models for optimal inference performance and cost." + ] + logging.warning(f'test_data: {test_data}') + + model = PyTorchModel( + model_data=model_dir, + framework_version='1.8', + # source_dir='inference', + py_version='py3', + entry_point='sample_inference_script.py', + role=ROLE, + sagemaker_session=sagemaker_session + ) + + logging.warning('Deploying endpoint in local mode') + logging.warning( + 'Note: if launching for the first time in local mode, container image download might take a few minutes to complete.') + predictor = model.deploy( + initial_instance_count=1, + instance_type='local', + container_startup_health_check_timeout=600 + ) + + # create a new CSV serializer and deserializer + predictor.serializer = CSVSerializer() + predictor.deserializer = CSVDeserializer() + + predictions = predictor.predict( + ",".join(test_data) + ) + logging.warning(f'predictions: {predictions}') + # delete endpoint, clean up and terminate + predictor.delete_endpoint(predictor.endpoint) + + # assert model response + assert type(predictions) == list, "Response return type is a List" + assert len(predictions) >= 1, "empty list returned" + + +def download_training_and_eval_data(): + logging.warning('Downloading training dataset') + + # Load California Housing dataset, then join labels and features + california = datasets.fetch_california_housing() + dataset = np.insert(california.data, 0, california.target, axis=1) + # Create directory and write csv + os.makedirs("./data/train", exist_ok=True) + os.makedirs("./data/validation", exist_ok=True) + os.makedirs("./data/test", exist_ok=True) + + train, other = train_test_split(dataset, test_size=0.3) + validation, test = train_test_split(other, test_size=0.5) + + np.savetxt("./data/train/california_train.csv", train, delimiter=",") + np.savetxt("./data/validation/california_validation.csv", validation, delimiter=",") + np.savetxt("./data/test/california_test.csv", test, delimiter=",") + + logging.warning('Downloading completed') + + +def do_inference_on_local_endpoint(predictor): + print(f'\nStarting Inference on endpoint (local).') + test_data = pd.read_csv("data/test/california_test.csv", header=None) + test_X = test_data.iloc[:, 1:] + test_y = test_data.iloc[:, 0] + predictions = predictor.predict(test_X.values) + logging.warning("Predictions: {}".format(predictions)) + logging.warning("Actual: {}".format(test_y.values)) + logging.warning(f"RMSE: {mean_squared_error(predictions, test_y.values)}") + return predictions, test_y.values, float(mean_squared_error(predictions, test_y.values)) + + +""" +Local Model: TrainingJob and Inference +""" +@pytest.mark.local +def test_sklearn_local_model_train_inference(): + + download_training_and_eval_data() + + logging.warning('Starting model training.') + logging.warning('Note: if launching for the first time in local mode, container image download might take a few minutes to complete.') + + # 1. Create a local session for inference + sagemaker_session = LocalSession() + sagemaker_session.config = {"local": {"local_code": True}} + + sklearn = SKLearn( + entry_point="sample_training_script.py", + # source_dir='training', + framework_version="1.2-1", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_type="local", + hyperparameters={"max_leaf_nodes": 30}, + ) + + train_input = "file://./data/train/california_train.csv" + validation_input = "file://./data/validation/california_validation.csv" + + sklearn.fit({"train": train_input, "validation": validation_input}) + logging.warning('Completed model training') + logging.warning('Deploying endpoint in local mode') + + predictor = sklearn.deploy( + initial_instance_count=1, + instance_type="local", + container_startup_health_check_timeout=600 + ) + + # get predictions from local endpoint + test_preds, test_y, test_mse = do_inference_on_local_endpoint(predictor) + + logging.warning('About to delete the endpoint') + predictor.delete_endpoint() + + assert type(test_preds) == np.ndarray, f"predictions are not in a np.ndarray format: {test_preds}" + assert type(test_y) == np.ndarray, f"Y ground truth are not in a np.ndarray format: {test_y}" + assert type(test_mse) == float, f"MSE is not a number: {test_mse}" From 58e37d5fb1d222fb8d9e7433c067f9442945f2c1 Mon Sep 17 00:00:00 2001 From: Pranav Murthy Date: Fri, 31 Jan 2025 18:40:58 -0500 Subject: [PATCH 4/8] name refactor --- .../{test_local_model_cpu_jobs.py => test_local_mode_cpu_jobs.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/integ/sagemaker/local_mode/{test_local_model_cpu_jobs.py => test_local_mode_cpu_jobs.py} (100%) diff --git a/tests/integ/sagemaker/local_mode/test_local_model_cpu_jobs.py b/tests/integ/sagemaker/local_mode/test_local_mode_cpu_jobs.py similarity index 100% rename from tests/integ/sagemaker/local_mode/test_local_model_cpu_jobs.py rename to tests/integ/sagemaker/local_mode/test_local_mode_cpu_jobs.py From 1d7254c5cbe7a9345ffd272158b1f8977f40cf2e Mon Sep 17 00:00:00 2001 From: Pranav Murthy Date: Fri, 31 Jan 2025 18:42:13 -0500 Subject: [PATCH 5/8] EOF empty line --- tests/integ/sagemaker/local_mode/test_local_mode_cpu_jobs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integ/sagemaker/local_mode/test_local_mode_cpu_jobs.py b/tests/integ/sagemaker/local_mode/test_local_mode_cpu_jobs.py index 6b10d8016e..64f82f9ae6 100644 --- a/tests/integ/sagemaker/local_mode/test_local_mode_cpu_jobs.py +++ b/tests/integ/sagemaker/local_mode/test_local_mode_cpu_jobs.py @@ -289,3 +289,4 @@ def test_sklearn_local_model_train_inference(): assert type(test_preds) == np.ndarray, f"predictions are not in a np.ndarray format: {test_preds}" assert type(test_y) == np.ndarray, f"Y ground truth are not in a np.ndarray format: {test_y}" assert type(test_mse) == float, f"MSE is not a number: {test_mse}" + From ce2639c76f8dbc3d3de94a084cd332a2d4327b36 Mon Sep 17 00:00:00 2001 From: Pranav Murthy Date: Sat, 1 Feb 2025 17:30:48 -0500 Subject: [PATCH 6/8] adding test cases for GPU local model inference with HF TGI and DJL lmi test cases --- .../local_mode/test_local_mode_gpu_jobs.py | 233 ++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 tests/integ/sagemaker/local_mode/test_local_mode_gpu_jobs.py diff --git a/tests/integ/sagemaker/local_mode/test_local_mode_gpu_jobs.py b/tests/integ/sagemaker/local_mode/test_local_mode_gpu_jobs.py new file mode 100644 index 0000000000..d5333ee5bb --- /dev/null +++ b/tests/integ/sagemaker/local_mode/test_local_mode_gpu_jobs.py @@ -0,0 +1,233 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import time +from typing import Union + + +import os +import re +import pytest +import subprocess +import logging +import sagemaker +import boto3 +import urllib3 +from pathlib import Path +from sagemaker.huggingface import ( + HuggingFaceModel, + get_huggingface_llm_image_uri +) +from sagemaker.deserializers import JSONDeserializer +from sagemaker.local import LocalSession +from sagemaker.serializers import JSONSerializer + + +# Replace this role ARN with an appropriate role for your environment +ROLE = "arn:aws:iam::111111111111:role/service-role/AmazonSageMaker-ExecutionRole-20200101T000001" + + +def ensure_docker_compose_installed(): + """ + Downloads the Docker Compose plugin if not present, and verifies installation + by checking the output of 'docker compose version' matches the pattern: + 'Docker Compose version vX.Y.Z' + """ + + cli_plugins_path = Path.home() / ".docker" / "cli-plugins" + cli_plugins_path.mkdir(parents=True, exist_ok=True) + + compose_binary_path = cli_plugins_path / "docker-compose" + if not compose_binary_path.exists(): + subprocess.run( + [ + "curl", + "-SL", + "https://github.com/docker/compose/releases/download/v2.3.3/docker-compose-linux-x86_64", + "-o", + str(compose_binary_path), + ], + check=True, + ) + subprocess.run(["chmod", "+x", str(compose_binary_path)], check=True) + + # Verify Docker Compose version + try: + output = subprocess.check_output(["docker", "compose", "version"], stderr=subprocess.STDOUT) + output_decoded = output.decode("utf-8").strip() + logging.info(f"'docker compose version' output: {output_decoded}") + + # Example expected format: "Docker Compose version vxxx" + pattern = r"Docker Compose version+" + match = re.search(pattern, output_decoded) + assert ( + match is not None + ), f"Could not find a Docker Compose version string matching '{pattern}' in: {output_decoded}" + + except subprocess.CalledProcessError as e: + raise AssertionError(f"Failed to verify Docker Compose: {e}") + + +""" +Local Model: HuggingFace LLM Inference +""" +@pytest.mark.local +def test_huggingfacellm_local_model_inference(): + """ + Test local mode inference with DJL-LMI inference containers + without a model_data path provided at runtime. This test should + be run on a GPU only machine with instance set to local_gpu. + """ + ensure_docker_compose_installed() + + # 1. Create a local session for inference + sagemaker_session = LocalSession() + sagemaker_session.config = {"local": {"local_code": True}} + + djllmi_model = sagemaker.Model( + image_uri="763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + env={ + "HF_MODEL_ID": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "OPTION_MAX_MODEL_LEN": "10000", + "OPTION_GPU_MEMORY_UTILIZATION": "0.95", + "OPTION_ENABLE_STREAMING": "false", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MODEL_LOADING_TIMEOUT": "3600", + "OPTION_PAGED_ATTENTION": "false", + "OPTION_DTYPE": "fp16", + }, + role=ROLE, + sagemaker_session=sagemaker_session + ) + + logging.warning('Deploying endpoint in local mode') + logging.warning( + 'Note: if launching for the first time in local mode, container image download might take a few minutes to complete.' + ) + + endpoint_name = "test-djl" + djllmi_model.deploy( + endpoint_name=endpoint_name, + initial_instance_count=1, + instance_type="local_gpu", + container_startup_health_check_timeout=600, + ) + predictor = sagemaker.Predictor( + endpoint_name=endpoint_name, + sagemaker_session=sagemaker_session, + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ) + test_response = predictor.predict( + { + "inputs": """<|begin_of_text|> + <|start_header_id|>system<|end_header_id|> + You are a helpful assistant that thinks and reasons before answering. + <|eot_id|> + <|start_header_id|>user<|end_header_id|> + What's 2x2? + <|eot_id|> + + <|start_header_id|>assistant<|end_header_id|> + """ + } + ) + logging.warning(test_response) + gen_text = test_response['generated_text'] + logging.warning(f"\n=======\nmodel response: {gen_text}\n=======\n") + + assert type(test_response) == dict, f"invalid model response format: {gen_text}" + assert type(gen_text) == str, f"assistant response format: {gen_text}" + + logging.warning('About to delete the endpoint') + predictor.delete_endpoint() + + +""" +Local Model: HuggingFace TGI Inference +""" +@pytest.mark.local +def test_huggingfacetgi_local_model_inference(): + """ + Test local mode inference with HuggingFace TGI inference containers + without a model_data path provided at runtime. This test should + be run on a GPU only machine with instance set to local_gpu. + """ + ensure_docker_compose_installed() + + # 1. Create a local session for inference + sagemaker_session = LocalSession() + sagemaker_session.config = {"local": {"local_code": True}} + + huggingface_model = HuggingFaceModel( + image_uri=get_huggingface_llm_image_uri( + "huggingface", + version="2.3.1" + ), + env={ + "HF_MODEL_ID": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MESSAGES_API_ENABLED": "true", + "OPTION_ENTRYPOINT": "inference.py", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "SAGEMAKER_PROGRAM": "inference.py", + "SM_NUM_GPUS": "1", + "MAX_TOTAL_TOKENS": "1024", + "MAX_INPUT_TOKENS": "800", + "MAX_BATCH_PREFILL_TOKENS": "900", + "DTYPE": "bfloat16", + "PORT": "8080" + }, + role=ROLE, + sagemaker_session=sagemaker_session + ) + + logging.warning('Deploying endpoint in local mode') + logging.warning( + 'Note: if launching for the first time in local mode, container image download might take a few minutes to complete.' + ) + + endpoint_name = "test-hf" + huggingface_model.deploy( + endpoint_name=endpoint_name, + initial_instance_count=1, + instance_type="local_gpu", + container_startup_health_check_timeout=600, + ) + predictor = sagemaker.Predictor( + endpoint_name=endpoint_name, + sagemaker_session=sagemaker_session, + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ) + test_response = predictor.predict( + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is 2x2?"} + ] + } + ) + logging.warning(test_response) + gen_text = test_response['choices'][0]['message'] + logging.warning(f"\n=======\nmodel response: {gen_text}\n=======\n") + + assert type(gen_text) == dict, f"invalid model response: {gen_text}" + assert gen_text['role'] == 'assistant', f"assistant response missing: {gen_text}" + + logging.warning('About to delete the endpoint') + predictor.delete_endpoint() + + + From aabd04cf23f90ea99360891842b4b0d8c7c7cf5e Mon Sep 17 00:00:00 2001 From: Pranav Murthy Date: Sat, 1 Feb 2025 17:31:23 -0500 Subject: [PATCH 7/8] fix missing sagemaker-local dir when model data is set to None --- src/sagemaker/local/image.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 9f550e33c4..80854c20eb 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -344,6 +344,12 @@ def serve(self, model_dir, environment): volumes = self._prepare_serving_volumes(model_dir) else: volumes = None + # if model_data is None, then force create ../sagemaker-local under + # contianer root + os.makedirs( + os.path.join(self.container_root, self.hosts[0]), + exist_ok=True + ) # If the user script was passed as a file:// mount it to the container. if sagemaker.estimator.DIR_PARAM_NAME.upper() in environment: From a1e3b9dd541e062d03fa650984777f5418ba3cdf Mon Sep 17 00:00:00 2001 From: Pranav Murthy Date: Sat, 1 Feb 2025 17:50:56 -0500 Subject: [PATCH 8/8] fix minor function doc string --- .../local_mode/test_local_mode_cpu_jobs.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/integ/sagemaker/local_mode/test_local_mode_cpu_jobs.py b/tests/integ/sagemaker/local_mode/test_local_mode_cpu_jobs.py index 64f82f9ae6..795881a941 100644 --- a/tests/integ/sagemaker/local_mode/test_local_mode_cpu_jobs.py +++ b/tests/integ/sagemaker/local_mode/test_local_mode_cpu_jobs.py @@ -92,12 +92,9 @@ def ensure_docker_compose_installed(): def test_scikit_learn_local_processing(): """ Test local mode processing with a scikit-learn processor. - This uses the same logic as scikit_learn_local_processing.py but in a pytest test function. - - Requirements/Assumptions: - - Docker must be installed and running on the local machine. - - 'processing_script.py' must be in the current working directory (or specify the correct path). - - There should be some local input data if 'processing_script.py' needs it (see ProcessingInput below). + This uses the same logic as scikit_learn_local_processing.py but in + a pytest test function. This test deploys the model locally via Docker, + and asserts that the output is received. """ ensure_docker_compose_installed() @@ -154,9 +151,9 @@ def test_scikit_learn_local_processing(): @pytest.mark.local def test_pytorch_local_model_inference(): """ - Test local mode inference for a TensorFlow NLP model using PyTorch. + Test local mode inference with NLP model using PyTorch. This test deploys the model locally via Docker, performs an inference - on a sample image URL, and asserts that the output is received. + on a sample dataset, and asserts that the output is received. """ ensure_docker_compose_installed() @@ -247,6 +244,11 @@ def do_inference_on_local_endpoint(predictor): """ @pytest.mark.local def test_sklearn_local_model_train_inference(): + """ + Test local mode training and inference with sagemaker SKLearn. + This test runs the model locally via Docker, performs an inference with + sample dataset, and asserts that the output is received. + """ download_training_and_eval_data()