diff --git a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json
index a7640e3f46..1dbcf64d36 100644
--- a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json
+++ b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json
@@ -43236,6 +43236,10 @@
"EvaluatorArn":{
"shape":"EvaluatorArn",
"documentation":"
The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt.
"
+ },
+ "SequenceLength":{
+ "shape":"SequenceLength",
+ "documentation":" The sequence length for the training job.
"
}
},
"documentation":" The configuration for the serverless training job.
"
@@ -43247,6 +43251,19 @@
"Evaluation"
]
},
+ "SequenceLength":{
+ "type":"string",
+ "enum":[
+ "1K",
+ "2K",
+ "4K",
+ "8K",
+ "16K",
+ "32K",
+ "64K",
+ "128K"
+ ]
+ },
"ServerlessMaxConcurrency":{
"type":"integer",
"box":true,
diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py
index 8f99bcba8c..9749b2a4b5 100644
--- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py
+++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py
@@ -9588,6 +9588,7 @@ class ServerlessJobConfig(Base):
peft: The parameter-efficient fine-tuning configuration.
evaluation_type: The evaluation job type. Required when serverless job type is Evaluation.
evaluator_arn: The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt.
+ sequence_length: The sequence length for the training job.
"""
base_model_arn: StrPipeVar
@@ -9597,6 +9598,7 @@ class ServerlessJobConfig(Base):
peft: Optional[StrPipeVar] = Unassigned()
evaluation_type: Optional[StrPipeVar] = Unassigned()
evaluator_arn: Optional[StrPipeVar] = Unassigned()
+ sequence_length: Optional[StrPipeVar] = Unassigned()
class MlflowConfig(Base):
diff --git a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py
index d7ad54ee25..f34a282348 100644
--- a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py
+++ b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py
@@ -15877,6 +15877,7 @@
{"name": "Peft", "shape": "Peft", "type": "string"},
{"name": "EvaluationType", "shape": "EvaluationType", "type": "string"},
{"name": "EvaluatorArn", "shape": "EvaluatorArn", "type": "string"},
+ {"name": "SequenceLength", "shape": "SequenceLength", "type": "string"},
],
"type": "structure",
},
diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py
index 0ea74ee207..0c04b29d2a 100644
--- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py
+++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py
@@ -12,7 +12,13 @@
from sagemaker.core.helper.session_helper import Session
from sagemaker.train.common_utils.recipe_utils import _get_hub_content_metadata
from sagemaker.train.common import TrainingType, CustomizationTechnique, JOB_TYPE, FineTuningOptions
-from sagemaker.core.shapes import ServerlessJobConfig, Channel, DataSource, ModelPackageConfig, MlflowConfig
+from sagemaker.core.shapes import (
+ ServerlessJobConfig,
+ Channel,
+ DataSource,
+ ModelPackageConfig,
+ MlflowConfig,
+)
from sagemaker.train.configs import InputData, OutputDataConfig
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.constants import get_sagemaker_hub_name
@@ -20,11 +26,17 @@
logger = logging.getLogger(__name__)
# Region mappings for model availability
-OPEN_WEIGHTS_REGIONS = ['us-east-1', 'us-west-2', 'ap-northeast-1', 'eu-west-1'] # IAD, PDX, NRT, DUB
-NOVA_REGIONS = ['us-east-1', 'us-west-2'] # IAD, PDX
+OPEN_WEIGHTS_REGIONS = [
+ "us-east-1",
+ "us-west-2",
+ "ap-northeast-1",
+ "eu-west-1",
+] # IAD, PDX, NRT, DUB
+NOVA_REGIONS = ["us-east-1", "us-west-2"] # IAD, PDX
# Constants
DEFAULT_REGION = "us-west-2"
+
def _validate_model_region_availability(model_name: str, region_name: str):
"""Validate if the model is available in the specified region."""
if "nova" in model_name.lower():
@@ -48,26 +60,24 @@ def _validate_model_region_availability(model_name: str, region_name: str):
)
-
-
def _get_beta_session():
"""Create a SageMaker session with beta endpoint for demo purposes."""
- sm_client = boto3.client('sagemaker', region_name=DEFAULT_REGION)
+ sm_client = boto3.client("sagemaker", region_name=DEFAULT_REGION)
return Session(sagemaker_client=sm_client)
def _read_domain_id_from_metadata() -> Optional[str]:
"""Read domain ID from Studio metadata file.
-
+
This is the standard location for domain information in Studio with Spaces.
Returns None if not running in Studio or if metadata file doesn't exist.
"""
try:
- metadata_path = '/opt/ml/metadata/resource-metadata.json'
+ metadata_path = "/opt/ml/metadata/resource-metadata.json"
if os.path.exists(metadata_path):
- with open(metadata_path, 'r') as f:
+ with open(metadata_path, "r") as f:
metadata = json.load(f)
- return metadata.get('DomainId')
+ return metadata.get("DomainId")
except Exception as e:
logger.debug(f"Could not read Studio metadata file: {e}")
return None
@@ -75,78 +85,88 @@ def _read_domain_id_from_metadata() -> Optional[str]:
def _get_current_domain_id(sagemaker_session) -> Optional[str]:
"""Get current SageMaker Studio domain ID.
-
+
Checks multiple sources in order of reliability:
1. Studio metadata file (Studio with Spaces - newer architecture)
2. User profile ARN (Studio Classic with User Profiles - legacy)
-
+
Returns None if not running in a Studio environment with domain.
"""
# Try metadata file first (Studio with Spaces)
domain_id = _read_domain_id_from_metadata()
if domain_id:
return domain_id
-
+
# Fallback to original logic (Studio Classic with User Profiles)
try:
user_profile_arn = sagemaker_session.get_caller_identity_arn()
- if user_profile_arn and 'user-profile' in user_profile_arn:
+ if user_profile_arn and "user-profile" in user_profile_arn:
# ARN format: arn:aws:sagemaker:region:account:user-profile/domain-id/profile-name
- return user_profile_arn.split('/')[1]
+ return user_profile_arn.split("/")[1]
except Exception as e:
logger.debug(f"Could not extract domain ID from user profile ARN: {e}")
-
+
return None
-def _resolve_mlflow_resource_arn(sagemaker_session, mlflow_resource_arn: Optional[str] = None) -> Optional[str]:
+def _resolve_mlflow_resource_arn(
+ sagemaker_session, mlflow_resource_arn: Optional[str] = None
+) -> Optional[str]:
"""Resolve MLflow resource ARN using default experience logic."""
if mlflow_resource_arn:
return mlflow_resource_arn
-
+
try:
-
+
mlflow_apps = MlflowApp.get_all(
session=sagemaker_session.boto_session,
- region=sagemaker_session.boto_session.region_name
+ region=sagemaker_session.boto_session.region_name,
)
-
+
mlflow_apps_list = list(mlflow_apps)
current_domain_id = _get_current_domain_id(sagemaker_session)
-
+
# Check for domain match
if current_domain_id:
- domain_match = next((app for app in mlflow_apps_list
- if isinstance(app.default_domain_id_list, list)
- and current_domain_id in app.default_domain_id_list), None)
+ domain_match = next(
+ (
+ app
+ for app in mlflow_apps_list
+ if isinstance(app.default_domain_id_list, list)
+ and current_domain_id in app.default_domain_id_list
+ ),
+ None,
+ )
if domain_match:
logger.info("Using domain-matched MLflow app: %s", domain_match.arn)
return domain_match.arn
-
+
# Check for account default
- account_default = next((app for app in mlflow_apps_list
- if app.account_default_status == "ENABLED"), None)
+ account_default = next(
+ (app for app in mlflow_apps_list if app.account_default_status == "ENABLED"), None
+ )
if account_default:
logger.info("Using account default MLflow app: %s", account_default.arn)
return account_default.arn
-
+
# Use first available with ready status
if mlflow_apps_list:
- ready_app = next((app for app in mlflow_apps_list
- if app.status in ["Created", "Updated"]), None)
+ ready_app = next(
+ (app for app in mlflow_apps_list if app.status in ["Created", "Updated"]), None
+ )
if ready_app:
logger.info("Using first available ready MLflow app: %s", ready_app.arn)
return ready_app.arn
-
+
# Create new app
new_app = _create_mlflow_app(sagemaker_session)
if new_app:
logger.info("Created new MLflow app: %s", new_app.arn)
return new_app.arn
-
+
logger.warning("Failed to create MLflow app. MLflow tracking disabled.")
return None
-
+
except Exception as e:
logger.error("Error resolving MLflow resource ARN: %s", e)
return None
@@ -156,45 +176,46 @@ def _create_mlflow_app(sagemaker_session) -> Optional[MlflowApp]:
"""Create a new MLflow app with minimal configuration."""
try:
app_name = f"finetune-mlflow-{int(time.time())}"
- account_id = sagemaker_session.boto_session.client('sts').get_caller_identity()['Account']
+ account_id = sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"]
region = sagemaker_session.boto_session.region_name
artifact_store_uri = f"s3://sagemaker-{region}-{account_id}/mlflow-artifacts"
role_arn = TrainDefaults.get_role(role=None, sagemaker_session=sagemaker_session)
-
+
# Ensure S3 bucket and prefix exist
- s3_client = sagemaker_session.boto_session.client('s3')
+ s3_client = sagemaker_session.boto_session.client("s3")
bucket_name = f"sagemaker-{region}-{account_id}"
-
+
try:
# Check if prefix exists
- response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix="mlflow-artifacts/", MaxKeys=1)
- if 'Contents' not in response:
+ response = s3_client.list_objects_v2(
+ Bucket=bucket_name, Prefix="mlflow-artifacts/", MaxKeys=1
+ )
+ if "Contents" not in response:
s3_client.put_object(Bucket=bucket_name, Key="mlflow-artifacts/")
except s3_client.exceptions.NoSuchBucket:
# Bucket doesn't exist, create bucket and prefix
- if region == 'us-east-1':
+ if region == "us-east-1":
s3_client.create_bucket(Bucket=bucket_name)
else:
s3_client.create_bucket(
- Bucket=bucket_name,
- CreateBucketConfiguration={'LocationConstraint': region}
+ Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}
)
s3_client.put_object(Bucket=bucket_name, Key="mlflow-artifacts/")
-
+
new_app = MlflowApp.create(
name=app_name,
artifact_store_uri=artifact_store_uri,
role_arn=role_arn,
account_default_status="DISABLED",
session=sagemaker_session.boto_session,
- region=region
+ region=region,
)
-
+
# Wait for app to reach Created/Updated state
max_wait_time = 600 # 10 minutes
- poll_interval = 10 # 10 seconds
+ poll_interval = 10 # 10 seconds
start_time = time.time()
-
+
while time.time() - start_time < max_wait_time:
new_app.refresh()
if new_app.status in ["Created", "Updated"]:
@@ -202,18 +223,18 @@ def _create_mlflow_app(sagemaker_session) -> Optional[MlflowApp]:
elif new_app.status in ["Failed", "Stopped"]:
# Get detailed error from MLflow app
error_msg = f"MLflow app creation failed with status: {new_app.status}"
- if hasattr(new_app, 'failure_reason') and new_app.failure_reason:
+ if hasattr(new_app, "failure_reason") and new_app.failure_reason:
error_msg += f". Reason: {new_app.failure_reason}"
raise RuntimeError(error_msg)
time.sleep(poll_interval)
-
+
# Timeout case - get current status and any error details
new_app.refresh()
error_msg = f"MLflow app creation failed. Current status: {new_app.status}"
- if hasattr(new_app, 'failure_reason') and new_app.failure_reason:
+ if hasattr(new_app, "failure_reason") and new_app.failure_reason:
error_msg += f". Reason: {new_app.failure_reason}"
raise RuntimeError(error_msg)
-
+
except Exception as e:
logger.error("Failed to create MLflow app: %s", e)
return None
@@ -229,14 +250,18 @@ def _validate_dataset_arn(dataset: str, param_name: str):
def _validate_evaluator_arn(evaluator_arn: str, param_name: str):
"""Validate that evaluator_arn is in correct ARN format."""
arn_pattern = r"^arn:aws:sagemaker:[^:]+:\d+:hub-content/[^/]+/JsonDoc/[^/]+/[\d\.]+$"
- if not evaluator_arn.startswith("arn:aws:sagemaker:") or not re.match(arn_pattern, evaluator_arn):
+ if not evaluator_arn.startswith("arn:aws:sagemaker:") or not re.match(
+ arn_pattern, evaluator_arn
+ ):
raise ValueError(f"{param_name} must be a valid SageMaker hub-content evaluator ARN")
def _validate_model_package_group_requirement(model, model_package_group_name):
"""Validate model_package_group_name when source_model_package_arn is not available."""
if not isinstance(model, ModelPackage) and not model_package_group_name:
- raise ValueError("model_package_group_name must be provided when source_model_package_arn is not available")
+ raise ValueError(
+ "model_package_group_name must be provided when source_model_package_arn is not available"
+ )
def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_session) -> str:
@@ -244,7 +269,7 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_
if isinstance(model_package_group_name_or_arn, str):
# Check if it's already an ARN using pattern matching
arn_pattern = r"^arn:aws:sagemaker:[^:]+:\d+:model-package-group/[^/]+$"
-
+
if re.match(arn_pattern, model_package_group_name_or_arn):
# It's already an ARN
return model_package_group_name_or_arn
@@ -253,7 +278,7 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_
model_package_group = ModelPackageGroup.get(
model_package_group_name=model_package_group_name_or_arn,
session=sagemaker_session.boto_session,
- region=sagemaker_session.boto_session.region_name
+ region=sagemaker_session.boto_session.region_name,
)
return model_package_group.model_package_group_arn
else:
@@ -263,7 +288,7 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_
def _get_default_s3_output_path(sagemaker_session) -> str:
"""Generate default S3 output path: s3://sagemaker--/output"""
- account_id = sagemaker_session.boto_session.client('sts').get_caller_identity()['Account']
+ account_id = sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"]
region = sagemaker_session.boto_session.region_name
return f"s3://sagemaker-{region}-{account_id}/output"
@@ -295,17 +320,21 @@ def _resolve_model_name(model_package) -> str:
if model_package:
try:
# Extract base model from InferenceSpecification
- if (model_package.inference_specification and
- model_package.inference_specification.containers):
+ if (
+ model_package.inference_specification
+ and model_package.inference_specification.containers
+ ):
container = model_package.inference_specification.containers[0]
- if hasattr(container, 'base_model') and container.base_model:
+ if hasattr(container, "base_model") and container.base_model:
return container.base_model.hub_content_name
-
- raise ValueError("Continued fine tuning is only allowed on model packages fine tuned with sagemaker 1p models")
+
+ raise ValueError(
+ "Continued fine tuning is only allowed on model packages fine tuned with sagemaker 1p models"
+ )
except Exception as e:
logger.error("Failed to resolve model_name from model package: %s", e)
raise
-
+
raise ValueError("model name or package must be provided")
@@ -318,10 +347,44 @@ def _resolve_model_package_arn(model_package) -> Optional[str]:
return None
-def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session,
- hub_name: Optional[str] = None) -> tuple:
+def _parse_context_length(value) -> int:
+ """Parse a context length value like '8K', '32K', '128K' into an integer (e.g., 8192).
+
+ Returns 0 if value is None or unparseable.
+ """
+ if not value:
+ return 0
+ value = str(value).strip().upper()
+ if value.endswith("K"):
+ try:
+ return int(value[:-1]) * 1024
+ except ValueError:
+ return 0
+ try:
+ return int(value)
+ except ValueError:
+ return 0
+
+
+def _get_fine_tuning_options_and_model_arn(
+ model_name: str,
+ customization_technique: str,
+ training_type,
+ sagemaker_session,
+ sequence_length=None,
+ hub_name: Optional[str] = None,
+) -> tuple:
"""Get fine-tuning options and model ARN for given customization technique.
-
+
+ Args:
+ model_name: Name of the model in the hub.
+ customization_technique: Technique (e.g., "SFT", "DPO", "RLVR", "RLAIF").
+ training_type: TrainingType enum or string ("LORA", "FULL").
+ sagemaker_session: SageMaker session for API calls.
+ sequence_length: Optional sequence length (e.g., "8K"). When provided, filters
+ recipes by MaxContextLength >= the requested value.
+ hub_name: Hub name (default: "SageMakerPublicHub").
+
Returns:
tuple: (FineTuningOptions, model_arn, is_gated_model)
"""
@@ -332,42 +395,95 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
hub_content = _get_hub_content_metadata(
hub_name=hub_name,
- hub_content_type="Model",
+ hub_content_type="Model",
hub_content_name=model_name,
session=sagemaker_session.boto_session,
- region=sagemaker_session.boto_session.region_name
+ region=sagemaker_session.boto_session.region_name,
)
-
- model_arn = hub_content.get('hub_content_arn')
- document = hub_content.get('hub_content_document')
-
+
+ model_arn = hub_content.get("hub_content_arn")
+ document = hub_content.get("hub_content_document")
+
# Check if model is gated
is_gated_model = document.get("GatedBucket", False)
-
+
recipe_collection = document.get("RecipeCollection", [])
-
+
# Filter recipes by customization technique
- matching_recipes = [r for r in recipe_collection if r.get("CustomizationTechnique") == customization_technique]
-
+ matching_recipes = [
+ r
+ for r in recipe_collection
+ if r.get("CustomizationTechnique") == customization_technique
+ ]
+
if not matching_recipes:
- raise ValueError(f"No recipes found for customization technique: {customization_technique}")
-
+ raise ValueError(
+ f"No recipes found for customization technique: {customization_technique}"
+ )
+
# Filter recipes that have SmtjRecipeTemplateS3Uri key
recipes_with_template = [r for r in matching_recipes if r.get("SmtjRecipeTemplateS3Uri")]
-
+
if not recipes_with_template:
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}")
+ # Filter by SequenceLength before recipe selection if sequence_length is requested
+ if sequence_length:
+ requested = _parse_context_length(sequence_length)
+ candidates_with_context = [r for r in recipes_with_template if r.get("SequenceLength")]
+ if candidates_with_context:
+ filtered = [
+ r
+ for r in candidates_with_context
+ if _parse_context_length(r.get("SequenceLength")) >= requested
+ ]
+ if filtered:
+ filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength")))
+ recipes_with_template = filtered
+ else:
+ available = sorted(
+ set(r.get("SequenceLength") for r in candidates_with_context)
+ )
+ raise ValueError(
+ f"No recipes found with SequenceLength >= {sequence_length}. "
+ f"Available sequence lengths: {available}"
+ )
+ else:
+ raise ValueError(
+ f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, "
+ f"and sequence length:{sequence_length}"
+ )
+
# Select recipe based on training type
# Collect override_params from ALL matching recipes (standard + subscription)
recipe = None
- if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
- recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
- elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
- recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
+ if (
+ isinstance(training_type, TrainingType) and training_type == TrainingType.LORA
+ ) or training_type == "LORA":
+ recipe = next(
+ (
+ r
+ for r in recipes_with_template
+ if r.get("Peft") and not r.get("IsSubscriptionModel")
+ ),
+ None,
+ )
+ elif (
+ isinstance(training_type, TrainingType) and training_type == TrainingType.FULL
+ ) or training_type == "FULL":
+ recipe = next(
+ (
+ r
+ for r in recipes_with_template
+ if not r.get("Peft") and not r.get("IsSubscriptionModel")
+ ),
+ None,
+ )
if not recipe:
- raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")
+ raise ValueError(
+ f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}"
+ )
# Start with the standard recipe's override_params
options_dict = {}
@@ -380,14 +496,33 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
options_dict = json.loads(obj["Body"].read())
# Auto-detect and merge subscription recipe's override_params if available
- if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
- sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
+ if (
+ isinstance(training_type, TrainingType) and training_type == TrainingType.LORA
+ ) or training_type == "LORA":
+ sub_recipe = next(
+ (
+ r
+ for r in recipes_with_template
+ if r.get("Peft") and r.get("IsSubscriptionModel")
+ ),
+ None,
+ )
else:
- sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)
+ sub_recipe = next(
+ (
+ r
+ for r in recipes_with_template
+ if not r.get("Peft") and r.get("IsSubscriptionModel")
+ ),
+ None,
+ )
if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"):
try:
- sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"])
+ sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace(
+ "{customer_id}",
+ sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"],
+ )
sub_uri_path = sub_s3_uri.replace("s3://", "")
# Handle access point ARN URIs
if sub_uri_path.startswith("arn:"):
@@ -405,73 +540,77 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
if k not in options_dict:
v_copy = v.copy() if isinstance(v, dict) else v
if isinstance(v_copy, dict):
- v_copy['default'] = None # No default — won't appear in to_dict() unless set
+ v_copy["default"] = (
+ None # No default — won't appear in to_dict() unless set
+ )
options_dict[k] = v_copy
except Exception as e:
- logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}")
+ logger.debug(
+ f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}"
+ )
if options_dict:
return FineTuningOptions(options_dict), model_arn, is_gated_model
else:
return FineTuningOptions({}), model_arn, is_gated_model
-
+
except Exception as e:
logger.error("Exception getting fine-tuning options: %s", e)
raise
-def _create_input_channels(dataset: str, content_type: Optional[str] = None,
- input_compression_type: Optional[str] = None,
- record_wrapper_type: Optional[str] = None,
- input_mode: Optional[str] = None):
+def _create_input_channels(
+ dataset: str,
+ content_type: Optional[str] = None,
+ input_compression_type: Optional[str] = None,
+ record_wrapper_type: Optional[str] = None,
+ input_mode: Optional[str] = None,
+):
"""Create input channels from dataset (S3 URI or dataset ARN).
-
+
Args:
dataset: S3 URI (s3://bucket/key) or dataset ARN (arn:aws:sagemaker:...)
-
+
Returns:
list: List of Channel objects
"""
channels = []
-
if dataset.startswith("s3://"):
# S3 URI - create S3DataSource
data_source = DataSource(
s3_data_source={
"s3_uri": dataset,
"s3_data_type": "S3Prefix",
- "s3_data_distribution_type": "FullyReplicated"
+ "s3_data_distribution_type": "FullyReplicated",
}
)
else:
# Dataset ARN - validate and create dataset source
_validate_dataset_arn(dataset, "dataset")
- data_source = DataSource(
- dataset_source={"dataset_arn": dataset}
- )
-
+ data_source = DataSource(dataset_source={"dataset_arn": dataset})
+
channel = Channel(
channel_name="train",
data_source=data_source,
content_type=content_type,
compression_type=input_compression_type,
record_wrapper_type=record_wrapper_type,
- input_mode=input_mode
- )
+ input_mode=input_mode,
+ )
channels.append(channel)
-
+
return channels
def _resolve_model_and_name(model, sagemaker_session=None):
"""Resolve model and extract model name from string, ARN, or ModelPackage object.
-
+
Args:
model: Can be a model name (str), model package ARN (str), or ModelPackage object
sagemaker_session: SageMaker session for API calls (required for ARN resolution)
-
+
Returns:
tuple: (resolved_model, model_name)
"""
@@ -481,14 +620,15 @@ def _resolve_model_and_name(model, sagemaker_session=None):
region_name = sagemaker_session.boto_region_name
else:
# Try to get region from SAGEMAKER_REGION env var, then boto3 session, then AWS_DEFAULT_REGION
- region_name = os.environ.get('SAGEMAKER_REGION')
+ region_name = os.environ.get("SAGEMAKER_REGION")
if not region_name:
try:
import boto3
- region_name = boto3.Session().region_name or os.environ.get('AWS_DEFAULT_REGION')
+
+ region_name = boto3.Session().region_name or os.environ.get("AWS_DEFAULT_REGION")
except:
pass
-
+
if isinstance(model, str):
# Check if it's a model package ARN
if model.startswith("arn:aws:sagemaker:") and ":model-package/" in model:
@@ -496,7 +636,7 @@ def _resolve_model_and_name(model, sagemaker_session=None):
model_package = ModelPackage.get(
model_package_name=model,
session=sagemaker_session.boto_session if sagemaker_session else None,
- region=sagemaker_session.boto_session.region_name if sagemaker_session else None
+ region=sagemaker_session.boto_session.region_name if sagemaker_session else None,
)
model_name = _resolve_model_name(model_package)
# Validate region availability
@@ -518,23 +658,34 @@ def _resolve_model_and_name(model, sagemaker_session=None):
return model, model_name
-def _create_serverless_config(model_arn, customization_technique,
- training_type, accept_eula, evaluator_arn=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']:
+def _create_serverless_config(
+ model_arn,
+ customization_technique,
+ training_type,
+ accept_eula,
+ evaluator_arn=None,
+ sequence_length=None,
+ job_type=JOB_TYPE,
+) -> Optional["ServerlessJobConfig"]:
"""Create serverless job configuration for fine-tuning.
-
+
Args:
model_arn: ARN of the base model
customization_technique: Technique used (e.g., "SFT", "DPO", "RLVR", "RLAIF")
training_type: Training type (TrainingType enum or string)
accept_eula: Boolean indicating if EULA is accepted
evaluator_arn: Optional evaluator ARN for RLVR/RLAIF
+ sequence_length: Optional sequence length enum value (e.g., "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K")
job_type: Type of job (default: "FineTuning")
-
+
Returns:
ServerlessJobConfig object or None if required parameters are missing
"""
- peft = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) \
+ peft = (
+ None
+ if (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL)
else (training_type.value if isinstance(training_type, TrainingType) else training_type)
+ )
# Create ServerlessJobConfig using shapes
serverless_config = ServerlessJobConfig(
@@ -543,7 +694,8 @@ def _create_serverless_config(model_arn, customization_technique,
customization_technique=customization_technique,
peft=peft,
evaluator_arn=evaluator_arn,
- accept_eula=accept_eula
+ accept_eula=accept_eula,
+ sequence_length=sequence_length,
)
return serverless_config
@@ -551,44 +703,41 @@ def _create_serverless_config(model_arn, customization_technique,
def _create_input_data_config(training_dataset, validation_dataset=None):
"""Create input data configuration from training and validation datasets.
-
+
Args:
training_dataset: Training dataset (method parameter takes priority over class attribute)
validation_dataset: Validation dataset (method parameter takes priority over class attribute)
-
+
Returns:
List of InputData objects for training job configuration
"""
# Extract and validate training dataset
final_training_dataset = _extract_dataset_source(training_dataset, "training_dataset")
-
- input_data_config = [
- InputData(channel_name="train", data_source=final_training_dataset)
- ]
-
+
+ input_data_config = [InputData(channel_name="train", data_source=final_training_dataset)]
+
# Add validation dataset if provided
if validation_dataset:
final_validation_dataset = _extract_dataset_source(validation_dataset, "validation_dataset")
input_data_config.append(
InputData(channel_name="validation", data_source=final_validation_dataset)
)
-
- return input_data_config
+ return input_data_config
def _create_model_package_config(model_package_group_name, model, sagemaker_session):
"""Create model package configuration with resolved ARNs.
-
+
Args:
model_package_group_name: Model package group name to resolve
model: Model object (used to resolve source model package ARN if it's a ModelPackage)
sagemaker_session: SageMaker session for API calls
-
+
Returns:
ModelPackageConfig object or None if no model package group name provided
"""
-
+
model_package_group_arn = None
if model_package_group_name:
model_package_group_arn = _resolve_model_package_group_arn(
@@ -605,22 +754,21 @@ def _create_model_package_config(model_package_group_name, model, sagemaker_sess
)
-
-def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None,
- mlflow_experiment_name=None, mlflow_run_name=None):
+def _create_mlflow_config(
+ sagemaker_session, mlflow_resource_arn=None, mlflow_experiment_name=None, mlflow_run_name=None
+):
"""Create MLflow configuration with resolved resource ARN.
-
+
Args:
sagemaker_session: SageMaker session for resolving MLflow ARN
mlflow_resource_arn: MLflow resource ARN (if None, uses default experience)
mlflow_experiment_name: MLflow experiment name
mlflow_run_name: MLflow run name
-
+
Returns:
MlflowConfig object or None if no MLflow resource ARN is resolved
"""
-
# Derive mlflow_resource_arn with default experience
resolved_mlflow_arn = _resolve_mlflow_resource_arn(sagemaker_session, mlflow_resource_arn)
logger.info(f"MLflow resource ARN: {resolved_mlflow_arn}")
@@ -633,18 +781,18 @@ def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None,
mlflow_experiment_name=mlflow_experiment_name,
mlflow_run_name=mlflow_run_name,
)
-
+
return mlflow_config
-def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None):
+def _create_output_config(sagemaker_session, s3_output_path=None, kms_key_id=None):
"""Create output data configuration with default S3 path if needed.
-
+
Args:
s3_output_path: S3 output path (if None, generates default path)
sagemaker_session: SageMaker session for generating default path
kms_key_id: Optional KMS key ID for encryption
-
+
Returns:
OutputDataConfig object
"""
@@ -652,7 +800,7 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None
# Use default S3 output path if none provided
if s3_output_path is None:
s3_output_path = _get_default_s3_output_path(sagemaker_session)
-
+
# Validate S3 path exists
_validate_s3_path_exists(s3_output_path, sagemaker_session)
@@ -662,16 +810,16 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None
)
-def _convert_input_data_to_channels(input_data_config ):
+def _convert_input_data_to_channels(input_data_config):
"""Convert InputData objects to Channel objects with S3 and dataset ARN support.
-
+
Args:
input_data_config: List of InputData objects
-
+
Returns:
List of Channel objects
"""
-
+
channels = []
for input_data in input_data_config:
if input_data.data_source.startswith("s3://"):
@@ -680,21 +828,19 @@ def _convert_input_data_to_channels(input_data_config ):
s3_data_source={
"s3_uri": input_data.data_source,
"s3_data_type": "S3Prefix",
- "s3_data_distribution_type": "FullyReplicated"
+ "s3_data_distribution_type": "FullyReplicated",
}
)
else:
# Dataset ARN - create dataset source
- data_source = DataSource(
- dataset_source={"dataset_arn": input_data.data_source}
- )
+ data_source = DataSource(dataset_source={"dataset_arn": input_data.data_source})
channel = Channel(
channel_name=input_data.channel_name,
data_source=data_source,
)
channels.append(channel)
-
+
return channels
@@ -703,41 +849,45 @@ def _validate_and_resolve_model_package_group(model, model_package_group_name):
# If model_package_group_name is already provided, return it as-is
if model_package_group_name:
return model_package_group_name
-
+
# Try to resolve from ModelPackage if available
if isinstance(model, ModelPackage):
return model.model_package_group_name
-
+
# Only validate if model_package_group_name is None and model is not ModelPackage
- raise ValueError("model_package_group_name must be provided when model given is "
- "not a ModelPackage artifact/not continued finetuning")
+ raise ValueError(
+ "model_package_group_name must be provided when model given is "
+ "not a ModelPackage artifact/not continued finetuning"
+ )
def _validate_eula_for_gated_model(model, accept_eula, is_gated_model):
"""Validate EULA acceptance for gated models.
-
+
Args:
model: Original model input (string, ARN, or ModelPackage)
accept_eula: Boolean indicating if EULA is accepted
is_gated_model: Boolean indicating if the model is gated
-
+
Returns:
bool: True if EULA is accepted (either explicitly or by default for ARN/ModelPackage)
-
+
Raises:
ValueError: If model is gated but accept_eula is False
"""
# For ModelPackage/ARN inputs, EULA is assumed accepted by default
- if isinstance(model, ModelPackage) or (isinstance(model, str) and model.startswith("arn:aws:sagemaker:")):
+ if isinstance(model, ModelPackage) or (
+ isinstance(model, str) and model.startswith("arn:aws:sagemaker:")
+ ):
return True
-
+
# Validate EULA acceptance for gated models
if is_gated_model and not accept_eula:
raise ValueError(
f"Model '{model}' is a gated model and requires EULA acceptance. "
"Please set accept_eula=True to proceed with training."
)
-
+
return accept_eula
@@ -745,14 +895,14 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session):
"""Validate S3 path and create bucket/prefix if they don't exist."""
if not s3_path.startswith("s3://"):
raise ValueError(f"Invalid S3 path format: {s3_path}")
-
+
# Parse S3 URI
s3_parts = s3_path.replace("s3://", "").split("/", 1)
bucket_name = s3_parts[0]
prefix = s3_parts[1] if len(s3_parts) > 1 else ""
-
- s3_client = sagemaker_session.boto_session.client('s3')
-
+
+ s3_client = sagemaker_session.boto_session.client("s3")
+
try:
# Check if bucket exists, create if it doesn't
try:
@@ -761,25 +911,24 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session):
if "NoSuchBucket" in str(e) or "Not Found" in str(e):
# Create bucket
region = sagemaker_session.boto_region_name
- if region == 'us-east-1':
+ if region == "us-east-1":
s3_client.create_bucket(Bucket=bucket_name)
else:
s3_client.create_bucket(
- Bucket=bucket_name,
- CreateBucketConfiguration={'LocationConstraint': region}
+ Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}
)
else:
raise
-
+
# If prefix is provided, check if it exists, create if it doesn't
if prefix:
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix, MaxKeys=1)
- if 'Contents' not in response:
+ if "Contents" not in response:
# Create the prefix by putting an empty object
- if not prefix.endswith('/'):
- prefix += '/'
- s3_client.put_object(Bucket=bucket_name, Key=prefix, Body=b'')
-
+ if not prefix.endswith("/"):
+ prefix += "/"
+ s3_client.put_object(Bucket=bucket_name, Key=prefix, Body=b"")
+
except Exception as e:
raise ValueError(f"Failed to validate/create S3 path '{s3_path}': {str(e)}")
@@ -787,6 +936,7 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session):
def _validate_hyperparameter_values(hyperparameters: dict):
"""Validate hyperparameter values for allowed characters."""
import re
+
allowed_chars = r"^[a-zA-Z0-9/_.:,\-\s'\"\[\]]*$"
for key, value in hyperparameters.items():
if isinstance(value, str) and not re.match(allowed_chars, value):
diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py
index bd5d9a11bd..6a6f3f07bd 100644
--- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py
+++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py
@@ -19,7 +19,7 @@
_create_mlflow_config,
_create_model_package_config,
_validate_eula_for_gated_model,
- _validate_hyperparameter_values
+ _validate_hyperparameter_values,
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
@@ -53,19 +53,19 @@ class DPOTrainer(BaseTrainer):
model="meta-llama/Llama-2-7b-hf",
model_package_group="my-dpo-models"
)
-
+
# Create training job (non-blocking)
training_job = trainer.train(
training_dataset="s3://bucket/preference_data.jsonl",
wait=False
)
-
+
# Wait for completion
training_job.wait()
-
+
# Refresh job status
training_job.refresh()
-
+
# Get the fine-tuned model package ARN
model_package_arn = training_job.output_model_package_arn
@@ -100,31 +100,38 @@ class DPOTrainer(BaseTrainer):
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, uses SageMaker service default (24 hours for serverless training).
+ sequence_length (Optional[str]):
+ The sequence length for the training job. Valid values are
+ "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K".
+ If not specified, the service will use default recipe selection behavior.
"""
+
def __init__(
- self,
- model: Union[str, ModelPackage],
- training_type: Union[TrainingType, str] = TrainingType.LORA,
- model_package_group: Optional[Union[str, ModelPackageGroup]] = None,
- mlflow_resource_arn: Optional[str] = None,
- mlflow_experiment_name: Optional[str] = None,
- mlflow_run_name: Optional[str] = None,
- training_dataset: Optional[Union[str, DataSet]] = None,
- validation_dataset: Optional[Union[str, DataSet]] = None,
- s3_output_path: Optional[str] = None,
- kms_key_id: Optional[str] = None,
- networking: Optional[VpcConfig] = None,
- accept_eula: bool = False,
- stopping_condition: Optional[StoppingCondition] = None,
- **kwargs,
+ self,
+ model: Union[str, ModelPackage],
+ training_type: Union[TrainingType, str] = TrainingType.LORA,
+ model_package_group: Optional[Union[str, ModelPackageGroup]] = None,
+ mlflow_resource_arn: Optional[str] = None,
+ mlflow_experiment_name: Optional[str] = None,
+ mlflow_run_name: Optional[str] = None,
+ training_dataset: Optional[Union[str, DataSet]] = None,
+ validation_dataset: Optional[Union[str, DataSet]] = None,
+ s3_output_path: Optional[str] = None,
+ kms_key_id: Optional[str] = None,
+ networking: Optional[VpcConfig] = None,
+ accept_eula: bool = False,
+ stopping_condition: Optional[StoppingCondition] = None,
+ sequence_length: Optional[str] = None,
+ **kwargs,
):
super().__init__(**kwargs)
-
+
self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session)
self.training_type = training_type
- self.model_package_group = _validate_and_resolve_model_package_group(model,
- model_package_group)
+ self.model_package_group = _validate_and_resolve_model_package_group(
+ model, model_package_group
+ )
self.mlflow_resource_arn = mlflow_resource_arn
self.mlflow_experiment_name = mlflow_experiment_name
self.mlflow_run_name = mlflow_run_name
@@ -134,19 +141,23 @@ def __init__(
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition
+ self.sequence_length = sequence_length
# Initialize fine-tuning options with beta session fallback
- self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
- CustomizationTechnique.DPO.value,
- self.training_type,
- self.sagemaker_session or TrainDefaults.get_sagemaker_session(
- sagemaker_session=self.sagemaker_session
-
- ))
-
+ self.hyperparameters, self._model_arn, is_gated_model = (
+ _get_fine_tuning_options_and_model_arn(
+ self._model_name,
+ CustomizationTechnique.DPO.value,
+ self.training_type,
+ self.sagemaker_session
+ or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session),
+ sequence_length=self.sequence_length,
+ )
+ )
+
# Process hyperparameters
self._process_hyperparameters()
-
+
# Validate and set EULA acceptance
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
@@ -154,35 +165,37 @@ def _process_hyperparameters(self):
"""Remove hyperparameter keys that are handled by constructor inputs."""
if self.hyperparameters:
# Remove keys that are handled by constructor inputs
- if hasattr(self.hyperparameters, 'data_path'):
- delattr(self.hyperparameters, 'data_path')
- self.hyperparameters._specs.pop('data_path', None)
- if hasattr(self.hyperparameters, 'output_path'):
- delattr(self.hyperparameters, 'output_path')
- self.hyperparameters._specs.pop('output_path', None)
- if hasattr(self.hyperparameters, 'data_s3_path'):
- delattr(self.hyperparameters, 'data_s3_path')
- self.hyperparameters._specs.pop('data_s3_path', None)
- if hasattr(self.hyperparameters, 'output_s3_path'):
- delattr(self.hyperparameters, 'output_s3_path')
- self.hyperparameters._specs.pop('output_s3_path', None)
- if hasattr(self.hyperparameters, 'training_data_name'):
- delattr(self.hyperparameters, 'training_data_name')
- self.hyperparameters._specs.pop('training_data_name', None)
- if hasattr(self.hyperparameters, 'validation_data_name'):
- delattr(self.hyperparameters, 'validation_data_name')
- self.hyperparameters._specs.pop('validation_data_name', None)
- if hasattr(self.hyperparameters, 'validation_data_path'):
- delattr(self.hyperparameters, 'validation_data_path')
- self.hyperparameters._specs.pop('validation_data_path', None)
+ if hasattr(self.hyperparameters, "data_path"):
+ delattr(self.hyperparameters, "data_path")
+ self.hyperparameters._specs.pop("data_path", None)
+ if hasattr(self.hyperparameters, "output_path"):
+ delattr(self.hyperparameters, "output_path")
+ self.hyperparameters._specs.pop("output_path", None)
+ if hasattr(self.hyperparameters, "data_s3_path"):
+ delattr(self.hyperparameters, "data_s3_path")
+ self.hyperparameters._specs.pop("data_s3_path", None)
+ if hasattr(self.hyperparameters, "output_s3_path"):
+ delattr(self.hyperparameters, "output_s3_path")
+ self.hyperparameters._specs.pop("output_s3_path", None)
+ if hasattr(self.hyperparameters, "training_data_name"):
+ delattr(self.hyperparameters, "training_data_name")
+ self.hyperparameters._specs.pop("training_data_name", None)
+ if hasattr(self.hyperparameters, "validation_data_name"):
+ delattr(self.hyperparameters, "validation_data_name")
+ self.hyperparameters._specs.pop("validation_data_name", None)
+ if hasattr(self.hyperparameters, "validation_data_path"):
+ delattr(self.hyperparameters, "validation_data_path")
+ self.hyperparameters._specs.pop("validation_data_path", None)
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DPOTrainer.train")
- def train(self,
- training_dataset: Optional[Union[str, DataSet]] = None,
- validation_dataset: Optional[Union[str, DataSet]] = None,
- wait: bool = True,
- wait_timeout: Optional[int] = None,
- poll: int = 5):
+ def train(
+ self,
+ training_dataset: Optional[Union[str, DataSet]] = None,
+ validation_dataset: Optional[Union[str, DataSet]] = None,
+ wait: bool = True,
+ wait_timeout: Optional[int] = None,
+ poll: int = 5,
+ ):
"""Execute the DPO training job.
Parameters:
@@ -215,24 +228,26 @@ def train(self,
logger.info(f"Training Job Name: {current_training_job_name}")
print(f"Training Job Name: {current_training_job_name}")
- #data
- input_data_config = _create_input_data_config(training_dataset or self.training_dataset,
- validation_dataset or self.validation_dataset
- )
+ # data
+ input_data_config = _create_input_data_config(
+ training_dataset or self.training_dataset, validation_dataset or self.validation_dataset
+ )
channels = _convert_input_data_to_channels(input_data_config)
output_config = _create_output_config(
s3_output_path=self.s3_output_path,
sagemaker_session=sagemaker_session,
- kms_key_id=self.kms_key_id
+ kms_key_id=self.kms_key_id,
)
- serverless_config = _create_serverless_config(model_arn=self._model_arn,
- customization_technique=CustomizationTechnique.DPO.value,
- training_type=self.training_type,
- accept_eula=self.accept_eula,
- job_type=JOB_TYPE
- )
+ serverless_config = _create_serverless_config(
+ model_arn=self._model_arn,
+ customization_technique=CustomizationTechnique.DPO.value,
+ training_type=self.training_type,
+ accept_eula=self.accept_eula,
+ sequence_length=self.sequence_length,
+ job_type=JOB_TYPE,
+ )
mlflow_config = _create_mlflow_config(
sagemaker_session,
@@ -247,7 +262,7 @@ def train(self,
model_package_config = _create_model_package_config(
model_package_group_name=self.model_package_group,
model=self.model,
- sagemaker_session=sagemaker_session
+ sagemaker_session=sagemaker_session,
)
vpc_config = self.networking if self.networking else None
@@ -268,7 +283,7 @@ def train(self,
"region": sagemaker_session.boto_session.region_name,
"tags": tags,
}
-
+
# Only pass stopping_condition if explicitly provided by user
if self.stopping_condition is not None:
create_args["stopping_condition"] = self.stopping_condition
@@ -282,15 +297,15 @@ def train(self,
if wait:
from sagemaker.train.common_utils.trainer_wait import wait as _wait
from sagemaker.core.utils.exceptions import TimeoutExceededError
- try :
+
+ try:
wait_kwargs = {}
if wait_timeout is not None:
- wait_kwargs['timeout'] = wait_timeout
- wait_kwargs['poll'] = poll
+ wait_kwargs["timeout"] = wait_timeout
+ wait_kwargs["poll"] = poll
_wait(training_job, **wait_kwargs)
except TimeoutExceededError as e:
logger.error("Error: %s", e)
self.latest_training_job = training_job
return training_job
-
diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py
index f2d8460989..09084359f1 100644
--- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py
+++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py
@@ -2,7 +2,12 @@
import logging
from sagemaker.train.base_trainer import BaseTrainer
from sagemaker.train.common import TrainingType, CustomizationTechnique, JOB_TYPE
-from sagemaker.core.resources import TrainingJob, ModelPackageGroup, MlflowTrackingServer, ModelPackage
+from sagemaker.core.resources import (
+ TrainingJob,
+ ModelPackageGroup,
+ MlflowTrackingServer,
+ ModelPackage,
+)
from sagemaker.core.shapes import VpcConfig
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
@@ -23,7 +28,7 @@
_create_mlflow_config,
_create_model_package_config,
_validate_eula_for_gated_model,
- _validate_hyperparameter_values
+ _validate_hyperparameter_values,
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
@@ -60,19 +65,19 @@ class RLAIFTrainer(BaseTrainer):
reward_model_id="reward-model-id",
reward_prompt="Rate the helpfulness of this response on a scale of 1-10"
)
-
+
# Create training job (non-blocking)
training_job = trainer.train(
training_dataset="s3://bucket/rlaif_data.jsonl",
wait=False
)
-
+
# Wait for completion
training_job.wait()
-
+
# Refresh job status
training_job.refresh()
-
+
# Get the fine-tuned model package ARN
model_package_arn = training_job.output_model_package_arn
@@ -114,6 +119,10 @@ class RLAIFTrainer(BaseTrainer):
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, uses SageMaker service default (24 hours for serverless training).
+ sequence_length (Optional[str]):
+ The sequence length for the training job. Valid values are
+ "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K".
+ If not specified, the service will use default recipe selection behavior.
"""
def __init__(
@@ -135,6 +144,7 @@ def __init__(
networking: Optional[VpcConfig] = None,
accept_eula: bool = False,
stopping_condition: Optional[StoppingCondition] = None,
+ sequence_length: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -143,8 +153,9 @@ def __init__(
self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session)
self.training_type = training_type
- self.model_package_group = _validate_and_resolve_model_package_group(model,
- model_package_group)
+ self.model_package_group = _validate_and_resolve_model_package_group(
+ model, model_package_group
+ )
self.reward_model_id = self._validate_reward_model_id(reward_model_id)
self.reward_prompt = reward_prompt
self.mlflow_resource_arn = mlflow_resource_arn
@@ -156,18 +167,23 @@ def __init__(
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition
+ self.sequence_length = sequence_length
# Initialize fine-tuning options with beta session fallback
- self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
- CustomizationTechnique.RLAIF.value,
- self.training_type,
- self.sagemaker_session or TrainDefaults.get_sagemaker_session(
- sagemaker_session=self.sagemaker_session
- ))
-
+ self.hyperparameters, self._model_arn, is_gated_model = (
+ _get_fine_tuning_options_and_model_arn(
+ self._model_name,
+ CustomizationTechnique.RLAIF.value,
+ self.training_type,
+ self.sagemaker_session
+ or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session),
+ sequence_length=self.sequence_length,
+ )
+ )
+
# Validate and set EULA acceptance
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
-
+
# Process reward_prompt parameter
self._process_hyperparameters()
@@ -181,23 +197,33 @@ def _validate_reward_model_id(self, reward_model_id):
f"Invalid reward_model_id '{reward_model_id}'. "
f"Available models are: {list(_ALLOWED_REWARD_MODEL_IDS.keys())}"
)
-
+
# Check region compatibility
- session = self.sagemaker_session if hasattr(self, 'sagemaker_session') and self.sagemaker_session else TrainDefaults.get_sagemaker_session()
+ session = (
+ self.sagemaker_session
+ if hasattr(self, "sagemaker_session") and self.sagemaker_session
+ else TrainDefaults.get_sagemaker_session()
+ )
current_region = session.boto_region_name
allowed_regions = _ALLOWED_REWARD_MODEL_IDS[reward_model_id]
-
+
if current_region not in allowed_regions:
raise ValueError(
f"Reward model '{reward_model_id}' is not available in region '{current_region}'. "
f"Available regions for this model: {allowed_regions}"
)
-
+
return reward_model_id
-
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train")
- def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5):
+ def train(
+ self,
+ training_dataset: Optional[Union[str, DataSet]] = None,
+ validation_dataset: Optional[Union[str, DataSet]] = None,
+ wait: bool = True,
+ wait_timeout: Optional[int] = None,
+ poll: int = 5,
+ ):
"""Execute the RLAIF training job.
Parameters:
@@ -229,26 +255,28 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
logger.info(f"Training Job Name: {current_training_job_name}")
- #data
- input_data_config = _create_input_data_config(training_dataset or self.training_dataset,
- validation_dataset or self.validation_dataset
- )
+ # data
+ input_data_config = _create_input_data_config(
+ training_dataset or self.training_dataset, validation_dataset or self.validation_dataset
+ )
channels = _convert_input_data_to_channels(input_data_config)
output_config = _create_output_config(
s3_output_path=self.s3_output_path,
sagemaker_session=sagemaker_session,
- kms_key_id=self.kms_key_id
+ kms_key_id=self.kms_key_id,
)
- evaluator_arn = getattr(self, '_evaluator_arn', None)
- serverless_config = _create_serverless_config(model_arn=self._model_arn,
- customization_technique=CustomizationTechnique.RLAIF.value,
- training_type=self.training_type,
- accept_eula=self.accept_eula,
- evaluator_arn=evaluator_arn,
- job_type=JOB_TYPE
- )
+ evaluator_arn = getattr(self, "_evaluator_arn", None)
+ serverless_config = _create_serverless_config(
+ model_arn=self._model_arn,
+ customization_technique=CustomizationTechnique.RLAIF.value,
+ training_type=self.training_type,
+ accept_eula=self.accept_eula,
+ evaluator_arn=evaluator_arn,
+ sequence_length=self.sequence_length,
+ job_type=JOB_TYPE,
+ )
mlflow_config = _create_mlflow_config(
sagemaker_session,
@@ -264,7 +292,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
model_package_config = _create_model_package_config(
model_package_group_name=self.model_package_group,
model=self.model,
- sagemaker_session=sagemaker_session
+ sagemaker_session=sagemaker_session,
)
vpc_config = self.networking if self.networking else None
@@ -285,7 +313,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
"region": sagemaker_session.boto_session.region_name,
"tags": tags,
}
-
+
# Only pass stopping_condition if explicitly provided by user
if self.stopping_condition is not None:
create_args["stopping_condition"] = self.stopping_condition
@@ -299,11 +327,12 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
if wait:
from sagemaker.train.common_utils.trainer_wait import wait as _wait
from sagemaker.core.utils.exceptions import TimeoutExceededError
- try :
+
+ try:
wait_kwargs = {}
if wait_timeout is not None:
- wait_kwargs['timeout'] = wait_timeout
- wait_kwargs['poll'] = poll
+ wait_kwargs["timeout"] = wait_timeout
+ wait_kwargs["poll"] = poll
_wait(training_job, **wait_kwargs)
except TimeoutExceededError as e:
logger.error("Error: %s", e)
@@ -313,27 +342,31 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
def _process_hyperparameters(self):
"""Update hyperparameters based on constructor inputs and process reward_prompt."""
- if not self.hyperparameters or not hasattr(self.hyperparameters, '_specs') or not self.hyperparameters._specs:
+ if (
+ not self.hyperparameters
+ or not hasattr(self.hyperparameters, "_specs")
+ or not self.hyperparameters._specs
+ ):
return
-
+
# Remove keys that are handled by constructor inputs
- if hasattr(self.hyperparameters, 'output_path'):
- delattr(self.hyperparameters, 'output_path')
- self.hyperparameters._specs.pop('output_path', None)
- if hasattr(self.hyperparameters, 'data_path'):
- delattr(self.hyperparameters, 'data_path')
- self.hyperparameters._specs.pop('data_path', None)
- if hasattr(self.hyperparameters, 'validation_data_path'):
- delattr(self.hyperparameters, 'validation_data_path')
- self.hyperparameters._specs.pop('validation_data_path', None)
-
+ if hasattr(self.hyperparameters, "output_path"):
+ delattr(self.hyperparameters, "output_path")
+ self.hyperparameters._specs.pop("output_path", None)
+ if hasattr(self.hyperparameters, "data_path"):
+ delattr(self.hyperparameters, "data_path")
+ self.hyperparameters._specs.pop("data_path", None)
+ if hasattr(self.hyperparameters, "validation_data_path"):
+ delattr(self.hyperparameters, "validation_data_path")
+ self.hyperparameters._specs.pop("validation_data_path", None)
+
# Update judge_model_id if reward_model_id is provided
- if hasattr(self, 'reward_model_id') and self.reward_model_id:
+ if hasattr(self, "reward_model_id") and self.reward_model_id:
judge_model_value = f"bedrock/{self.reward_model_id}"
self.hyperparameters.judge_model_id = judge_model_value
-
+
# Process reward_prompt parameter
- if hasattr(self, 'reward_prompt') and self.reward_prompt:
+ if hasattr(self, "reward_prompt") and self.reward_prompt:
if isinstance(self.reward_prompt, str):
if self.reward_prompt.startswith("Builtin"):
# Handle builtin reward prompts
@@ -343,9 +376,9 @@ def _process_hyperparameters(self):
self._process_non_builtin_reward_prompt()
else:
# Handle evaluator object
- if hasattr(self.hyperparameters, 'judge_prompt_template'):
- delattr(self.hyperparameters, 'judge_prompt_template')
- self.hyperparameters._specs.pop('judge_prompt_template', None)
+ if hasattr(self.hyperparameters, "judge_prompt_template"):
+ delattr(self.hyperparameters, "judge_prompt_template")
+ self.hyperparameters._specs.pop("judge_prompt_template", None)
evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt")
self._evaluator_arn = evaluator_arn
@@ -353,10 +386,10 @@ def _process_hyperparameters(self):
def _process_non_builtin_reward_prompt(self):
"""Process non-builtin reward prompt (ARN or hub content name)."""
# Remove judge_prompt_template for non-builtin prompts
- if hasattr(self.hyperparameters, 'judge_prompt_template'):
- delattr(self.hyperparameters, 'judge_prompt_template')
- self.hyperparameters._specs.pop('judge_prompt_template', None)
-
+ if hasattr(self.hyperparameters, "judge_prompt_template"):
+ delattr(self.hyperparameters, "judge_prompt_template")
+ self.hyperparameters._specs.pop("judge_prompt_template", None)
+
if self.reward_prompt.startswith("arn:aws:sagemaker:"):
# Validate and assign ARN
evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt")
@@ -364,39 +397,39 @@ def _process_non_builtin_reward_prompt(self):
else:
try:
session = TrainDefaults.get_sagemaker_session(
- sagemaker_session=self.sagemaker_session
- )
+ sagemaker_session=self.sagemaker_session
+ )
hub_content = _get_hub_content_metadata(
hub_name=get_sagemaker_hub_name(),
hub_content_type="JsonDoc",
hub_content_name=self.reward_prompt,
session=session.boto_session,
- region=session.boto_session.region_name
+ region=session.boto_session.region_name,
)
# Store ARN for evaluator_arn
self._evaluator_arn = hub_content.hub_content_arn
except Exception as e:
- raise ValueError(f"Custom prompt '{self.reward_prompt}' not found in HubContent: {e}")
-
-
+ raise ValueError(
+ f"Custom prompt '{self.reward_prompt}' not found in HubContent: {e}"
+ )
def _update_judge_prompt_template_direct(self, reward_prompt):
"""Update judge_prompt_template based on Builtin reward function."""
# Get available templates from hyperparameters specs
- judge_prompt_spec = self.hyperparameters._specs.get('judge_prompt_template', {})
- available_templates = judge_prompt_spec.get('enum', [])
-
+ judge_prompt_spec = self.hyperparameters._specs.get("judge_prompt_template", {})
+ available_templates = judge_prompt_spec.get("enum", [])
+
if not available_templates:
# If no enum found, use the current value as the only available option
- current_value = getattr(self.hyperparameters, 'judge_prompt_template', None)
+ current_value = getattr(self.hyperparameters, "judge_prompt_template", None)
if current_value:
available_templates = [current_value]
else:
return
-
+
# Extract template name after "Builtin." and convert to lowercase
template_name = reward_prompt.split(".", 1)[1].lower()
-
+
# Find matching template by extracting filename without extension
matching_template = None
for template in available_templates:
@@ -404,14 +437,15 @@ def _update_judge_prompt_template_direct(self, reward_prompt):
if template_filename == template_name:
matching_template = template
break
-
+
if matching_template:
self.hyperparameters.judge_prompt_template = matching_template
else:
- available_options = [f"Builtin.{t.split('/')[-1].replace('.jinja', '')}" for t in available_templates]
+ available_options = [
+ f"Builtin.{t.split('/')[-1].replace('.jinja', '')}" for t in available_templates
+ ]
raise ValueError(
f"Selected reward function option '{reward_prompt}' is not available. "
f"Choose one from the available options: {available_options}. "
f"Example: reward_prompt='Builtin.summarize'"
)
-
diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py
index 333a93fc55..3abcbbf47e 100644
--- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py
+++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py
@@ -2,7 +2,12 @@
import logging
from sagemaker.train.base_trainer import BaseTrainer
from sagemaker.train.common import TrainingType, CustomizationTechnique, JOB_TYPE
-from sagemaker.core.resources import TrainingJob, ModelPackageGroup, MlflowTrackingServer, ModelPackage
+from sagemaker.core.resources import (
+ TrainingJob,
+ ModelPackageGroup,
+ MlflowTrackingServer,
+ ModelPackage,
+)
from sagemaker.core.shapes import VpcConfig
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
@@ -21,7 +26,7 @@
_create_mlflow_config,
_create_model_package_config,
_validate_eula_for_gated_model,
- _validate_hyperparameter_values
+ _validate_hyperparameter_values,
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
@@ -56,19 +61,19 @@ class RLVRTrainer(BaseTrainer):
model_package_group="my-rlvr-models",
custom_reward_function="arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/my-evaluator/1.0"
)
-
+
# Create training job (non-blocking)
training_job = trainer.train(
training_dataset="s3://bucket/rlvr_data.jsonl",
wait=False
)
-
+
# Wait for completion
training_job.wait()
-
+
# Refresh job status
training_job.refresh()
-
+
# Get the fine-tuned model package ARN
model_package_arn = training_job.output_model_package_arn
@@ -106,6 +111,10 @@ class RLVRTrainer(BaseTrainer):
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, uses SageMaker service default (24 hours for serverless training).
+ sequence_length (Optional[str]):
+ The sequence length for the training job. Valid values are
+ "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K".
+ If not specified, the service will use default recipe selection behavior.
"""
def __init__(
@@ -126,6 +135,7 @@ def __init__(
networking: Optional[VpcConfig] = None,
accept_eula: bool = False,
stopping_condition: Optional[StoppingCondition] = None,
+ sequence_length: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -134,8 +144,9 @@ def __init__(
self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session)
self.training_type = training_type
- self.model_package_group = _validate_and_resolve_model_package_group(model,
- model_package_group)
+ self.model_package_group = _validate_and_resolve_model_package_group(
+ model, model_package_group
+ )
self.custom_reward_function = custom_reward_function
self.mlflow_resource_arn = mlflow_resource_arn
self.mlflow_experiment_name = mlflow_experiment_name
@@ -146,18 +157,23 @@ def __init__(
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition
+ self.sequence_length = sequence_length
# Initialize fine-tuning options with beta session fallback
- self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
- CustomizationTechnique.RLVR.value,
- self.training_type,
- self.sagemaker_session or TrainDefaults.get_sagemaker_session(
- sagemaker_session=self.sagemaker_session
- ))
-
+ self.hyperparameters, self._model_arn, is_gated_model = (
+ _get_fine_tuning_options_and_model_arn(
+ self._model_name,
+ CustomizationTechnique.RLVR.value,
+ self.training_type,
+ self.sagemaker_session
+ or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session),
+ sequence_length=self.sequence_length,
+ )
+ )
+
# Remove constructor-handled hyperparameters
self._process_hyperparameters()
-
+
# Validate and set EULA acceptance
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
@@ -165,28 +181,34 @@ def _process_hyperparameters(self):
"""Remove hyperparameter keys that are handled by constructor inputs."""
if self.hyperparameters:
# Remove keys that are handled by constructor inputs
- if hasattr(self.hyperparameters, 'data_s3_path'):
- delattr(self.hyperparameters, 'data_s3_path')
- self.hyperparameters._specs.pop('data_s3_path', None)
- if hasattr(self.hyperparameters, 'reward_lambda_arn'):
- delattr(self.hyperparameters, 'reward_lambda_arn')
- self.hyperparameters._specs.pop('reward_lambda_arn', None)
- if hasattr(self.hyperparameters, 'preset_reward_function'):
- delattr(self.hyperparameters, 'preset_reward_function')
- self.hyperparameters._specs.pop('preset_reward_function', None)
- if hasattr(self.hyperparameters, 'data_path'):
- delattr(self.hyperparameters, 'data_path')
- self.hyperparameters._specs.pop('data_path', None)
- if hasattr(self.hyperparameters, 'validation_data_path'):
- delattr(self.hyperparameters, 'validation_data_path')
- self.hyperparameters._specs.pop('validation_data_path', None)
- if hasattr(self.hyperparameters, 'output_path'):
- delattr(self.hyperparameters, 'output_path')
- self.hyperparameters._specs.pop('output_path', None)
+ if hasattr(self.hyperparameters, "data_s3_path"):
+ delattr(self.hyperparameters, "data_s3_path")
+ self.hyperparameters._specs.pop("data_s3_path", None)
+ if hasattr(self.hyperparameters, "reward_lambda_arn"):
+ delattr(self.hyperparameters, "reward_lambda_arn")
+ self.hyperparameters._specs.pop("reward_lambda_arn", None)
+ if hasattr(self.hyperparameters, "preset_reward_function"):
+ delattr(self.hyperparameters, "preset_reward_function")
+ self.hyperparameters._specs.pop("preset_reward_function", None)
+ if hasattr(self.hyperparameters, "data_path"):
+ delattr(self.hyperparameters, "data_path")
+ self.hyperparameters._specs.pop("data_path", None)
+ if hasattr(self.hyperparameters, "validation_data_path"):
+ delattr(self.hyperparameters, "validation_data_path")
+ self.hyperparameters._specs.pop("validation_data_path", None)
+ if hasattr(self.hyperparameters, "output_path"):
+ delattr(self.hyperparameters, "output_path")
+ self.hyperparameters._specs.pop("output_path", None)
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLVRTrainer.train")
- def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
- validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5):
+ def train(
+ self,
+ training_dataset: Optional[Union[str, DataSet]] = None,
+ validation_dataset: Optional[Union[str, DataSet]] = None,
+ wait: bool = True,
+ wait_timeout: Optional[int] = None,
+ poll: int = 5,
+ ):
"""Execute the RLVR training job.
Parameters:
@@ -219,27 +241,33 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
logger.info(f"Training Job Name: {current_training_job_name}")
- #data
- input_data_config = _create_input_data_config(training_dataset or self.training_dataset,
- validation_dataset or self.validation_dataset
- )
+ # data
+ input_data_config = _create_input_data_config(
+ training_dataset or self.training_dataset, validation_dataset or self.validation_dataset
+ )
channels = _convert_input_data_to_channels(input_data_config)
output_config = _create_output_config(
s3_output_path=self.s3_output_path,
sagemaker_session=sagemaker_session,
- kms_key_id=self.kms_key_id
+ kms_key_id=self.kms_key_id,
)
# Extract and validate evaluator ARN
- evaluator_arn = _extract_evaluator_arn(self.custom_reward_function) if self.custom_reward_function else None
- serverless_config = _create_serverless_config(model_arn=self._model_arn,
- customization_technique=CustomizationTechnique.RLVR.value,
- training_type=self.training_type,
- accept_eula=self.accept_eula,
- evaluator_arn=evaluator_arn,
- job_type=JOB_TYPE
- )
+ evaluator_arn = (
+ _extract_evaluator_arn(self.custom_reward_function)
+ if self.custom_reward_function
+ else None
+ )
+ serverless_config = _create_serverless_config(
+ model_arn=self._model_arn,
+ customization_technique=CustomizationTechnique.RLVR.value,
+ training_type=self.training_type,
+ accept_eula=self.accept_eula,
+ evaluator_arn=evaluator_arn,
+ sequence_length=self.sequence_length,
+ job_type=JOB_TYPE,
+ )
mlflow_config = _create_mlflow_config(
sagemaker_session,
mlflow_resource_arn=self.mlflow_resource_arn,
@@ -248,14 +276,14 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
)
final_hyperparameters = self.hyperparameters.to_dict()
-
+
# Validate hyperparameter values
_validate_hyperparameter_values(final_hyperparameters)
model_package_config = _create_model_package_config(
model_package_group_name=self.model_package_group,
model=self.model,
- sagemaker_session=sagemaker_session
+ sagemaker_session=sagemaker_session,
)
vpc_config = self.networking if self.networking else None
@@ -276,7 +304,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
"region": sagemaker_session.boto_session.region_name,
"tags": tags,
}
-
+
# Only pass stopping_condition if explicitly provided by user
if self.stopping_condition is not None:
create_args["stopping_condition"] = self.stopping_condition
@@ -290,11 +318,12 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
if wait:
from sagemaker.train.common_utils.trainer_wait import wait as _wait
from sagemaker.core.utils.exceptions import TimeoutExceededError
+
try:
wait_kwargs = {}
if wait_timeout is not None:
- wait_kwargs['timeout'] = wait_timeout
- wait_kwargs['poll'] = poll
+ wait_kwargs["timeout"] = wait_timeout
+ wait_kwargs["poll"] = poll
_wait(training_job, **wait_kwargs)
except TimeoutExceededError as e:
logger.error("Error: %s", e)
diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py
index 233f169d0f..9d47e9742a 100644
--- a/sagemaker-train/src/sagemaker/train/sft_trainer.py
+++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py
@@ -20,7 +20,7 @@
_create_mlflow_config,
_create_model_package_config,
_validate_eula_for_gated_model,
- _validate_hyperparameter_values
+ _validate_hyperparameter_values,
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
@@ -55,19 +55,19 @@ class SFTTrainer(BaseTrainer):
model="meta-llama/Llama-2-7b-hf",
model_package_group="my-fine-tuned-models"
)
-
+
# Create training job (non-blocking)
training_job = trainer.train(
training_dataset="s3://bucket/train.jsonl",
wait=False
)
-
+
# Wait for completion
training_job.wait()
-
+
# Refresh job status
training_job.refresh()
-
+
# Get the fine-tuned model artifacts ARN
model_package_arn = training_job.output_model_package_arn
@@ -102,6 +102,10 @@ class SFTTrainer(BaseTrainer):
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, uses SageMaker service default (24 hours for serverless training).
+ sequence_length (Optional[str]):
+ The sequence length for the training job. Valid values are
+ "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K".
+ If not specified, the service will use default recipe selection behavior.
"""
def __init__(
@@ -119,16 +123,18 @@ def __init__(
networking: Optional[VpcConfig] = None,
accept_eula: Optional[bool] = False,
stopping_condition: Optional[StoppingCondition] = None,
+ sequence_length: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
-
+
# Resolve model and model name
self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session)
self.training_type = training_type
- self.model_package_group = _validate_and_resolve_model_package_group(model,
- model_package_group)
+ self.model_package_group = _validate_and_resolve_model_package_group(
+ model, model_package_group
+ )
self.mlflow_resource_arn = mlflow_resource_arn
self.mlflow_experiment_name = mlflow_experiment_name
self.mlflow_run_name = mlflow_run_name
@@ -138,18 +144,23 @@ def __init__(
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition
+ self.sequence_length = sequence_length
# Initialize fine-tuning options with beta session fallback
- self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
- CustomizationTechnique.SFT.value,
- self.training_type,
- self.sagemaker_session or TrainDefaults.get_sagemaker_session(
- sagemaker_session=self.sagemaker_session
- ))
-
+ self.hyperparameters, self._model_arn, is_gated_model = (
+ _get_fine_tuning_options_and_model_arn(
+ self._model_name,
+ CustomizationTechnique.SFT.value,
+ self.training_type,
+ self.sagemaker_session
+ or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session),
+ sequence_length=self.sequence_length,
+ )
+ )
+
# Process hyperparameters
self._process_hyperparameters()
-
+
# Validate and set EULA acceptance
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
@@ -157,30 +168,37 @@ def _process_hyperparameters(self):
"""Remove hyperparameter keys that are handled by constructor inputs."""
if self.hyperparameters:
# Remove keys that are handled by constructor inputs
- if hasattr(self.hyperparameters, 'data_path'):
- delattr(self.hyperparameters, 'data_path')
- self.hyperparameters._specs.pop('data_path', None)
- if hasattr(self.hyperparameters, 'output_path'):
- delattr(self.hyperparameters, 'output_path')
- self.hyperparameters._specs.pop('output_path', None)
- if hasattr(self.hyperparameters, 'data_s3_path'):
- delattr(self.hyperparameters, 'data_s3_path')
- self.hyperparameters._specs.pop('data_s3_path', None)
- if hasattr(self.hyperparameters, 'output_s3_path'):
- delattr(self.hyperparameters, 'output_s3_path')
- self.hyperparameters._specs.pop('output_s3_path', None)
- if hasattr(self.hyperparameters, 'training_data_name'):
- delattr(self.hyperparameters, 'training_data_name')
- self.hyperparameters._specs.pop('training_data_name', None)
- if hasattr(self.hyperparameters, 'validation_data_name'):
- delattr(self.hyperparameters, 'validation_data_name')
- self.hyperparameters._specs.pop('validation_data_name', None)
- if hasattr(self.hyperparameters, 'validation_data_path'):
- delattr(self.hyperparameters, 'validation_data_path')
- self.hyperparameters._specs.pop('validation_data_path', None)
+ if hasattr(self.hyperparameters, "data_path"):
+ delattr(self.hyperparameters, "data_path")
+ self.hyperparameters._specs.pop("data_path", None)
+ if hasattr(self.hyperparameters, "output_path"):
+ delattr(self.hyperparameters, "output_path")
+ self.hyperparameters._specs.pop("output_path", None)
+ if hasattr(self.hyperparameters, "data_s3_path"):
+ delattr(self.hyperparameters, "data_s3_path")
+ self.hyperparameters._specs.pop("data_s3_path", None)
+ if hasattr(self.hyperparameters, "output_s3_path"):
+ delattr(self.hyperparameters, "output_s3_path")
+ self.hyperparameters._specs.pop("output_s3_path", None)
+ if hasattr(self.hyperparameters, "training_data_name"):
+ delattr(self.hyperparameters, "training_data_name")
+ self.hyperparameters._specs.pop("training_data_name", None)
+ if hasattr(self.hyperparameters, "validation_data_name"):
+ delattr(self.hyperparameters, "validation_data_name")
+ self.hyperparameters._specs.pop("validation_data_name", None)
+ if hasattr(self.hyperparameters, "validation_data_path"):
+ delattr(self.hyperparameters, "validation_data_path")
+ self.hyperparameters._specs.pop("validation_data_path", None)
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="SFTTrainer.train")
- def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5):
+ def train(
+ self,
+ training_dataset: Optional[Union[str, DataSet]] = None,
+ validation_dataset: Optional[Union[str, DataSet]] = None,
+ wait: bool = True,
+ wait_timeout: Optional[int] = None,
+ poll: int = 5,
+ ):
"""Execute the SFT training job.
Parameters:
@@ -213,24 +231,26 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
logger.info(f"Training Job Name: {current_training_job_name}")
- #data
- input_data_config = _create_input_data_config(training_dataset or self.training_dataset,
- validation_dataset or self.validation_dataset
- )
+ # data
+ input_data_config = _create_input_data_config(
+ training_dataset or self.training_dataset, validation_dataset or self.validation_dataset
+ )
channels = _convert_input_data_to_channels(input_data_config)
output_config = _create_output_config(
s3_output_path=self.s3_output_path,
sagemaker_session=sagemaker_session,
- kms_key_id=self.kms_key_id
+ kms_key_id=self.kms_key_id,
)
- serverless_config = _create_serverless_config(model_arn=self._model_arn,
- customization_technique=CustomizationTechnique.SFT.value,
- training_type=self.training_type,
- accept_eula=self.accept_eula,
- job_type=JOB_TYPE
- )
+ serverless_config = _create_serverless_config(
+ model_arn=self._model_arn,
+ customization_technique=CustomizationTechnique.SFT.value,
+ training_type=self.training_type,
+ accept_eula=self.accept_eula,
+ sequence_length=self.sequence_length,
+ job_type=JOB_TYPE,
+ )
mlflow_config = _create_mlflow_config(
sagemaker_session,
mlflow_resource_arn=self.mlflow_resource_arn,
@@ -239,14 +259,14 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
)
final_hyperparameters = self.hyperparameters.to_dict()
-
+
# Validate hyperparameter values
_validate_hyperparameter_values(final_hyperparameters)
model_package_config = _create_model_package_config(
model_package_group_name=self.model_package_group,
model=self.model,
- sagemaker_session=sagemaker_session
+ sagemaker_session=sagemaker_session,
)
vpc_config = self.networking if self.networking else None
@@ -267,7 +287,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
"region": sagemaker_session.boto_session.region_name,
"tags": tags,
}
-
+
# Only pass stopping_condition if explicitly provided by user
if self.stopping_condition is not None:
create_args["stopping_condition"] = self.stopping_condition
@@ -281,16 +301,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
if wait:
from sagemaker.train.common_utils.trainer_wait import wait as _wait
from sagemaker.core.utils.exceptions import TimeoutExceededError
- try :
+
+ try:
wait_kwargs = {}
if wait_timeout is not None:
- wait_kwargs['timeout'] = wait_timeout
- wait_kwargs['poll'] = poll
+ wait_kwargs["timeout"] = wait_timeout
+ wait_kwargs["poll"] = poll
_wait(training_job, **wait_kwargs)
except TimeoutExceededError as e:
logger.error("Error: %s", e)
self._latest_training_job = training_job
return training_job
-
-
diff --git a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py
index 93be84a738..b09b608ae3 100644
--- a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py
+++ b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py
@@ -21,11 +21,12 @@
from sagemaker.train.sft_trainer import SFTTrainer
from sagemaker.train.common import TrainingType
+
@pytest.mark.gpu_intensive
def test_sft_trainer_lora_complete_workflow(sagemaker_session):
"""Test complete SFT training workflow with LORA."""
unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}"
-
+
sft_trainer = SFTTrainer(
model="meta-textgeneration-llama-3-2-1b-instruct",
training_type=TrainingType.LORA,
@@ -35,27 +36,27 @@ def test_sft_trainer_lora_complete_workflow(sagemaker_session):
accept_eula=True,
base_job_name=f"sft-lora-integ-{unique_id}",
)
-
+
# Create training job
training_job = sft_trainer.train(wait=False)
-
+
# Manual wait loop to avoid resource_config issue
max_wait_time = 3600 # 1 hour timeout
- poll_interval = 30 # Check every 30 seconds
+ poll_interval = 30 # Check every 30 seconds
start_time = time.time()
-
+
while time.time() - start_time < max_wait_time:
training_job.refresh()
status = training_job.training_job_status
-
+
if status in ["Completed", "Failed", "Stopped"]:
break
-
+
time.sleep(poll_interval)
-
+
# Verify job completed successfully
assert training_job.training_job_status == "Completed"
- assert hasattr(training_job, 'output_model_package_arn')
+ assert hasattr(training_job, "output_model_package_arn")
assert training_job.output_model_package_arn is not None
@@ -73,26 +74,26 @@ def test_sft_trainer_with_validation_dataset(sagemaker_session):
accept_eula=True,
base_job_name=f"sft-val-integ-{unique_id}",
)
-
+
training_job = sft_trainer.train(wait=False)
-
+
# Manual wait loop
max_wait_time = 3600
poll_interval = 30
start_time = time.time()
-
+
while time.time() - start_time < max_wait_time:
training_job.refresh()
status = training_job.training_job_status
-
+
if status in ["Completed", "Failed", "Stopped"]:
break
-
+
time.sleep(poll_interval)
-
+
# Verify job completed successfully
assert training_job.training_job_status == "Completed"
- assert hasattr(training_job, 'output_model_package_arn')
+ assert hasattr(training_job, "output_model_package_arn")
@pytest.mark.gpu_intensive
@@ -104,7 +105,7 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1):
unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}"
sft_trainer_nova = SFTTrainer(
model="nova-textgeneration-lite-v2",
- training_type=TrainingType.LORA,
+ training_type=TrainingType.LORA,
model_package_group="sdk-test-finetuned-models",
mlflow_experiment_name="test-nova-finetuned-models-exp",
mlflow_run_name="test-nova-finetuned-models-run",
@@ -113,25 +114,61 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1):
sagemaker_session=sagemaker_session_us_east_1,
base_job_name=f"sft-nova-integ-{unique_id}",
)
-
+
# Create training job
training_job = sft_trainer_nova.train(wait=False)
-
+
# Manual wait loop
max_wait_time = 3600 # 1 hour timeout
- poll_interval = 30 # Check every 30 seconds
+ poll_interval = 30 # Check every 30 seconds
start_time = time.time()
-
+
while time.time() - start_time < max_wait_time:
training_job.refresh()
status = training_job.training_job_status
-
+
if status in ["Completed", "Failed", "Stopped"]:
break
-
+
time.sleep(poll_interval)
-
+
# Verify job completed successfully
assert training_job.training_job_status == "Completed"
- assert hasattr(training_job, 'output_model_package_arn')
+ assert hasattr(training_job, "output_model_package_arn")
+ assert training_job.output_model_package_arn is not None
+
+
+@pytest.mark.gpu_intensive
+def test_sft_trainer_lora_with_sequence_length(sagemaker_session):
+ """Test SFT training workflow with LORA and sequence_length specified."""
+ unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}"
+
+ sft_trainer = SFTTrainer(
+ model="meta-textgeneration-llama-3-2-1b-instruct",
+ training_type=TrainingType.LORA,
+ model_package_group="arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models",
+ training_dataset="s3://mc-flows-sdk-testing/input_data/sft/sample_data_256_final.jsonl",
+ s3_output_path="s3://mc-flows-sdk-testing/output/",
+ accept_eula=True,
+ sequence_length="8K",
+ base_job_name=f"sft-seqlen-integ-{unique_id}",
+ )
+
+ training_job = sft_trainer.train(wait=False)
+
+ max_wait_time = 3600
+ poll_interval = 30
+ start_time = time.time()
+
+ while time.time() - start_time < max_wait_time:
+ training_job.refresh()
+ status = training_job.training_job_status
+
+ if status in ["Completed", "Failed", "Stopped"]:
+ break
+
+ time.sleep(poll_interval)
+
+ assert training_job.training_job_status == "Completed"
+ assert hasattr(training_job, "output_model_package_arn")
assert training_job.output_model_package_arn is not None
diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py
index 7a63e36234..0f5ff6f92e 100644
--- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py
+++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py
@@ -27,7 +27,8 @@
_create_mlflow_config,
_validate_eula_for_gated_model,
_validate_model_region_availability,
- _validate_s3_path_exists
+ _validate_s3_path_exists,
+ _parse_context_length,
)
from sagemaker.core.resources import ModelPackage, ModelPackageGroup
from sagemaker.ai_registry.dataset import DataSet
@@ -37,47 +38,53 @@
class TestFinetuneUtils:
- @patch('sagemaker.train.common_utils.finetune_utils.boto3.client')
- @patch('sagemaker.train.common_utils.finetune_utils.Session')
+ @patch("sagemaker.train.common_utils.finetune_utils.boto3.client")
+ @patch("sagemaker.train.common_utils.finetune_utils.Session")
def test__get_beta_session(self, mock_session, mock_boto_client):
mock_client = Mock()
mock_boto_client.return_value = mock_client
mock_sagemaker_session = Mock()
mock_session.return_value = mock_sagemaker_session
-
+
result = _get_beta_session()
-
+
assert result == mock_sagemaker_session
mock_boto_client.assert_called_once()
def test_get_current_domain_id_with_studio_arn(self):
mock_session = Mock()
- mock_session.get_caller_identity_arn.return_value = "arn:aws:sts::123456789012:assumed-role/SageMakerStudioExecutionRole/SageMaker"
-
+ mock_session.get_caller_identity_arn.return_value = (
+ "arn:aws:sts::123456789012:assumed-role/SageMakerStudioExecutionRole/SageMaker"
+ )
+
result = _get_current_domain_id(mock_session)
-
+
assert result is None
def test_get_current_domain_id_with_domain_arn(self):
mock_session = Mock()
- mock_session.get_caller_identity_arn.return_value = "arn:aws:sagemaker:us-east-1:123456789012:user-profile/d-123456789/test-user"
-
+ mock_session.get_caller_identity_arn.return_value = (
+ "arn:aws:sagemaker:us-east-1:123456789012:user-profile/d-123456789/test-user"
+ )
+
result = _get_current_domain_id(mock_session)
-
+
assert result == "d-123456789"
def test__resolve_mlflow_resource_arn_with_provided_arn(self):
mock_session = Mock()
provided_arn = "arn:aws:mlflow:us-east-1:123456789012:tracking-server/test"
-
+
result = _resolve_mlflow_resource_arn(mock_session, provided_arn)
-
+
assert result == provided_arn
- @patch('sagemaker.train.common_utils.finetune_utils._get_current_domain_id')
- @patch('sagemaker.train.common_utils.finetune_utils._create_mlflow_app')
- @patch('sagemaker.core.resources.MlflowApp.get_all')
- def test__resolve_mlflow_resource_arn_creates_new_app(self, mock_get_all, mock_create_app, mock_get_domain):
+ @patch("sagemaker.train.common_utils.finetune_utils._get_current_domain_id")
+ @patch("sagemaker.train.common_utils.finetune_utils._create_mlflow_app")
+ @patch("sagemaker.core.resources.MlflowApp.get_all")
+ def test__resolve_mlflow_resource_arn_creates_new_app(
+ self, mock_get_all, mock_create_app, mock_get_domain
+ ):
mock_session = Mock()
mock_session.boto_session.region_name = "us-east-1"
mock_get_domain.return_value = "d-123456789"
@@ -85,13 +92,13 @@ def test__resolve_mlflow_resource_arn_creates_new_app(self, mock_get_all, mock_c
mock_app = Mock()
mock_app.arn = "arn:aws:mlflow:us-east-1:123456789012:tracking-server/new-app"
mock_create_app.return_value = mock_app
-
+
result = _resolve_mlflow_resource_arn(mock_session, None)
-
+
assert result == mock_app.arn
- @patch('sagemaker.train.common_utils.finetune_utils.TrainDefaults.get_role')
- @patch('sagemaker.core.resources.MlflowApp.create')
+ @patch("sagemaker.train.common_utils.finetune_utils.TrainDefaults.get_role")
+ @patch("sagemaker.core.resources.MlflowApp.create")
def test_create_mlflow_app_success(self, mock_create, mock_get_role):
mock_session = Mock()
mock_session.region_name = "us-east-1"
@@ -99,63 +106,67 @@ def test_create_mlflow_app_success(self, mock_create, mock_get_role):
mock_sts_client.get_caller_identity.return_value = {"Account": "123456789012"}
mock_s3_client = Mock()
mock_s3_client.list_objects_v2.return_value = {"Contents": [{"Key": "mlflow-artifacts/"}]}
-
+
def mock_client(service_name):
- if service_name == 'sts':
+ if service_name == "sts":
return mock_sts_client
- elif service_name == 's3':
+ elif service_name == "s3":
return mock_s3_client
return Mock()
-
+
mock_session.boto_session.client.side_effect = mock_client
mock_session.boto_session.region_name = "us-east-1"
mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role"
mock_app = Mock()
mock_app.status = "Created"
mock_create.return_value = mock_app
-
+
result = _create_mlflow_app(mock_session)
-
+
assert result == mock_app
mock_create.assert_called_once()
mock_app.refresh.assert_called()
- @patch('sagemaker.core.resources.MlflowApp.create')
+ @patch("sagemaker.core.resources.MlflowApp.create")
def test_create_mlflow_app_failure(self, mock_create):
mock_session = Mock()
mock_create.side_effect = Exception("Creation failed")
-
+
result = _create_mlflow_app(mock_session)
-
+
assert result is None
def test__validate_dataset_arn_valid(self):
valid_arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test-dataset/1.0"
-
+
# Should not raise exception
_validate_dataset_arn(valid_arn, "test_dataset")
def test__validate_dataset_arn_invalid(self):
invalid_arn = "invalid-arn"
-
- with pytest.raises(ValueError, match="test_dataset must be a valid SageMaker hub-content DataSet ARN"):
+
+ with pytest.raises(
+ ValueError, match="test_dataset must be a valid SageMaker hub-content DataSet ARN"
+ ):
_validate_dataset_arn(invalid_arn, "test_dataset")
def test_validate_evaluator_arn_valid(self):
valid_arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/test-evaluator/1.0"
-
+
# Should not raise exception
_validate_evaluator_arn(valid_arn, "test_evaluator")
def test_validate_evaluator_arn_invalid(self):
invalid_arn = "invalid-arn"
-
- with pytest.raises(ValueError, match="test_evaluator must be a valid SageMaker hub-content evaluator ARN"):
+
+ with pytest.raises(
+ ValueError, match="test_evaluator must be a valid SageMaker hub-content evaluator ARN"
+ ):
_validate_evaluator_arn(invalid_arn, "test_evaluator")
def test__validate_model_package_group_requirement_with_model_package(self):
model_package = Mock(spec=ModelPackage)
-
+
# Should not raise exception
_validate_model_package_group_requirement(model_package, None)
@@ -163,33 +174,37 @@ def test__validate_model_package_group_requirement_without_group_name(self):
with pytest.raises(ValueError, match="model_package_group_name must be provided"):
_validate_model_package_group_requirement("string-model", None)
- @patch('sagemaker.core.resources.ModelPackageGroup.get')
+ @patch("sagemaker.core.resources.ModelPackageGroup.get")
def test__resolve_model_package_group_arn_with_name(self, mock_get):
mock_session = Mock()
mock_session.boto_session.region_name = "us-east-1"
mock_group = Mock()
- mock_group.model_package_group_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group"
+ mock_group.model_package_group_arn = (
+ "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group"
+ )
mock_get.return_value = mock_group
-
+
result = _resolve_model_package_group_arn("test-group", mock_session)
-
+
assert result == mock_group.model_package_group_arn
def test__resolve_model_package_group_arn_with_arn(self):
mock_session = Mock()
arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group"
-
+
result = _resolve_model_package_group_arn(arn, mock_session)
-
+
assert result == arn
def test__resolve_model_package_group_arn_with_object(self):
mock_session = Mock()
mock_group = Mock(spec=ModelPackageGroup)
- mock_group.model_package_group_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group"
-
+ mock_group.model_package_group_arn = (
+ "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group"
+ )
+
result = _resolve_model_package_group_arn(mock_group, mock_session)
-
+
assert result == mock_group.model_package_group_arn
def test__get_default_s3_output_path(self):
@@ -198,50 +213,50 @@ def test__get_default_s3_output_path(self):
mock_sts_client.get_caller_identity.return_value = {"Account": "123456789012"}
mock_session.boto_session.client.return_value = mock_sts_client
mock_session.boto_session.region_name = "us-east-1"
-
+
result = _get_default_s3_output_path(mock_session)
-
+
assert result == "s3://sagemaker-us-east-1-123456789012/output"
def test__extract_dataset_source_s3_uri(self):
s3_uri = "s3://bucket/dataset"
-
+
result = _extract_dataset_source(s3_uri, "test_dataset")
-
+
assert result == s3_uri
- @patch('sagemaker.train.common_utils.finetune_utils._validate_dataset_arn')
+ @patch("sagemaker.train.common_utils.finetune_utils._validate_dataset_arn")
def test__extract_dataset_source_arn(self, mock_validate):
arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test/1.0"
-
+
result = _extract_dataset_source(arn, "test_dataset")
-
+
assert result == arn
mock_validate.assert_called_once_with(arn, "test_dataset")
def test__extract_dataset_source_dataset_object(self):
mock_dataset = Mock(spec=DataSet)
mock_dataset.arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test/1.0"
-
+
result = _extract_dataset_source(mock_dataset, "test_dataset")
-
+
assert result == mock_dataset.arn
- @patch('sagemaker.train.common_utils.finetune_utils._validate_evaluator_arn')
+ @patch("sagemaker.train.common_utils.finetune_utils._validate_evaluator_arn")
def test_extract_evaluator_arn_string(self, mock_validate):
arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/test/1.0"
-
+
result = _extract_evaluator_arn(arn, "test_evaluator")
-
+
assert result == arn
mock_validate.assert_called_once_with(arn, "test_evaluator")
def test_extract_evaluator_arn_object(self):
mock_evaluator = Mock()
mock_evaluator.arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/test/1.0"
-
+
result = _extract_evaluator_arn(mock_evaluator, "test_evaluator")
-
+
assert result == mock_evaluator.arn
def test__resolve_model_name_with_model_package(self):
@@ -251,9 +266,9 @@ def test__resolve_model_name_with_model_package(self):
mock_base_model.hub_content_name = "test-model"
mock_container.base_model = mock_base_model
mock_model_package.inference_specification.containers = [mock_container]
-
+
result = _resolve_model_name(mock_model_package)
-
+
assert result == "test-model"
def test__resolve_model_name_with_none(self):
@@ -264,41 +279,41 @@ def test__resolve_model_package_arn_success(self):
mock_model_package = Mock()
expected_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/test-package"
mock_model_package.model_package_arn = expected_arn
-
+
result = _resolve_model_package_arn(mock_model_package)
-
+
assert result == expected_arn
def test__resolve_model_package_arn_failure(self):
mock_model_package = Mock()
mock_model_package.model_package_arn = None
-
+
result = _resolve_model_package_arn(mock_model_package)
-
+
assert result is None
- @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
- @patch('boto3.client')
+ @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
+ @patch("boto3.client")
def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get_hub_content):
mock_session = Mock()
mock_session.boto_session.region_name = "us-east-1"
-
+
# Mock hub content metadata
mock_get_hub_content.return_value = {
- 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
- 'hub_content_document': {
+ "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
+ "hub_content_document": {
"GatedBucket": False,
"RecipeCollection": [
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.json",
"SmtjOverrideParamsS3Uri": "s3://bucket/params.json",
- "Peft": True
+ "Peft": True,
}
- ]
- }
+ ],
+ },
}
-
+
# Mock S3 client
mock_s3_client = Mock()
mock_boto_client.return_value = mock_s3_client
@@ -307,9 +322,9 @@ def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get
}
mock_session.boto_session.client.return_value = mock_s3_client
mock_session.boto_session.client.return_value = mock_s3_client
-
+
result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session)
-
+
# Handle case where function might return None
if result is not None:
options, model_arn, is_gated_model = result
@@ -322,7 +337,7 @@ def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get
def test_create_input_channels_s3_uri(self):
result = _create_input_channels("s3://bucket/data", "application/json")
-
+
assert len(result) == 1
assert result[0].channel_name == "train"
assert result[0].data_source.s3_data_source.s3_uri == "s3://bucket/data"
@@ -330,9 +345,9 @@ def test_create_input_channels_s3_uri(self):
def test_create_input_channels_dataset_arn(self):
arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test/1.0"
-
+
result = _create_input_channels(arn)
-
+
assert len(result) == 1
assert result[0].channel_name == "train"
assert result[0].data_source.dataset_source.dataset_arn == arn
@@ -340,24 +355,24 @@ def test_create_input_channels_dataset_arn(self):
def test__validate_and_resolve_model_package_group_with_provided_name(self):
model = "test-model"
group_name = "test-group"
-
+
result = _validate_and_resolve_model_package_group(model, group_name)
-
+
assert result == group_name
def test__validate_and_resolve_model_package_group_from_model_package(self):
mock_model = Mock(spec=ModelPackage)
mock_model.model_package_group_name = "extracted-group"
-
+
result = _validate_and_resolve_model_package_group(mock_model, None)
-
+
assert result == "extracted-group"
def test__validate_and_resolve_model_package_group_missing_both(self):
with pytest.raises(ValueError, match="model_package_group_name must be provided"):
_validate_and_resolve_model_package_group("string-model", None)
- @patch('sagemaker.core.resources.ModelPackage.get')
+ @patch("sagemaker.core.resources.ModelPackage.get")
def test__resolve_model_and_name_with_model_package_arn(self, mock_get):
mock_session = Mock()
mock_session.boto_region_name = "us-east-1" # Set valid region
@@ -369,15 +384,17 @@ def test__resolve_model_and_name_with_model_package_arn(self, mock_get):
mock_model_package.inference_specification = Mock()
mock_model_package.inference_specification.containers = [mock_container]
mock_get.return_value = mock_model_package
-
- model, name = _resolve_model_and_name("arn:aws:sagemaker:us-east-1:123456789012:model-package/test", mock_session)
-
+
+ model, name = _resolve_model_and_name(
+ "arn:aws:sagemaker:us-east-1:123456789012:model-package/test", mock_session
+ )
+
assert model == mock_model_package
assert name == "test-model"
def test__resolve_model_and_name_with_string(self):
model, name = _resolve_model_and_name("test-model")
-
+
assert model == "test-model"
assert name == "test-model"
@@ -389,15 +406,15 @@ def test__resolve_model_and_name_with_model_package_object(self):
mock_container.base_model = mock_base_model
mock_model_package.inference_specification = Mock()
mock_model_package.inference_specification.containers = [mock_container]
-
+
model, name = _resolve_model_and_name(mock_model_package)
-
+
assert model == mock_model_package
assert name == "test-model"
def test__create_serverless_config_with_lora(self):
config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True)
-
+
assert config.job_type == "FineTuning"
assert config.base_model_arn == "model-arn"
assert config.customization_technique == "SFT"
@@ -405,14 +422,13 @@ def test__create_serverless_config_with_lora(self):
def test__create_serverless_config_with_full(self):
config = _create_serverless_config("model-arn", "SFT", TrainingType.FULL, accept_eula=True)
-
+
assert config.peft is None
def test__create_input_data_config(self):
-
config = _create_input_data_config("s3://bucket/train", "s3://bucket/val")
-
+
assert len(config) == 2
assert config[0].channel_name == "train"
assert config[1].channel_name == "validation"
@@ -421,30 +437,34 @@ def test__create_model_package_config(self):
mock_session = Mock()
mock_model = Mock(spec=ModelPackage)
mock_model.model_package_arn = "source-arn"
-
- with patch('sagemaker.train.common_utils.finetune_utils._resolve_model_package_group_arn') as mock_resolve:
+
+ with patch(
+ "sagemaker.train.common_utils.finetune_utils._resolve_model_package_group_arn"
+ ) as mock_resolve:
mock_resolve.return_value = "group-arn"
config = _create_model_package_config("test-group", mock_model, mock_session)
-
+
assert config.model_package_group_arn == "group-arn"
assert config.source_model_package_arn == "source-arn"
def test__create_mlflow_config(self):
mock_session = Mock()
-
- with patch('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn') as mock_resolve:
+
+ with patch(
+ "sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn"
+ ) as mock_resolve:
mock_resolve.return_value = "mlflow-arn"
config = _create_mlflow_config(mock_session, mlflow_experiment_name="test-exp")
-
+
assert config.mlflow_resource_arn == "mlflow-arn"
assert config.mlflow_experiment_name == "test-exp"
- @patch('sagemaker.train.common_utils.finetune_utils._validate_s3_path_exists')
+ @patch("sagemaker.train.common_utils.finetune_utils._validate_s3_path_exists")
def test__create_output_config(self, mock_validate_s3):
mock_session = Mock()
-
+
config = _create_output_config(mock_session, "s3://bucket/output", "kms-key")
-
+
assert config.s3_output_path == "s3://bucket/output"
assert config.kms_key_id == "kms-key"
mock_validate_s3.assert_called_once_with("s3://bucket/output", mock_session)
@@ -453,22 +473,23 @@ def test__convert_input_data_to_channels(self):
input_data = [InputData(channel_name="train", data_source="s3://bucket/data")]
channels = _convert_input_data_to_channels(input_data)
-
+
assert len(channels) == 1
assert channels[0].channel_name == "train"
def test__validate_eula_for_gated_model_with_model_package(self):
"""Test EULA validation returns True for ModelPackage input"""
from sagemaker.core.resources import ModelPackage
+
model_package = Mock(spec=ModelPackage)
-
+
result = _validate_eula_for_gated_model(model_package, False, True)
assert result == True
def test__validate_eula_for_gated_model_with_arn(self):
"""Test EULA validation returns True for ARN input"""
model_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/test/1"
-
+
result = _validate_eula_for_gated_model(model_arn, False, True)
assert result == True
@@ -495,7 +516,9 @@ def test__validate_model_region_availability_nova_valid_region(self):
def test__validate_model_region_availability_nova_invalid_region(self):
"""Test Nova model validation fails for invalid region"""
- with pytest.raises(ValueError, match="Region 'eu-west-1' does not support model customization"):
+ with pytest.raises(
+ ValueError, match="Region 'eu-west-1' does not support model customization"
+ ):
_validate_model_region_availability("nova-textgeneration-lite-v2", "eu-west-1")
def test__validate_model_region_availability_open_weights_valid_region(self):
@@ -505,57 +528,63 @@ def test__validate_model_region_availability_open_weights_valid_region(self):
def test__validate_model_region_availability_open_weights_invalid_region(self):
"""Test open weights model validation fails for invalid region"""
- with pytest.raises(ValueError, match="Region 'us-west-1' does not support model customization"):
+ with pytest.raises(
+ ValueError, match="Region 'us-west-1' does not support model customization"
+ ):
_validate_model_region_availability("meta-textgeneration-llama-3-2-1b", "us-west-1")
def test__validate_s3_path_exists_invalid_format(self):
"""Test S3 path validation fails for invalid format"""
mock_session = Mock()
-
+
with pytest.raises(ValueError, match="Invalid S3 path format"):
_validate_s3_path_exists("invalid-path", mock_session)
- @patch('boto3.client')
+ @patch("boto3.client")
def test__validate_s3_path_exists_bucket_only_success(self, mock_boto_client):
"""Test S3 path validation succeeds for bucket-only path"""
mock_session = Mock()
mock_s3_client = Mock()
mock_session.boto_session.client.return_value = mock_s3_client
-
+
_validate_s3_path_exists("s3://test-bucket", mock_session)
-
+
mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket")
- @patch('boto3.client')
+ @patch("boto3.client")
def test__validate_s3_path_exists_with_prefix_exists(self, mock_boto_client):
"""Test S3 path validation succeeds when prefix exists"""
mock_session = Mock()
mock_s3_client = Mock()
mock_session.boto_session.client.return_value = mock_s3_client
mock_s3_client.list_objects_v2.return_value = {"Contents": [{"Key": "prefix/file.txt"}]}
-
+
_validate_s3_path_exists("s3://test-bucket/prefix/", mock_session)
-
+
mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket")
- mock_s3_client.list_objects_v2.assert_called_once_with(Bucket="test-bucket", Prefix="prefix/", MaxKeys=1)
+ mock_s3_client.list_objects_v2.assert_called_once_with(
+ Bucket="test-bucket", Prefix="prefix/", MaxKeys=1
+ )
- @patch('boto3.client')
+ @patch("boto3.client")
def test__validate_s3_path_exists_with_prefix_not_exists(self, mock_boto_client):
"""Test S3 path validation creates prefix when it doesn't exist"""
mock_session = Mock()
mock_s3_client = Mock()
mock_session.boto_session.client.return_value = mock_s3_client
mock_s3_client.list_objects_v2.return_value = {} # No contents
-
- _validate_s3_path_exists("s3://test-bucket/prefix", mock_session)
-
- mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket")
- mock_s3_client.list_objects_v2.assert_called_once_with(Bucket="test-bucket", Prefix="prefix", MaxKeys=1)
- mock_s3_client.put_object.assert_called_once_with(Bucket="test-bucket", Key="prefix/", Body=b'')
+ _validate_s3_path_exists("s3://test-bucket/prefix", mock_session)
+ mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket")
+ mock_s3_client.list_objects_v2.assert_called_once_with(
+ Bucket="test-bucket", Prefix="prefix", MaxKeys=1
+ )
+ mock_s3_client.put_object.assert_called_once_with(
+ Bucket="test-bucket", Key="prefix/", Body=b""
+ )
- @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
+ @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_get_hub_content):
"""When and user is subscribed, datamix HPs are available."""
mock_session = Mock()
@@ -563,34 +592,40 @@ def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_ge
mock_s3 = Mock()
mock_sts = Mock()
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
- mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
+ mock_session.boto_session.client.side_effect = lambda service, **kwargs: (
+ mock_s3 if service == "s3" else mock_sts
+ )
mock_get_hub_content.return_value = {
- 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
- 'hub_content_document': {
+ "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
+ "hub_content_document": {
"GatedBucket": False,
"RecipeCollection": [
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
- "Name": "standard_sft"
+ "Name": "standard_sft",
},
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-123456789012/source/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
"Name": "datamix_sft",
- "IsSubscriptionModel": True
- }
- ]
- }
+ "IsSubscriptionModel": True,
+ },
+ ],
+ },
}
# Standard recipe returns base params
- standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
+ standard_params = json.dumps(
+ {"max_steps": {"type": "integer", "required": True, "default": 100}}
+ )
# Subscription recipe returns datamix params
- datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}})
+ datamix_params = json.dumps(
+ {"customer_data_percent": {"type": "integer", "required": False, "default": 50}}
+ )
mock_s3.get_object.side_effect = [
{"Body": Mock(read=Mock(return_value=standard_params.encode()))},
@@ -598,15 +633,22 @@ def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_ge
]
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
- "test-model", "SFT", "FULL", mock_session,
+ "test-model",
+ "SFT",
+ "FULL",
+ mock_session,
)
assert "max_steps" in options._specs
assert "customer_data_percent" in options._specs
- assert options._specs["customer_data_percent"]["default"] is None # defaults are None so they dont serialize unless explicitly set
-
- @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
- def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(self, mock_get_hub_content):
+ assert (
+ options._specs["customer_data_percent"]["default"] is None
+ ) # defaults are None so they dont serialize unless explicitly set
+
+ @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
+ def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(
+ self, mock_get_hub_content
+ ):
"""When (default), datamix HPs are NOT available."""
mock_session = Mock()
mock_session.boto_session.region_name = "us-east-1"
@@ -614,70 +656,83 @@ def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(self, moc
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3
mock_get_hub_content.return_value = {
- 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
- 'hub_content_document': {
+ "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
+ "hub_content_document": {
"GatedBucket": False,
"RecipeCollection": [
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
- "Name": "standard_sft"
+ "Name": "standard_sft",
},
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
"Name": "datamix_sft",
- "IsSubscriptionModel": True
- }
- ]
- }
+ "IsSubscriptionModel": True,
+ },
+ ],
+ },
}
- standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
- mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=standard_params.encode()))}
+ standard_params = json.dumps(
+ {"max_steps": {"type": "integer", "required": True, "default": 100}}
+ )
+ mock_s3.get_object.return_value = {
+ "Body": Mock(read=Mock(return_value=standard_params.encode()))
+ }
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
- "test-model", "SFT", "FULL", mock_session,
+ "test-model",
+ "SFT",
+ "FULL",
+ mock_session,
)
assert "max_steps" in options._specs
assert "customer_data_percent" not in options._specs
- @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
- def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, mock_get_hub_content):
+ @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
+ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(
+ self, mock_get_hub_content
+ ):
"""When but user is NOT subscribed, falls back gracefully."""
mock_session = Mock()
mock_session.boto_session.region_name = "us-east-1"
mock_s3 = Mock()
mock_sts = Mock()
mock_sts.get_caller_identity.return_value = {"Account": "999999999999"}
- mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
+ mock_session.boto_session.client.side_effect = lambda service, **kwargs: (
+ mock_s3 if service == "s3" else mock_sts
+ )
mock_get_hub_content.return_value = {
- 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
- 'hub_content_document': {
+ "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
+ "hub_content_document": {
"GatedBucket": False,
"RecipeCollection": [
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
- "Name": "standard_sft"
+ "Name": "standard_sft",
},
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml",
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
"Name": "datamix_sft",
- "IsSubscriptionModel": True
- }
- ]
- }
+ "IsSubscriptionModel": True,
+ },
+ ],
+ },
}
- standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
+ standard_params = json.dumps(
+ {"max_steps": {"type": "integer", "required": True, "default": 100}}
+ )
# First call succeeds (standard recipe), second call fails (access denied)
mock_s3.get_object.side_effect = [
{"Body": Mock(read=Mock(return_value=standard_params.encode()))},
@@ -685,9 +740,289 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self,
]
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
- "test-model", "SFT", "FULL", mock_session,
+ "test-model",
+ "SFT",
+ "FULL",
+ mock_session,
)
# Should still have standard params, just not datamix ones
assert "max_steps" in options._specs
assert "customer_data_percent" not in options._specs
+
+ def test__create_serverless_config_with_sequence_length(self):
+ config = _create_serverless_config(
+ "model-arn", "SFT", TrainingType.LORA, accept_eula=True, sequence_length="8K"
+ )
+
+ assert config.sequence_length == "8K"
+ assert config.base_model_arn == "model-arn"
+
+ def test__create_serverless_config_without_sequence_length(self):
+ config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True)
+
+ assert config.sequence_length is None
+
+ def test__parse_context_length_with_k_suffix(self):
+ assert _parse_context_length("8K") == 8192
+ assert _parse_context_length("32K") == 32768
+ assert _parse_context_length("128K") == 131072
+
+ def test__parse_context_length_with_lowercase(self):
+ assert _parse_context_length("8k") == 8192
+
+ def test__parse_context_length_with_integer(self):
+ assert _parse_context_length("4096") == 4096
+
+ def test__parse_context_length_with_none(self):
+ assert _parse_context_length(None) == 0
+
+ def test__parse_context_length_with_empty(self):
+ assert _parse_context_length("") == 0
+
+ @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
+ def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_content):
+ mock_session = Mock()
+ mock_session.boto_session.region_name = "us-east-1"
+ mock_s3 = Mock()
+ mock_s3.get_object.return_value = {
+ "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}'))
+ }
+ mock_session.boto_session.client.return_value = mock_s3
+
+ mock_get_hub_content.return_value = {
+ "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
+ "hub_content_document": {
+ "GatedBucket": False,
+ "RecipeCollection": [
+ {
+ "CustomizationTechnique": "SFT",
+ "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json",
+ "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json",
+ "Peft": True,
+ "SequenceLength": "4K",
+ },
+ {
+ "CustomizationTechnique": "SFT",
+ "SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json",
+ "SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json",
+ "Peft": True,
+ "SequenceLength": "32K",
+ },
+ ],
+ },
+ }
+
+ result = _get_fine_tuning_options_and_model_arn(
+ "test-model", "SFT", "LORA", mock_session, sequence_length="8K"
+ )
+
+ if result is not None:
+ options, model_arn, is_gated_model = result
+ # Should pick the 32K recipe (smallest >= 8K)
+ mock_s3.get_object.assert_called_once()
+ call_args = mock_s3.get_object.call_args[1]
+ assert "params-32k" in call_args["Key"]
+
+ @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
+ def test__get_fine_tuning_options_raises_when_no_sufficient_context_length(
+ self, mock_get_hub_content
+ ):
+ mock_session = Mock()
+ mock_session.boto_session.region_name = "us-east-1"
+
+ mock_get_hub_content.return_value = {
+ "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
+ "hub_content_document": {
+ "GatedBucket": False,
+ "RecipeCollection": [
+ {
+ "CustomizationTechnique": "SFT",
+ "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json",
+ "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json",
+ "Peft": True,
+ "SequenceLength": "4K",
+ }
+ ],
+ },
+ }
+
+ # Requesting 128K but only 4K available — should raise
+ with pytest.raises(ValueError, match="No recipes found with SequenceLength >= 128K"):
+ _get_fine_tuning_options_and_model_arn(
+ "test-model", "SFT", "LORA", mock_session, sequence_length="128K"
+ )
+
+ def test__parse_context_length_with_invalid_k_value(self):
+ assert _parse_context_length("abcK") == 0
+
+ def test__parse_context_length_with_non_numeric_string(self):
+ assert _parse_context_length("hello") == 0
+
+ @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
+ def test__get_fine_tuning_options_raises_when_no_recipes_have_sequence_length(
+ self, mock_get_hub_content
+ ):
+ mock_session = Mock()
+ mock_session.boto_session.region_name = "us-east-1"
+
+ mock_get_hub_content.return_value = {
+ "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
+ "hub_content_document": {
+ "GatedBucket": False,
+ "RecipeCollection": [
+ {
+ "CustomizationTechnique": "SFT",
+ "SmtjRecipeTemplateS3Uri": "s3://bucket/template.json",
+ "SmtjOverrideParamsS3Uri": "s3://bucket/params.json",
+ "Peft": True,
+ }
+ ],
+ },
+ }
+
+ with pytest.raises(ValueError, match="and sequence length"):
+ _get_fine_tuning_options_and_model_arn(
+ "test-model", "SFT", "LORA", mock_session, sequence_length="8K"
+ )
+
+ @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
+ def test__get_fine_tuning_options_filters_by_sequence_length_full_training(
+ self, mock_get_hub_content
+ ):
+ mock_session = Mock()
+ mock_session.boto_session.region_name = "us-east-1"
+ mock_s3 = Mock()
+ mock_s3.get_object.return_value = {
+ "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 8192}}'))
+ }
+ mock_session.boto_session.client.return_value = mock_s3
+
+ mock_get_hub_content.return_value = {
+ "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
+ "hub_content_document": {
+ "GatedBucket": False,
+ "RecipeCollection": [
+ {
+ "CustomizationTechnique": "SFT",
+ "SmtjRecipeTemplateS3Uri": "s3://bucket/template-8k.json",
+ "SmtjOverrideParamsS3Uri": "s3://bucket/params-8k.json",
+ "SequenceLength": "8K",
+ },
+ {
+ "CustomizationTechnique": "SFT",
+ "SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json",
+ "SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json",
+ "SequenceLength": "32K",
+ },
+ ],
+ },
+ }
+
+ result = _get_fine_tuning_options_and_model_arn(
+ "test-model", "SFT", "FULL", mock_session, sequence_length="8K"
+ )
+
+ if result is not None:
+ options, model_arn, is_gated_model = result
+ mock_s3.get_object.assert_called_once()
+ call_args = mock_s3.get_object.call_args[1]
+ assert "params-8k" in call_args["Key"]
+
+ @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
+ def test__get_fine_tuning_options_selects_smallest_sufficient_sequence_length(
+ self, mock_get_hub_content
+ ):
+ mock_session = Mock()
+ mock_session.boto_session.region_name = "us-east-1"
+ mock_s3 = Mock()
+ mock_s3.get_object.return_value = {
+ "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 16384}}'))
+ }
+ mock_session.boto_session.client.return_value = mock_s3
+
+ mock_get_hub_content.return_value = {
+ "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
+ "hub_content_document": {
+ "GatedBucket": False,
+ "RecipeCollection": [
+ {
+ "CustomizationTechnique": "SFT",
+ "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json",
+ "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json",
+ "Peft": True,
+ "SequenceLength": "4K",
+ },
+ {
+ "CustomizationTechnique": "SFT",
+ "SmtjRecipeTemplateS3Uri": "s3://bucket/template-16k.json",
+ "SmtjOverrideParamsS3Uri": "s3://bucket/params-16k.json",
+ "Peft": True,
+ "SequenceLength": "16K",
+ },
+ {
+ "CustomizationTechnique": "SFT",
+ "SmtjRecipeTemplateS3Uri": "s3://bucket/template-128k.json",
+ "SmtjOverrideParamsS3Uri": "s3://bucket/params-128k.json",
+ "Peft": True,
+ "SequenceLength": "128K",
+ },
+ ],
+ },
+ }
+
+ result = _get_fine_tuning_options_and_model_arn(
+ "test-model", "SFT", "LORA", mock_session, sequence_length="8K"
+ )
+
+ if result is not None:
+ options, model_arn, is_gated_model = result
+ # Should pick 16K (smallest >= 8K), not 128K
+ mock_s3.get_object.assert_called_once()
+ call_args = mock_s3.get_object.call_args[1]
+ assert "params-16k" in call_args["Key"]
+
+ @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
+ def test__get_fine_tuning_options_without_sequence_length_uses_first_recipe(
+ self, mock_get_hub_content
+ ):
+ """Verify that when no sequence_length is provided, existing behavior is unchanged."""
+ mock_session = Mock()
+ mock_session.boto_session.region_name = "us-east-1"
+ mock_s3 = Mock()
+ mock_s3.get_object.return_value = {
+ "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 4096}}'))
+ }
+ mock_session.boto_session.client.return_value = mock_s3
+
+ mock_get_hub_content.return_value = {
+ "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
+ "hub_content_document": {
+ "GatedBucket": False,
+ "RecipeCollection": [
+ {
+ "CustomizationTechnique": "SFT",
+ "SmtjRecipeTemplateS3Uri": "s3://bucket/template-first.json",
+ "SmtjOverrideParamsS3Uri": "s3://bucket/params-first.json",
+ "Peft": True,
+ "SequenceLength": "4K",
+ },
+ {
+ "CustomizationTechnique": "SFT",
+ "SmtjRecipeTemplateS3Uri": "s3://bucket/template-second.json",
+ "SmtjOverrideParamsS3Uri": "s3://bucket/params-second.json",
+ "Peft": True,
+ "SequenceLength": "32K",
+ },
+ ],
+ },
+ }
+
+ result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session)
+
+ if result is not None:
+ options, model_arn, is_gated_model = result
+ # Without sequence_length, should pick the first matching recipe
+ mock_s3.get_object.assert_called_once()
+ call_args = mock_s3.get_object.call_args[1]
+ assert "params-first" in call_args["Key"]
diff --git a/sagemaker-train/tests/unit/train/test_dpo_trainer.py b/sagemaker-train/tests/unit/train/test_dpo_trainer.py
index 1b70e0bf89..7648b46e35 100644
--- a/sagemaker-train/tests/unit/train/test_dpo_trainer.py
+++ b/sagemaker-train/tests/unit/train/test_dpo_trainer.py
@@ -506,4 +506,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_
trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train")
trainer.train(wait=False, wait_timeout=600)
- mock_wait.assert_not_called()
\ No newline at end of file
+ mock_wait.assert_not_called()
+
+ @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
+ def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group):
+ mock_validate_group.return_value = "test-group"
+ mock_hyperparams = Mock()
+ mock_hyperparams.to_dict.return_value = {}
+ mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)
+ trainer = DPOTrainer(model="test-model", model_package_group="test-group")
+ assert trainer.sequence_length is None
+
+ @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
+ def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group):
+ mock_validate_group.return_value = "test-group"
+ mock_hyperparams = Mock()
+ mock_hyperparams.to_dict.return_value = {}
+ mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)
+ trainer = DPOTrainer(model="test-model", model_package_group="test-group", sequence_length="8K")
+ assert trainer.sequence_length == "8K"
+
+ @patch('sagemaker.train.dpo_trainer._resolve_model_and_name')
+ @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
+ @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role')
+ @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session')
+ @patch('sagemaker.train.dpo_trainer._get_unique_name')
+ @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.dpo_trainer._create_input_data_config')
+ @patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels')
+ @patch('sagemaker.train.dpo_trainer._create_output_config')
+ @patch('sagemaker.train.dpo_trainer._create_serverless_config')
+ @patch('sagemaker.train.dpo_trainer._create_mlflow_config')
+ @patch('sagemaker.train.dpo_trainer._create_model_package_config')
+ @patch('sagemaker.core.resources.TrainingJob.create')
+ def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create,
+ mock_model_package_config, mock_mlflow_config, mock_serverless_config,
+ mock_output_config, mock_convert_channels, mock_input_config,
+ mock_validate_group, mock_unique_name, mock_get_sagemaker_session,
+ mock_get_role, mock_get_options, mock_resolve_model):
+ mock_validate_group.return_value = "test-group"
+ mock_resolve_model.return_value = ("test-model", "test-model")
+ mock_get_sagemaker_session.return_value = Mock()
+ mock_fine_tuning_options = Mock()
+ mock_fine_tuning_options.to_dict.return_value = {}
+ mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False)
+ mock_get_role.return_value = "test-role"
+ mock_unique_name.return_value = "test-job-name"
+ mock_input_config.return_value = [Mock()]
+ mock_convert_channels.return_value = [Mock()]
+ mock_output_config.return_value = Mock()
+ mock_serverless_config.return_value = Mock()
+ mock_mlflow_config.return_value = Mock()
+ mock_model_package_config.return_value = Mock()
+ mock_training_job = Mock()
+ mock_training_job_create.return_value = mock_training_job
+
+ trainer = DPOTrainer(model="test-model", model_package_group="test-group",
+ training_dataset="s3://bucket/train", sequence_length="16K")
+ trainer.train(wait=False)
+
+ mock_serverless_config.assert_called_once()
+ call_kwargs = mock_serverless_config.call_args[1]
+ assert call_kwargs["sequence_length"] == "16K"
diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py
index e5666883e8..6811c45540 100644
--- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py
+++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py
@@ -682,4 +682,70 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_
trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train")
trainer.train(wait=False, wait_timeout=600)
- mock_wait.assert_not_called()
\ No newline at end of file
+ mock_wait.assert_not_called()
+
+ @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn')
+ def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group):
+ mock_validate_group.return_value = "test-group"
+ mock_hyperparams = Mock()
+ mock_hyperparams.to_dict.return_value = {}
+ mock_hyperparams._specs = {}
+ mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)
+ trainer = RLAIFTrainer(model="test-model", model_package_group="test-group")
+ assert trainer.sequence_length is None
+
+ @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn')
+ def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group):
+ mock_validate_group.return_value = "test-group"
+ mock_hyperparams = Mock()
+ mock_hyperparams.to_dict.return_value = {}
+ mock_hyperparams._specs = {}
+ mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)
+ trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", sequence_length="128K")
+ assert trainer.sequence_length == "128K"
+
+ @patch('sagemaker.train.rlaif_trainer._resolve_model_and_name')
+ @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn')
+ @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_role')
+ @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session')
+ @patch('sagemaker.train.rlaif_trainer._get_unique_name')
+ @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.rlaif_trainer._create_input_data_config')
+ @patch('sagemaker.train.rlaif_trainer._convert_input_data_to_channels')
+ @patch('sagemaker.train.rlaif_trainer._create_output_config')
+ @patch('sagemaker.train.rlaif_trainer._create_serverless_config')
+ @patch('sagemaker.train.rlaif_trainer._create_mlflow_config')
+ @patch('sagemaker.train.rlaif_trainer._create_model_package_config')
+ @patch('sagemaker.core.resources.TrainingJob.create')
+ def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create,
+ mock_model_package_config, mock_mlflow_config, mock_serverless_config,
+ mock_output_config, mock_convert_channels, mock_input_config,
+ mock_validate_group, mock_unique_name, mock_get_sagemaker_session,
+ mock_get_role, mock_get_options, mock_resolve_model):
+ mock_validate_group.return_value = "test-group"
+ mock_resolve_model.return_value = ("test-model", "test-model")
+ mock_get_sagemaker_session.return_value = Mock()
+ mock_fine_tuning_options = Mock()
+ mock_fine_tuning_options.to_dict.return_value = {}
+ mock_fine_tuning_options._specs = {}
+ mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False)
+ mock_get_role.return_value = "test-role"
+ mock_unique_name.return_value = "test-job-name"
+ mock_input_config.return_value = [Mock()]
+ mock_convert_channels.return_value = [Mock()]
+ mock_output_config.return_value = Mock()
+ mock_serverless_config.return_value = Mock()
+ mock_mlflow_config.return_value = Mock()
+ mock_model_package_config.return_value = Mock()
+ mock_training_job = Mock()
+ mock_training_job_create.return_value = mock_training_job
+
+ trainer = RLAIFTrainer(model="test-model", model_package_group="test-group",
+ training_dataset="s3://bucket/train", sequence_length="64K")
+ trainer.train(wait=False)
+
+ mock_serverless_config.assert_called_once()
+ call_kwargs = mock_serverless_config.call_args[1]
+ assert call_kwargs["sequence_length"] == "64K"
diff --git a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py
index 320b81555d..b4c01385e2 100644
--- a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py
+++ b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py
@@ -509,4 +509,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_
trainer = RLVRTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train")
trainer.train(wait=False, wait_timeout=600)
- mock_wait.assert_not_called()
\ No newline at end of file
+ mock_wait.assert_not_called()
+
+ @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn')
+ def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group):
+ mock_validate_group.return_value = "test-group"
+ mock_hyperparams = Mock()
+ mock_hyperparams.to_dict.return_value = {}
+ mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)
+ trainer = RLVRTrainer(model="test-model", model_package_group="test-group")
+ assert trainer.sequence_length is None
+
+ @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn')
+ def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group):
+ mock_validate_group.return_value = "test-group"
+ mock_hyperparams = Mock()
+ mock_hyperparams.to_dict.return_value = {}
+ mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)
+ trainer = RLVRTrainer(model="test-model", model_package_group="test-group", sequence_length="32K")
+ assert trainer.sequence_length == "32K"
+
+ @patch('sagemaker.train.rlvr_trainer._resolve_model_and_name')
+ @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn')
+ @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_role')
+ @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_sagemaker_session')
+ @patch('sagemaker.train.rlvr_trainer._get_unique_name')
+ @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.rlvr_trainer._create_input_data_config')
+ @patch('sagemaker.train.rlvr_trainer._convert_input_data_to_channels')
+ @patch('sagemaker.train.rlvr_trainer._create_output_config')
+ @patch('sagemaker.train.rlvr_trainer._create_serverless_config')
+ @patch('sagemaker.train.rlvr_trainer._create_mlflow_config')
+ @patch('sagemaker.train.rlvr_trainer._create_model_package_config')
+ @patch('sagemaker.core.resources.TrainingJob.create')
+ def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create,
+ mock_model_package_config, mock_mlflow_config, mock_serverless_config,
+ mock_output_config, mock_convert_channels, mock_input_config,
+ mock_validate_group, mock_unique_name, mock_get_sagemaker_session,
+ mock_get_role, mock_get_options, mock_resolve_model):
+ mock_validate_group.return_value = "test-group"
+ mock_resolve_model.return_value = ("test-model", "test-model")
+ mock_get_sagemaker_session.return_value = Mock()
+ mock_fine_tuning_options = Mock()
+ mock_fine_tuning_options.to_dict.return_value = {}
+ mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False)
+ mock_get_role.return_value = "test-role"
+ mock_unique_name.return_value = "test-job-name"
+ mock_input_config.return_value = [Mock()]
+ mock_convert_channels.return_value = [Mock()]
+ mock_output_config.return_value = Mock()
+ mock_serverless_config.return_value = Mock()
+ mock_mlflow_config.return_value = Mock()
+ mock_model_package_config.return_value = Mock()
+ mock_training_job = Mock()
+ mock_training_job_create.return_value = mock_training_job
+
+ trainer = RLVRTrainer(model="test-model", model_package_group="test-group",
+ training_dataset="s3://bucket/train", sequence_length="4K")
+ trainer.train(wait=False)
+
+ mock_serverless_config.assert_called_once()
+ call_kwargs = mock_serverless_config.call_args[1]
+ assert call_kwargs["sequence_length"] == "4K"
diff --git a/sagemaker-train/tests/unit/train/test_sft_trainer.py b/sagemaker-train/tests/unit/train/test_sft_trainer.py
index 108990f839..01fc21f4bd 100644
--- a/sagemaker-train/tests/unit/train/test_sft_trainer.py
+++ b/sagemaker-train/tests/unit/train/test_sft_trainer.py
@@ -520,4 +520,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_
trainer = SFTTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train")
trainer.train(wait=False, wait_timeout=600)
- mock_wait.assert_not_called()
\ No newline at end of file
+ mock_wait.assert_not_called()
+
+ @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn')
+ def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group):
+ mock_validate_group.return_value = "test-group"
+ mock_hyperparams = Mock()
+ mock_hyperparams.to_dict.return_value = {}
+ mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)
+ trainer = SFTTrainer(model="test-model", model_package_group="test-group")
+ assert trainer.sequence_length is None
+
+ @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn')
+ def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group):
+ mock_validate_group.return_value = "test-group"
+ mock_hyperparams = Mock()
+ mock_hyperparams.to_dict.return_value = {}
+ mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)
+ trainer = SFTTrainer(model="test-model", model_package_group="test-group", sequence_length="8K")
+ assert trainer.sequence_length == "8K"
+
+ @patch('sagemaker.train.sft_trainer._resolve_model_and_name')
+ @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn')
+ @patch('sagemaker.train.sft_trainer.TrainDefaults.get_role')
+ @patch('sagemaker.train.sft_trainer.TrainDefaults.get_sagemaker_session')
+ @patch('sagemaker.train.sft_trainer._get_unique_name')
+ @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group')
+ @patch('sagemaker.train.sft_trainer._create_input_data_config')
+ @patch('sagemaker.train.sft_trainer._convert_input_data_to_channels')
+ @patch('sagemaker.train.sft_trainer._create_output_config')
+ @patch('sagemaker.train.sft_trainer._create_serverless_config')
+ @patch('sagemaker.train.sft_trainer._create_mlflow_config')
+ @patch('sagemaker.train.sft_trainer._create_model_package_config')
+ @patch('sagemaker.core.resources.TrainingJob.create')
+ def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create,
+ mock_model_package_config, mock_mlflow_config, mock_serverless_config,
+ mock_output_config, mock_convert_channels, mock_input_config,
+ mock_validate_group, mock_unique_name, mock_get_sagemaker_session,
+ mock_get_role, mock_get_options, mock_resolve_model):
+ mock_validate_group.return_value = "test-group"
+ mock_resolve_model.return_value = ("test-model", "test-model")
+ mock_get_sagemaker_session.return_value = Mock()
+ mock_fine_tuning_options = Mock()
+ mock_fine_tuning_options.to_dict.return_value = {}
+ mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False)
+ mock_get_role.return_value = "test-role"
+ mock_unique_name.return_value = "test-job-name"
+ mock_input_config.return_value = [Mock()]
+ mock_convert_channels.return_value = [Mock()]
+ mock_output_config.return_value = Mock()
+ mock_serverless_config.return_value = Mock()
+ mock_mlflow_config.return_value = Mock()
+ mock_model_package_config.return_value = Mock()
+ mock_training_job = Mock()
+ mock_training_job_create.return_value = mock_training_job
+
+ trainer = SFTTrainer(model="test-model", model_package_group="test-group",
+ training_dataset="s3://bucket/train", sequence_length="16K")
+ trainer.train(wait=False)
+
+ mock_serverless_config.assert_called_once()
+ call_kwargs = mock_serverless_config.call_args[1]
+ assert call_kwargs["sequence_length"] == "16K"