diff --git a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json index a7640e3f46..1dbcf64d36 100644 --- a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json +++ b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json @@ -43236,6 +43236,10 @@ "EvaluatorArn":{ "shape":"EvaluatorArn", "documentation":"

The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt.

" + }, + "SequenceLength":{ + "shape":"SequenceLength", + "documentation":"

The sequence length for the training job.

" } }, "documentation":"

The configuration for the serverless training job.

" @@ -43247,6 +43251,19 @@ "Evaluation" ] }, + "SequenceLength":{ + "type":"string", + "enum":[ + "1K", + "2K", + "4K", + "8K", + "16K", + "32K", + "64K", + "128K" + ] + }, "ServerlessMaxConcurrency":{ "type":"integer", "box":true, diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index 8f99bcba8c..9749b2a4b5 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -9588,6 +9588,7 @@ class ServerlessJobConfig(Base): peft: The parameter-efficient fine-tuning configuration. evaluation_type: The evaluation job type. Required when serverless job type is Evaluation. evaluator_arn: The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt. + sequence_length: The sequence length for the training job. """ base_model_arn: StrPipeVar @@ -9597,6 +9598,7 @@ class ServerlessJobConfig(Base): peft: Optional[StrPipeVar] = Unassigned() evaluation_type: Optional[StrPipeVar] = Unassigned() evaluator_arn: Optional[StrPipeVar] = Unassigned() + sequence_length: Optional[StrPipeVar] = Unassigned() class MlflowConfig(Base): diff --git a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py index d7ad54ee25..f34a282348 100644 --- a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py +++ b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py @@ -15877,6 +15877,7 @@ {"name": "Peft", "shape": "Peft", "type": "string"}, {"name": "EvaluationType", "shape": "EvaluationType", "type": "string"}, {"name": "EvaluatorArn", "shape": "EvaluatorArn", "type": "string"}, + {"name": "SequenceLength", "shape": "SequenceLength", "type": "string"}, ], "type": "structure", }, diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 0ea74ee207..0c04b29d2a 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -12,7 +12,13 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.train.common_utils.recipe_utils import _get_hub_content_metadata from sagemaker.train.common import TrainingType, CustomizationTechnique, JOB_TYPE, FineTuningOptions -from sagemaker.core.shapes import ServerlessJobConfig, Channel, DataSource, ModelPackageConfig, MlflowConfig +from sagemaker.core.shapes import ( + ServerlessJobConfig, + Channel, + DataSource, + ModelPackageConfig, + MlflowConfig, +) from sagemaker.train.configs import InputData, OutputDataConfig from sagemaker.train.defaults import TrainDefaults from sagemaker.train.constants import get_sagemaker_hub_name @@ -20,11 +26,17 @@ logger = logging.getLogger(__name__) # Region mappings for model availability -OPEN_WEIGHTS_REGIONS = ['us-east-1', 'us-west-2', 'ap-northeast-1', 'eu-west-1'] # IAD, PDX, NRT, DUB -NOVA_REGIONS = ['us-east-1', 'us-west-2'] # IAD, PDX +OPEN_WEIGHTS_REGIONS = [ + "us-east-1", + "us-west-2", + "ap-northeast-1", + "eu-west-1", +] # IAD, PDX, NRT, DUB +NOVA_REGIONS = ["us-east-1", "us-west-2"] # IAD, PDX # Constants DEFAULT_REGION = "us-west-2" + def _validate_model_region_availability(model_name: str, region_name: str): """Validate if the model is available in the specified region.""" if "nova" in model_name.lower(): @@ -48,26 +60,24 @@ def _validate_model_region_availability(model_name: str, region_name: str): ) - - def _get_beta_session(): """Create a SageMaker session with beta endpoint for demo purposes.""" - sm_client = boto3.client('sagemaker', region_name=DEFAULT_REGION) + sm_client = boto3.client("sagemaker", region_name=DEFAULT_REGION) return Session(sagemaker_client=sm_client) def _read_domain_id_from_metadata() -> Optional[str]: """Read domain ID from Studio metadata file. - + This is the standard location for domain information in Studio with Spaces. Returns None if not running in Studio or if metadata file doesn't exist. """ try: - metadata_path = '/opt/ml/metadata/resource-metadata.json' + metadata_path = "/opt/ml/metadata/resource-metadata.json" if os.path.exists(metadata_path): - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: metadata = json.load(f) - return metadata.get('DomainId') + return metadata.get("DomainId") except Exception as e: logger.debug(f"Could not read Studio metadata file: {e}") return None @@ -75,78 +85,88 @@ def _read_domain_id_from_metadata() -> Optional[str]: def _get_current_domain_id(sagemaker_session) -> Optional[str]: """Get current SageMaker Studio domain ID. - + Checks multiple sources in order of reliability: 1. Studio metadata file (Studio with Spaces - newer architecture) 2. User profile ARN (Studio Classic with User Profiles - legacy) - + Returns None if not running in a Studio environment with domain. """ # Try metadata file first (Studio with Spaces) domain_id = _read_domain_id_from_metadata() if domain_id: return domain_id - + # Fallback to original logic (Studio Classic with User Profiles) try: user_profile_arn = sagemaker_session.get_caller_identity_arn() - if user_profile_arn and 'user-profile' in user_profile_arn: + if user_profile_arn and "user-profile" in user_profile_arn: # ARN format: arn:aws:sagemaker:region:account:user-profile/domain-id/profile-name - return user_profile_arn.split('/')[1] + return user_profile_arn.split("/")[1] except Exception as e: logger.debug(f"Could not extract domain ID from user profile ARN: {e}") - + return None -def _resolve_mlflow_resource_arn(sagemaker_session, mlflow_resource_arn: Optional[str] = None) -> Optional[str]: +def _resolve_mlflow_resource_arn( + sagemaker_session, mlflow_resource_arn: Optional[str] = None +) -> Optional[str]: """Resolve MLflow resource ARN using default experience logic.""" if mlflow_resource_arn: return mlflow_resource_arn - + try: - + mlflow_apps = MlflowApp.get_all( session=sagemaker_session.boto_session, - region=sagemaker_session.boto_session.region_name + region=sagemaker_session.boto_session.region_name, ) - + mlflow_apps_list = list(mlflow_apps) current_domain_id = _get_current_domain_id(sagemaker_session) - + # Check for domain match if current_domain_id: - domain_match = next((app for app in mlflow_apps_list - if isinstance(app.default_domain_id_list, list) - and current_domain_id in app.default_domain_id_list), None) + domain_match = next( + ( + app + for app in mlflow_apps_list + if isinstance(app.default_domain_id_list, list) + and current_domain_id in app.default_domain_id_list + ), + None, + ) if domain_match: logger.info("Using domain-matched MLflow app: %s", domain_match.arn) return domain_match.arn - + # Check for account default - account_default = next((app for app in mlflow_apps_list - if app.account_default_status == "ENABLED"), None) + account_default = next( + (app for app in mlflow_apps_list if app.account_default_status == "ENABLED"), None + ) if account_default: logger.info("Using account default MLflow app: %s", account_default.arn) return account_default.arn - + # Use first available with ready status if mlflow_apps_list: - ready_app = next((app for app in mlflow_apps_list - if app.status in ["Created", "Updated"]), None) + ready_app = next( + (app for app in mlflow_apps_list if app.status in ["Created", "Updated"]), None + ) if ready_app: logger.info("Using first available ready MLflow app: %s", ready_app.arn) return ready_app.arn - + # Create new app new_app = _create_mlflow_app(sagemaker_session) if new_app: logger.info("Created new MLflow app: %s", new_app.arn) return new_app.arn - + logger.warning("Failed to create MLflow app. MLflow tracking disabled.") return None - + except Exception as e: logger.error("Error resolving MLflow resource ARN: %s", e) return None @@ -156,45 +176,46 @@ def _create_mlflow_app(sagemaker_session) -> Optional[MlflowApp]: """Create a new MLflow app with minimal configuration.""" try: app_name = f"finetune-mlflow-{int(time.time())}" - account_id = sagemaker_session.boto_session.client('sts').get_caller_identity()['Account'] + account_id = sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"] region = sagemaker_session.boto_session.region_name artifact_store_uri = f"s3://sagemaker-{region}-{account_id}/mlflow-artifacts" role_arn = TrainDefaults.get_role(role=None, sagemaker_session=sagemaker_session) - + # Ensure S3 bucket and prefix exist - s3_client = sagemaker_session.boto_session.client('s3') + s3_client = sagemaker_session.boto_session.client("s3") bucket_name = f"sagemaker-{region}-{account_id}" - + try: # Check if prefix exists - response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix="mlflow-artifacts/", MaxKeys=1) - if 'Contents' not in response: + response = s3_client.list_objects_v2( + Bucket=bucket_name, Prefix="mlflow-artifacts/", MaxKeys=1 + ) + if "Contents" not in response: s3_client.put_object(Bucket=bucket_name, Key="mlflow-artifacts/") except s3_client.exceptions.NoSuchBucket: # Bucket doesn't exist, create bucket and prefix - if region == 'us-east-1': + if region == "us-east-1": s3_client.create_bucket(Bucket=bucket_name) else: s3_client.create_bucket( - Bucket=bucket_name, - CreateBucketConfiguration={'LocationConstraint': region} + Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region} ) s3_client.put_object(Bucket=bucket_name, Key="mlflow-artifacts/") - + new_app = MlflowApp.create( name=app_name, artifact_store_uri=artifact_store_uri, role_arn=role_arn, account_default_status="DISABLED", session=sagemaker_session.boto_session, - region=region + region=region, ) - + # Wait for app to reach Created/Updated state max_wait_time = 600 # 10 minutes - poll_interval = 10 # 10 seconds + poll_interval = 10 # 10 seconds start_time = time.time() - + while time.time() - start_time < max_wait_time: new_app.refresh() if new_app.status in ["Created", "Updated"]: @@ -202,18 +223,18 @@ def _create_mlflow_app(sagemaker_session) -> Optional[MlflowApp]: elif new_app.status in ["Failed", "Stopped"]: # Get detailed error from MLflow app error_msg = f"MLflow app creation failed with status: {new_app.status}" - if hasattr(new_app, 'failure_reason') and new_app.failure_reason: + if hasattr(new_app, "failure_reason") and new_app.failure_reason: error_msg += f". Reason: {new_app.failure_reason}" raise RuntimeError(error_msg) time.sleep(poll_interval) - + # Timeout case - get current status and any error details new_app.refresh() error_msg = f"MLflow app creation failed. Current status: {new_app.status}" - if hasattr(new_app, 'failure_reason') and new_app.failure_reason: + if hasattr(new_app, "failure_reason") and new_app.failure_reason: error_msg += f". Reason: {new_app.failure_reason}" raise RuntimeError(error_msg) - + except Exception as e: logger.error("Failed to create MLflow app: %s", e) return None @@ -229,14 +250,18 @@ def _validate_dataset_arn(dataset: str, param_name: str): def _validate_evaluator_arn(evaluator_arn: str, param_name: str): """Validate that evaluator_arn is in correct ARN format.""" arn_pattern = r"^arn:aws:sagemaker:[^:]+:\d+:hub-content/[^/]+/JsonDoc/[^/]+/[\d\.]+$" - if not evaluator_arn.startswith("arn:aws:sagemaker:") or not re.match(arn_pattern, evaluator_arn): + if not evaluator_arn.startswith("arn:aws:sagemaker:") or not re.match( + arn_pattern, evaluator_arn + ): raise ValueError(f"{param_name} must be a valid SageMaker hub-content evaluator ARN") def _validate_model_package_group_requirement(model, model_package_group_name): """Validate model_package_group_name when source_model_package_arn is not available.""" if not isinstance(model, ModelPackage) and not model_package_group_name: - raise ValueError("model_package_group_name must be provided when source_model_package_arn is not available") + raise ValueError( + "model_package_group_name must be provided when source_model_package_arn is not available" + ) def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_session) -> str: @@ -244,7 +269,7 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_ if isinstance(model_package_group_name_or_arn, str): # Check if it's already an ARN using pattern matching arn_pattern = r"^arn:aws:sagemaker:[^:]+:\d+:model-package-group/[^/]+$" - + if re.match(arn_pattern, model_package_group_name_or_arn): # It's already an ARN return model_package_group_name_or_arn @@ -253,7 +278,7 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_ model_package_group = ModelPackageGroup.get( model_package_group_name=model_package_group_name_or_arn, session=sagemaker_session.boto_session, - region=sagemaker_session.boto_session.region_name + region=sagemaker_session.boto_session.region_name, ) return model_package_group.model_package_group_arn else: @@ -263,7 +288,7 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_ def _get_default_s3_output_path(sagemaker_session) -> str: """Generate default S3 output path: s3://sagemaker--/output""" - account_id = sagemaker_session.boto_session.client('sts').get_caller_identity()['Account'] + account_id = sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"] region = sagemaker_session.boto_session.region_name return f"s3://sagemaker-{region}-{account_id}/output" @@ -295,17 +320,21 @@ def _resolve_model_name(model_package) -> str: if model_package: try: # Extract base model from InferenceSpecification - if (model_package.inference_specification and - model_package.inference_specification.containers): + if ( + model_package.inference_specification + and model_package.inference_specification.containers + ): container = model_package.inference_specification.containers[0] - if hasattr(container, 'base_model') and container.base_model: + if hasattr(container, "base_model") and container.base_model: return container.base_model.hub_content_name - - raise ValueError("Continued fine tuning is only allowed on model packages fine tuned with sagemaker 1p models") + + raise ValueError( + "Continued fine tuning is only allowed on model packages fine tuned with sagemaker 1p models" + ) except Exception as e: logger.error("Failed to resolve model_name from model package: %s", e) raise - + raise ValueError("model name or package must be provided") @@ -318,10 +347,44 @@ def _resolve_model_package_arn(model_package) -> Optional[str]: return None -def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session, - hub_name: Optional[str] = None) -> tuple: +def _parse_context_length(value) -> int: + """Parse a context length value like '8K', '32K', '128K' into an integer (e.g., 8192). + + Returns 0 if value is None or unparseable. + """ + if not value: + return 0 + value = str(value).strip().upper() + if value.endswith("K"): + try: + return int(value[:-1]) * 1024 + except ValueError: + return 0 + try: + return int(value) + except ValueError: + return 0 + + +def _get_fine_tuning_options_and_model_arn( + model_name: str, + customization_technique: str, + training_type, + sagemaker_session, + sequence_length=None, + hub_name: Optional[str] = None, +) -> tuple: """Get fine-tuning options and model ARN for given customization technique. - + + Args: + model_name: Name of the model in the hub. + customization_technique: Technique (e.g., "SFT", "DPO", "RLVR", "RLAIF"). + training_type: TrainingType enum or string ("LORA", "FULL"). + sagemaker_session: SageMaker session for API calls. + sequence_length: Optional sequence length (e.g., "8K"). When provided, filters + recipes by MaxContextLength >= the requested value. + hub_name: Hub name (default: "SageMakerPublicHub"). + Returns: tuple: (FineTuningOptions, model_arn, is_gated_model) """ @@ -332,42 +395,95 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni hub_content = _get_hub_content_metadata( hub_name=hub_name, - hub_content_type="Model", + hub_content_type="Model", hub_content_name=model_name, session=sagemaker_session.boto_session, - region=sagemaker_session.boto_session.region_name + region=sagemaker_session.boto_session.region_name, ) - - model_arn = hub_content.get('hub_content_arn') - document = hub_content.get('hub_content_document') - + + model_arn = hub_content.get("hub_content_arn") + document = hub_content.get("hub_content_document") + # Check if model is gated is_gated_model = document.get("GatedBucket", False) - + recipe_collection = document.get("RecipeCollection", []) - + # Filter recipes by customization technique - matching_recipes = [r for r in recipe_collection if r.get("CustomizationTechnique") == customization_technique] - + matching_recipes = [ + r + for r in recipe_collection + if r.get("CustomizationTechnique") == customization_technique + ] + if not matching_recipes: - raise ValueError(f"No recipes found for customization technique: {customization_technique}") - + raise ValueError( + f"No recipes found for customization technique: {customization_technique}" + ) + # Filter recipes that have SmtjRecipeTemplateS3Uri key recipes_with_template = [r for r in matching_recipes if r.get("SmtjRecipeTemplateS3Uri")] - + if not recipes_with_template: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}") + # Filter by SequenceLength before recipe selection if sequence_length is requested + if sequence_length: + requested = _parse_context_length(sequence_length) + candidates_with_context = [r for r in recipes_with_template if r.get("SequenceLength")] + if candidates_with_context: + filtered = [ + r + for r in candidates_with_context + if _parse_context_length(r.get("SequenceLength")) >= requested + ] + if filtered: + filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength"))) + recipes_with_template = filtered + else: + available = sorted( + set(r.get("SequenceLength") for r in candidates_with_context) + ) + raise ValueError( + f"No recipes found with SequenceLength >= {sequence_length}. " + f"Available sequence lengths: {available}" + ) + else: + raise ValueError( + f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, " + f"and sequence length:{sequence_length}" + ) + # Select recipe based on training type # Collect override_params from ALL matching recipes (standard + subscription) recipe = None - if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) - elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + if ( + isinstance(training_type, TrainingType) and training_type == TrainingType.LORA + ) or training_type == "LORA": + recipe = next( + ( + r + for r in recipes_with_template + if r.get("Peft") and not r.get("IsSubscriptionModel") + ), + None, + ) + elif ( + isinstance(training_type, TrainingType) and training_type == TrainingType.FULL + ) or training_type == "FULL": + recipe = next( + ( + r + for r in recipes_with_template + if not r.get("Peft") and not r.get("IsSubscriptionModel") + ), + None, + ) if not recipe: - raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") + raise ValueError( + f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}" + ) # Start with the standard recipe's override_params options_dict = {} @@ -380,14 +496,33 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni options_dict = json.loads(obj["Body"].read()) # Auto-detect and merge subscription recipe's override_params if available - if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None) + if ( + isinstance(training_type, TrainingType) and training_type == TrainingType.LORA + ) or training_type == "LORA": + sub_recipe = next( + ( + r + for r in recipes_with_template + if r.get("Peft") and r.get("IsSubscriptionModel") + ), + None, + ) else: - sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None) + sub_recipe = next( + ( + r + for r in recipes_with_template + if not r.get("Peft") and r.get("IsSubscriptionModel") + ), + None, + ) if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"): try: - sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"]) + sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace( + "{customer_id}", + sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"], + ) sub_uri_path = sub_s3_uri.replace("s3://", "") # Handle access point ARN URIs if sub_uri_path.startswith("arn:"): @@ -405,73 +540,77 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni if k not in options_dict: v_copy = v.copy() if isinstance(v, dict) else v if isinstance(v_copy, dict): - v_copy['default'] = None # No default — won't appear in to_dict() unless set + v_copy["default"] = ( + None # No default — won't appear in to_dict() unless set + ) options_dict[k] = v_copy except Exception as e: - logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}") + logger.debug( + f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}" + ) if options_dict: return FineTuningOptions(options_dict), model_arn, is_gated_model else: return FineTuningOptions({}), model_arn, is_gated_model - + except Exception as e: logger.error("Exception getting fine-tuning options: %s", e) raise -def _create_input_channels(dataset: str, content_type: Optional[str] = None, - input_compression_type: Optional[str] = None, - record_wrapper_type: Optional[str] = None, - input_mode: Optional[str] = None): +def _create_input_channels( + dataset: str, + content_type: Optional[str] = None, + input_compression_type: Optional[str] = None, + record_wrapper_type: Optional[str] = None, + input_mode: Optional[str] = None, +): """Create input channels from dataset (S3 URI or dataset ARN). - + Args: dataset: S3 URI (s3://bucket/key) or dataset ARN (arn:aws:sagemaker:...) - + Returns: list: List of Channel objects """ channels = [] - if dataset.startswith("s3://"): # S3 URI - create S3DataSource data_source = DataSource( s3_data_source={ "s3_uri": dataset, "s3_data_type": "S3Prefix", - "s3_data_distribution_type": "FullyReplicated" + "s3_data_distribution_type": "FullyReplicated", } ) else: # Dataset ARN - validate and create dataset source _validate_dataset_arn(dataset, "dataset") - data_source = DataSource( - dataset_source={"dataset_arn": dataset} - ) - + data_source = DataSource(dataset_source={"dataset_arn": dataset}) + channel = Channel( channel_name="train", data_source=data_source, content_type=content_type, compression_type=input_compression_type, record_wrapper_type=record_wrapper_type, - input_mode=input_mode - ) + input_mode=input_mode, + ) channels.append(channel) - + return channels def _resolve_model_and_name(model, sagemaker_session=None): """Resolve model and extract model name from string, ARN, or ModelPackage object. - + Args: model: Can be a model name (str), model package ARN (str), or ModelPackage object sagemaker_session: SageMaker session for API calls (required for ARN resolution) - + Returns: tuple: (resolved_model, model_name) """ @@ -481,14 +620,15 @@ def _resolve_model_and_name(model, sagemaker_session=None): region_name = sagemaker_session.boto_region_name else: # Try to get region from SAGEMAKER_REGION env var, then boto3 session, then AWS_DEFAULT_REGION - region_name = os.environ.get('SAGEMAKER_REGION') + region_name = os.environ.get("SAGEMAKER_REGION") if not region_name: try: import boto3 - region_name = boto3.Session().region_name or os.environ.get('AWS_DEFAULT_REGION') + + region_name = boto3.Session().region_name or os.environ.get("AWS_DEFAULT_REGION") except: pass - + if isinstance(model, str): # Check if it's a model package ARN if model.startswith("arn:aws:sagemaker:") and ":model-package/" in model: @@ -496,7 +636,7 @@ def _resolve_model_and_name(model, sagemaker_session=None): model_package = ModelPackage.get( model_package_name=model, session=sagemaker_session.boto_session if sagemaker_session else None, - region=sagemaker_session.boto_session.region_name if sagemaker_session else None + region=sagemaker_session.boto_session.region_name if sagemaker_session else None, ) model_name = _resolve_model_name(model_package) # Validate region availability @@ -518,23 +658,34 @@ def _resolve_model_and_name(model, sagemaker_session=None): return model, model_name -def _create_serverless_config(model_arn, customization_technique, - training_type, accept_eula, evaluator_arn=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']: +def _create_serverless_config( + model_arn, + customization_technique, + training_type, + accept_eula, + evaluator_arn=None, + sequence_length=None, + job_type=JOB_TYPE, +) -> Optional["ServerlessJobConfig"]: """Create serverless job configuration for fine-tuning. - + Args: model_arn: ARN of the base model customization_technique: Technique used (e.g., "SFT", "DPO", "RLVR", "RLAIF") training_type: Training type (TrainingType enum or string) accept_eula: Boolean indicating if EULA is accepted evaluator_arn: Optional evaluator ARN for RLVR/RLAIF + sequence_length: Optional sequence length enum value (e.g., "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K") job_type: Type of job (default: "FineTuning") - + Returns: ServerlessJobConfig object or None if required parameters are missing """ - peft = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) \ + peft = ( + None + if (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) else (training_type.value if isinstance(training_type, TrainingType) else training_type) + ) # Create ServerlessJobConfig using shapes serverless_config = ServerlessJobConfig( @@ -543,7 +694,8 @@ def _create_serverless_config(model_arn, customization_technique, customization_technique=customization_technique, peft=peft, evaluator_arn=evaluator_arn, - accept_eula=accept_eula + accept_eula=accept_eula, + sequence_length=sequence_length, ) return serverless_config @@ -551,44 +703,41 @@ def _create_serverless_config(model_arn, customization_technique, def _create_input_data_config(training_dataset, validation_dataset=None): """Create input data configuration from training and validation datasets. - + Args: training_dataset: Training dataset (method parameter takes priority over class attribute) validation_dataset: Validation dataset (method parameter takes priority over class attribute) - + Returns: List of InputData objects for training job configuration """ # Extract and validate training dataset final_training_dataset = _extract_dataset_source(training_dataset, "training_dataset") - - input_data_config = [ - InputData(channel_name="train", data_source=final_training_dataset) - ] - + + input_data_config = [InputData(channel_name="train", data_source=final_training_dataset)] + # Add validation dataset if provided if validation_dataset: final_validation_dataset = _extract_dataset_source(validation_dataset, "validation_dataset") input_data_config.append( InputData(channel_name="validation", data_source=final_validation_dataset) ) - - return input_data_config + return input_data_config def _create_model_package_config(model_package_group_name, model, sagemaker_session): """Create model package configuration with resolved ARNs. - + Args: model_package_group_name: Model package group name to resolve model: Model object (used to resolve source model package ARN if it's a ModelPackage) sagemaker_session: SageMaker session for API calls - + Returns: ModelPackageConfig object or None if no model package group name provided """ - + model_package_group_arn = None if model_package_group_name: model_package_group_arn = _resolve_model_package_group_arn( @@ -605,22 +754,21 @@ def _create_model_package_config(model_package_group_name, model, sagemaker_sess ) - -def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None, - mlflow_experiment_name=None, mlflow_run_name=None): +def _create_mlflow_config( + sagemaker_session, mlflow_resource_arn=None, mlflow_experiment_name=None, mlflow_run_name=None +): """Create MLflow configuration with resolved resource ARN. - + Args: sagemaker_session: SageMaker session for resolving MLflow ARN mlflow_resource_arn: MLflow resource ARN (if None, uses default experience) mlflow_experiment_name: MLflow experiment name mlflow_run_name: MLflow run name - + Returns: MlflowConfig object or None if no MLflow resource ARN is resolved """ - # Derive mlflow_resource_arn with default experience resolved_mlflow_arn = _resolve_mlflow_resource_arn(sagemaker_session, mlflow_resource_arn) logger.info(f"MLflow resource ARN: {resolved_mlflow_arn}") @@ -633,18 +781,18 @@ def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None, mlflow_experiment_name=mlflow_experiment_name, mlflow_run_name=mlflow_run_name, ) - + return mlflow_config -def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None): +def _create_output_config(sagemaker_session, s3_output_path=None, kms_key_id=None): """Create output data configuration with default S3 path if needed. - + Args: s3_output_path: S3 output path (if None, generates default path) sagemaker_session: SageMaker session for generating default path kms_key_id: Optional KMS key ID for encryption - + Returns: OutputDataConfig object """ @@ -652,7 +800,7 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None # Use default S3 output path if none provided if s3_output_path is None: s3_output_path = _get_default_s3_output_path(sagemaker_session) - + # Validate S3 path exists _validate_s3_path_exists(s3_output_path, sagemaker_session) @@ -662,16 +810,16 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None ) -def _convert_input_data_to_channels(input_data_config ): +def _convert_input_data_to_channels(input_data_config): """Convert InputData objects to Channel objects with S3 and dataset ARN support. - + Args: input_data_config: List of InputData objects - + Returns: List of Channel objects """ - + channels = [] for input_data in input_data_config: if input_data.data_source.startswith("s3://"): @@ -680,21 +828,19 @@ def _convert_input_data_to_channels(input_data_config ): s3_data_source={ "s3_uri": input_data.data_source, "s3_data_type": "S3Prefix", - "s3_data_distribution_type": "FullyReplicated" + "s3_data_distribution_type": "FullyReplicated", } ) else: # Dataset ARN - create dataset source - data_source = DataSource( - dataset_source={"dataset_arn": input_data.data_source} - ) + data_source = DataSource(dataset_source={"dataset_arn": input_data.data_source}) channel = Channel( channel_name=input_data.channel_name, data_source=data_source, ) channels.append(channel) - + return channels @@ -703,41 +849,45 @@ def _validate_and_resolve_model_package_group(model, model_package_group_name): # If model_package_group_name is already provided, return it as-is if model_package_group_name: return model_package_group_name - + # Try to resolve from ModelPackage if available if isinstance(model, ModelPackage): return model.model_package_group_name - + # Only validate if model_package_group_name is None and model is not ModelPackage - raise ValueError("model_package_group_name must be provided when model given is " - "not a ModelPackage artifact/not continued finetuning") + raise ValueError( + "model_package_group_name must be provided when model given is " + "not a ModelPackage artifact/not continued finetuning" + ) def _validate_eula_for_gated_model(model, accept_eula, is_gated_model): """Validate EULA acceptance for gated models. - + Args: model: Original model input (string, ARN, or ModelPackage) accept_eula: Boolean indicating if EULA is accepted is_gated_model: Boolean indicating if the model is gated - + Returns: bool: True if EULA is accepted (either explicitly or by default for ARN/ModelPackage) - + Raises: ValueError: If model is gated but accept_eula is False """ # For ModelPackage/ARN inputs, EULA is assumed accepted by default - if isinstance(model, ModelPackage) or (isinstance(model, str) and model.startswith("arn:aws:sagemaker:")): + if isinstance(model, ModelPackage) or ( + isinstance(model, str) and model.startswith("arn:aws:sagemaker:") + ): return True - + # Validate EULA acceptance for gated models if is_gated_model and not accept_eula: raise ValueError( f"Model '{model}' is a gated model and requires EULA acceptance. " "Please set accept_eula=True to proceed with training." ) - + return accept_eula @@ -745,14 +895,14 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session): """Validate S3 path and create bucket/prefix if they don't exist.""" if not s3_path.startswith("s3://"): raise ValueError(f"Invalid S3 path format: {s3_path}") - + # Parse S3 URI s3_parts = s3_path.replace("s3://", "").split("/", 1) bucket_name = s3_parts[0] prefix = s3_parts[1] if len(s3_parts) > 1 else "" - - s3_client = sagemaker_session.boto_session.client('s3') - + + s3_client = sagemaker_session.boto_session.client("s3") + try: # Check if bucket exists, create if it doesn't try: @@ -761,25 +911,24 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session): if "NoSuchBucket" in str(e) or "Not Found" in str(e): # Create bucket region = sagemaker_session.boto_region_name - if region == 'us-east-1': + if region == "us-east-1": s3_client.create_bucket(Bucket=bucket_name) else: s3_client.create_bucket( - Bucket=bucket_name, - CreateBucketConfiguration={'LocationConstraint': region} + Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region} ) else: raise - + # If prefix is provided, check if it exists, create if it doesn't if prefix: response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix, MaxKeys=1) - if 'Contents' not in response: + if "Contents" not in response: # Create the prefix by putting an empty object - if not prefix.endswith('/'): - prefix += '/' - s3_client.put_object(Bucket=bucket_name, Key=prefix, Body=b'') - + if not prefix.endswith("/"): + prefix += "/" + s3_client.put_object(Bucket=bucket_name, Key=prefix, Body=b"") + except Exception as e: raise ValueError(f"Failed to validate/create S3 path '{s3_path}': {str(e)}") @@ -787,6 +936,7 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session): def _validate_hyperparameter_values(hyperparameters: dict): """Validate hyperparameter values for allowed characters.""" import re + allowed_chars = r"^[a-zA-Z0-9/_.:,\-\s'\"\[\]]*$" for key, value in hyperparameters.items(): if isinstance(value, str) and not re.match(allowed_chars, value): diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index bd5d9a11bd..6a6f3f07bd 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -19,7 +19,7 @@ _create_mlflow_config, _create_model_package_config, _validate_eula_for_gated_model, - _validate_hyperparameter_values + _validate_hyperparameter_values, ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -53,19 +53,19 @@ class DPOTrainer(BaseTrainer): model="meta-llama/Llama-2-7b-hf", model_package_group="my-dpo-models" ) - + # Create training job (non-blocking) training_job = trainer.train( training_dataset="s3://bucket/preference_data.jsonl", wait=False ) - + # Wait for completion training_job.wait() - + # Refresh job status training_job.refresh() - + # Get the fine-tuned model package ARN model_package_arn = training_job.output_model_package_arn @@ -100,31 +100,38 @@ class DPOTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ + def __init__( - self, - model: Union[str, ModelPackage], - training_type: Union[TrainingType, str] = TrainingType.LORA, - model_package_group: Optional[Union[str, ModelPackageGroup]] = None, - mlflow_resource_arn: Optional[str] = None, - mlflow_experiment_name: Optional[str] = None, - mlflow_run_name: Optional[str] = None, - training_dataset: Optional[Union[str, DataSet]] = None, - validation_dataset: Optional[Union[str, DataSet]] = None, - s3_output_path: Optional[str] = None, - kms_key_id: Optional[str] = None, - networking: Optional[VpcConfig] = None, - accept_eula: bool = False, - stopping_condition: Optional[StoppingCondition] = None, - **kwargs, + self, + model: Union[str, ModelPackage], + training_type: Union[TrainingType, str] = TrainingType.LORA, + model_package_group: Optional[Union[str, ModelPackageGroup]] = None, + mlflow_resource_arn: Optional[str] = None, + mlflow_experiment_name: Optional[str] = None, + mlflow_run_name: Optional[str] = None, + training_dataset: Optional[Union[str, DataSet]] = None, + validation_dataset: Optional[Union[str, DataSet]] = None, + s3_output_path: Optional[str] = None, + kms_key_id: Optional[str] = None, + networking: Optional[VpcConfig] = None, + accept_eula: bool = False, + stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, + **kwargs, ): super().__init__(**kwargs) - + self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session) self.training_type = training_type - self.model_package_group = _validate_and_resolve_model_package_group(model, - model_package_group) + self.model_package_group = _validate_and_resolve_model_package_group( + model, model_package_group + ) self.mlflow_resource_arn = mlflow_resource_arn self.mlflow_experiment_name = mlflow_experiment_name self.mlflow_run_name = mlflow_run_name @@ -134,19 +141,23 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.DPO.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - - )) - + self.hyperparameters, self._model_arn, is_gated_model = ( + _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.DPO.value, + self.training_type, + self.sagemaker_session + or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length, + ) + ) + # Process hyperparameters self._process_hyperparameters() - + # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) @@ -154,35 +165,37 @@ def _process_hyperparameters(self): """Remove hyperparameter keys that are handled by constructor inputs.""" if self.hyperparameters: # Remove keys that are handled by constructor inputs - if hasattr(self.hyperparameters, 'data_path'): - delattr(self.hyperparameters, 'data_path') - self.hyperparameters._specs.pop('data_path', None) - if hasattr(self.hyperparameters, 'output_path'): - delattr(self.hyperparameters, 'output_path') - self.hyperparameters._specs.pop('output_path', None) - if hasattr(self.hyperparameters, 'data_s3_path'): - delattr(self.hyperparameters, 'data_s3_path') - self.hyperparameters._specs.pop('data_s3_path', None) - if hasattr(self.hyperparameters, 'output_s3_path'): - delattr(self.hyperparameters, 'output_s3_path') - self.hyperparameters._specs.pop('output_s3_path', None) - if hasattr(self.hyperparameters, 'training_data_name'): - delattr(self.hyperparameters, 'training_data_name') - self.hyperparameters._specs.pop('training_data_name', None) - if hasattr(self.hyperparameters, 'validation_data_name'): - delattr(self.hyperparameters, 'validation_data_name') - self.hyperparameters._specs.pop('validation_data_name', None) - if hasattr(self.hyperparameters, 'validation_data_path'): - delattr(self.hyperparameters, 'validation_data_path') - self.hyperparameters._specs.pop('validation_data_path', None) + if hasattr(self.hyperparameters, "data_path"): + delattr(self.hyperparameters, "data_path") + self.hyperparameters._specs.pop("data_path", None) + if hasattr(self.hyperparameters, "output_path"): + delattr(self.hyperparameters, "output_path") + self.hyperparameters._specs.pop("output_path", None) + if hasattr(self.hyperparameters, "data_s3_path"): + delattr(self.hyperparameters, "data_s3_path") + self.hyperparameters._specs.pop("data_s3_path", None) + if hasattr(self.hyperparameters, "output_s3_path"): + delattr(self.hyperparameters, "output_s3_path") + self.hyperparameters._specs.pop("output_s3_path", None) + if hasattr(self.hyperparameters, "training_data_name"): + delattr(self.hyperparameters, "training_data_name") + self.hyperparameters._specs.pop("training_data_name", None) + if hasattr(self.hyperparameters, "validation_data_name"): + delattr(self.hyperparameters, "validation_data_name") + self.hyperparameters._specs.pop("validation_data_name", None) + if hasattr(self.hyperparameters, "validation_data_path"): + delattr(self.hyperparameters, "validation_data_path") + self.hyperparameters._specs.pop("validation_data_path", None) @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DPOTrainer.train") - def train(self, - training_dataset: Optional[Union[str, DataSet]] = None, - validation_dataset: Optional[Union[str, DataSet]] = None, - wait: bool = True, - wait_timeout: Optional[int] = None, - poll: int = 5): + def train( + self, + training_dataset: Optional[Union[str, DataSet]] = None, + validation_dataset: Optional[Union[str, DataSet]] = None, + wait: bool = True, + wait_timeout: Optional[int] = None, + poll: int = 5, + ): """Execute the DPO training job. Parameters: @@ -215,24 +228,26 @@ def train(self, logger.info(f"Training Job Name: {current_training_job_name}") print(f"Training Job Name: {current_training_job_name}") - #data - input_data_config = _create_input_data_config(training_dataset or self.training_dataset, - validation_dataset or self.validation_dataset - ) + # data + input_data_config = _create_input_data_config( + training_dataset or self.training_dataset, validation_dataset or self.validation_dataset + ) channels = _convert_input_data_to_channels(input_data_config) output_config = _create_output_config( s3_output_path=self.s3_output_path, sagemaker_session=sagemaker_session, - kms_key_id=self.kms_key_id + kms_key_id=self.kms_key_id, ) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.DPO.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.DPO.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + sequence_length=self.sequence_length, + job_type=JOB_TYPE, + ) mlflow_config = _create_mlflow_config( sagemaker_session, @@ -247,7 +262,7 @@ def train(self, model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group, model=self.model, - sagemaker_session=sagemaker_session + sagemaker_session=sagemaker_session, ) vpc_config = self.networking if self.networking else None @@ -268,7 +283,7 @@ def train(self, "region": sagemaker_session.boto_session.region_name, "tags": tags, } - + # Only pass stopping_condition if explicitly provided by user if self.stopping_condition is not None: create_args["stopping_condition"] = self.stopping_condition @@ -282,15 +297,15 @@ def train(self, if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError - try : + + try: wait_kwargs = {} if wait_timeout is not None: - wait_kwargs['timeout'] = wait_timeout - wait_kwargs['poll'] = poll + wait_kwargs["timeout"] = wait_timeout + wait_kwargs["poll"] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) self.latest_training_job = training_job return training_job - diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index f2d8460989..09084359f1 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -2,7 +2,12 @@ import logging from sagemaker.train.base_trainer import BaseTrainer from sagemaker.train.common import TrainingType, CustomizationTechnique, JOB_TYPE -from sagemaker.core.resources import TrainingJob, ModelPackageGroup, MlflowTrackingServer, ModelPackage +from sagemaker.core.resources import ( + TrainingJob, + ModelPackageGroup, + MlflowTrackingServer, + ModelPackage, +) from sagemaker.core.shapes import VpcConfig from sagemaker.train.defaults import TrainDefaults from sagemaker.train.utils import _get_unique_name, _get_studio_tags @@ -23,7 +28,7 @@ _create_mlflow_config, _create_model_package_config, _validate_eula_for_gated_model, - _validate_hyperparameter_values + _validate_hyperparameter_values, ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -60,19 +65,19 @@ class RLAIFTrainer(BaseTrainer): reward_model_id="reward-model-id", reward_prompt="Rate the helpfulness of this response on a scale of 1-10" ) - + # Create training job (non-blocking) training_job = trainer.train( training_dataset="s3://bucket/rlaif_data.jsonl", wait=False ) - + # Wait for completion training_job.wait() - + # Refresh job status training_job.refresh() - + # Get the fine-tuned model package ARN model_package_arn = training_job.output_model_package_arn @@ -114,6 +119,10 @@ class RLAIFTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -135,6 +144,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -143,8 +153,9 @@ def __init__( self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session) self.training_type = training_type - self.model_package_group = _validate_and_resolve_model_package_group(model, - model_package_group) + self.model_package_group = _validate_and_resolve_model_package_group( + model, model_package_group + ) self.reward_model_id = self._validate_reward_model_id(reward_model_id) self.reward_prompt = reward_prompt self.mlflow_resource_arn = mlflow_resource_arn @@ -156,18 +167,23 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.RLAIF.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) - + self.hyperparameters, self._model_arn, is_gated_model = ( + _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLAIF.value, + self.training_type, + self.sagemaker_session + or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length, + ) + ) + # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) - + # Process reward_prompt parameter self._process_hyperparameters() @@ -181,23 +197,33 @@ def _validate_reward_model_id(self, reward_model_id): f"Invalid reward_model_id '{reward_model_id}'. " f"Available models are: {list(_ALLOWED_REWARD_MODEL_IDS.keys())}" ) - + # Check region compatibility - session = self.sagemaker_session if hasattr(self, 'sagemaker_session') and self.sagemaker_session else TrainDefaults.get_sagemaker_session() + session = ( + self.sagemaker_session + if hasattr(self, "sagemaker_session") and self.sagemaker_session + else TrainDefaults.get_sagemaker_session() + ) current_region = session.boto_region_name allowed_regions = _ALLOWED_REWARD_MODEL_IDS[reward_model_id] - + if current_region not in allowed_regions: raise ValueError( f"Reward model '{reward_model_id}' is not available in region '{current_region}'. " f"Available regions for this model: {allowed_regions}" ) - + return reward_model_id - @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train") - def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5): + def train( + self, + training_dataset: Optional[Union[str, DataSet]] = None, + validation_dataset: Optional[Union[str, DataSet]] = None, + wait: bool = True, + wait_timeout: Optional[int] = None, + poll: int = 5, + ): """Execute the RLAIF training job. Parameters: @@ -229,26 +255,28 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati logger.info(f"Training Job Name: {current_training_job_name}") - #data - input_data_config = _create_input_data_config(training_dataset or self.training_dataset, - validation_dataset or self.validation_dataset - ) + # data + input_data_config = _create_input_data_config( + training_dataset or self.training_dataset, validation_dataset or self.validation_dataset + ) channels = _convert_input_data_to_channels(input_data_config) output_config = _create_output_config( s3_output_path=self.s3_output_path, sagemaker_session=sagemaker_session, - kms_key_id=self.kms_key_id + kms_key_id=self.kms_key_id, ) - evaluator_arn = getattr(self, '_evaluator_arn', None) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.RLAIF.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - evaluator_arn=evaluator_arn, - job_type=JOB_TYPE - ) + evaluator_arn = getattr(self, "_evaluator_arn", None) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.RLAIF.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + evaluator_arn=evaluator_arn, + sequence_length=self.sequence_length, + job_type=JOB_TYPE, + ) mlflow_config = _create_mlflow_config( sagemaker_session, @@ -264,7 +292,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group, model=self.model, - sagemaker_session=sagemaker_session + sagemaker_session=sagemaker_session, ) vpc_config = self.networking if self.networking else None @@ -285,7 +313,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati "region": sagemaker_session.boto_session.region_name, "tags": tags, } - + # Only pass stopping_condition if explicitly provided by user if self.stopping_condition is not None: create_args["stopping_condition"] = self.stopping_condition @@ -299,11 +327,12 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError - try : + + try: wait_kwargs = {} if wait_timeout is not None: - wait_kwargs['timeout'] = wait_timeout - wait_kwargs['poll'] = poll + wait_kwargs["timeout"] = wait_timeout + wait_kwargs["poll"] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) @@ -313,27 +342,31 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati def _process_hyperparameters(self): """Update hyperparameters based on constructor inputs and process reward_prompt.""" - if not self.hyperparameters or not hasattr(self.hyperparameters, '_specs') or not self.hyperparameters._specs: + if ( + not self.hyperparameters + or not hasattr(self.hyperparameters, "_specs") + or not self.hyperparameters._specs + ): return - + # Remove keys that are handled by constructor inputs - if hasattr(self.hyperparameters, 'output_path'): - delattr(self.hyperparameters, 'output_path') - self.hyperparameters._specs.pop('output_path', None) - if hasattr(self.hyperparameters, 'data_path'): - delattr(self.hyperparameters, 'data_path') - self.hyperparameters._specs.pop('data_path', None) - if hasattr(self.hyperparameters, 'validation_data_path'): - delattr(self.hyperparameters, 'validation_data_path') - self.hyperparameters._specs.pop('validation_data_path', None) - + if hasattr(self.hyperparameters, "output_path"): + delattr(self.hyperparameters, "output_path") + self.hyperparameters._specs.pop("output_path", None) + if hasattr(self.hyperparameters, "data_path"): + delattr(self.hyperparameters, "data_path") + self.hyperparameters._specs.pop("data_path", None) + if hasattr(self.hyperparameters, "validation_data_path"): + delattr(self.hyperparameters, "validation_data_path") + self.hyperparameters._specs.pop("validation_data_path", None) + # Update judge_model_id if reward_model_id is provided - if hasattr(self, 'reward_model_id') and self.reward_model_id: + if hasattr(self, "reward_model_id") and self.reward_model_id: judge_model_value = f"bedrock/{self.reward_model_id}" self.hyperparameters.judge_model_id = judge_model_value - + # Process reward_prompt parameter - if hasattr(self, 'reward_prompt') and self.reward_prompt: + if hasattr(self, "reward_prompt") and self.reward_prompt: if isinstance(self.reward_prompt, str): if self.reward_prompt.startswith("Builtin"): # Handle builtin reward prompts @@ -343,9 +376,9 @@ def _process_hyperparameters(self): self._process_non_builtin_reward_prompt() else: # Handle evaluator object - if hasattr(self.hyperparameters, 'judge_prompt_template'): - delattr(self.hyperparameters, 'judge_prompt_template') - self.hyperparameters._specs.pop('judge_prompt_template', None) + if hasattr(self.hyperparameters, "judge_prompt_template"): + delattr(self.hyperparameters, "judge_prompt_template") + self.hyperparameters._specs.pop("judge_prompt_template", None) evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") self._evaluator_arn = evaluator_arn @@ -353,10 +386,10 @@ def _process_hyperparameters(self): def _process_non_builtin_reward_prompt(self): """Process non-builtin reward prompt (ARN or hub content name).""" # Remove judge_prompt_template for non-builtin prompts - if hasattr(self.hyperparameters, 'judge_prompt_template'): - delattr(self.hyperparameters, 'judge_prompt_template') - self.hyperparameters._specs.pop('judge_prompt_template', None) - + if hasattr(self.hyperparameters, "judge_prompt_template"): + delattr(self.hyperparameters, "judge_prompt_template") + self.hyperparameters._specs.pop("judge_prompt_template", None) + if self.reward_prompt.startswith("arn:aws:sagemaker:"): # Validate and assign ARN evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") @@ -364,39 +397,39 @@ def _process_non_builtin_reward_prompt(self): else: try: session = TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - ) + sagemaker_session=self.sagemaker_session + ) hub_content = _get_hub_content_metadata( hub_name=get_sagemaker_hub_name(), hub_content_type="JsonDoc", hub_content_name=self.reward_prompt, session=session.boto_session, - region=session.boto_session.region_name + region=session.boto_session.region_name, ) # Store ARN for evaluator_arn self._evaluator_arn = hub_content.hub_content_arn except Exception as e: - raise ValueError(f"Custom prompt '{self.reward_prompt}' not found in HubContent: {e}") - - + raise ValueError( + f"Custom prompt '{self.reward_prompt}' not found in HubContent: {e}" + ) def _update_judge_prompt_template_direct(self, reward_prompt): """Update judge_prompt_template based on Builtin reward function.""" # Get available templates from hyperparameters specs - judge_prompt_spec = self.hyperparameters._specs.get('judge_prompt_template', {}) - available_templates = judge_prompt_spec.get('enum', []) - + judge_prompt_spec = self.hyperparameters._specs.get("judge_prompt_template", {}) + available_templates = judge_prompt_spec.get("enum", []) + if not available_templates: # If no enum found, use the current value as the only available option - current_value = getattr(self.hyperparameters, 'judge_prompt_template', None) + current_value = getattr(self.hyperparameters, "judge_prompt_template", None) if current_value: available_templates = [current_value] else: return - + # Extract template name after "Builtin." and convert to lowercase template_name = reward_prompt.split(".", 1)[1].lower() - + # Find matching template by extracting filename without extension matching_template = None for template in available_templates: @@ -404,14 +437,15 @@ def _update_judge_prompt_template_direct(self, reward_prompt): if template_filename == template_name: matching_template = template break - + if matching_template: self.hyperparameters.judge_prompt_template = matching_template else: - available_options = [f"Builtin.{t.split('/')[-1].replace('.jinja', '')}" for t in available_templates] + available_options = [ + f"Builtin.{t.split('/')[-1].replace('.jinja', '')}" for t in available_templates + ] raise ValueError( f"Selected reward function option '{reward_prompt}' is not available. " f"Choose one from the available options: {available_options}. " f"Example: reward_prompt='Builtin.summarize'" ) - diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index 333a93fc55..3abcbbf47e 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -2,7 +2,12 @@ import logging from sagemaker.train.base_trainer import BaseTrainer from sagemaker.train.common import TrainingType, CustomizationTechnique, JOB_TYPE -from sagemaker.core.resources import TrainingJob, ModelPackageGroup, MlflowTrackingServer, ModelPackage +from sagemaker.core.resources import ( + TrainingJob, + ModelPackageGroup, + MlflowTrackingServer, + ModelPackage, +) from sagemaker.core.shapes import VpcConfig from sagemaker.train.defaults import TrainDefaults from sagemaker.train.utils import _get_unique_name, _get_studio_tags @@ -21,7 +26,7 @@ _create_mlflow_config, _create_model_package_config, _validate_eula_for_gated_model, - _validate_hyperparameter_values + _validate_hyperparameter_values, ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -56,19 +61,19 @@ class RLVRTrainer(BaseTrainer): model_package_group="my-rlvr-models", custom_reward_function="arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/my-evaluator/1.0" ) - + # Create training job (non-blocking) training_job = trainer.train( training_dataset="s3://bucket/rlvr_data.jsonl", wait=False ) - + # Wait for completion training_job.wait() - + # Refresh job status training_job.refresh() - + # Get the fine-tuned model package ARN model_package_arn = training_job.output_model_package_arn @@ -106,6 +111,10 @@ class RLVRTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -126,6 +135,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -134,8 +144,9 @@ def __init__( self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session) self.training_type = training_type - self.model_package_group = _validate_and_resolve_model_package_group(model, - model_package_group) + self.model_package_group = _validate_and_resolve_model_package_group( + model, model_package_group + ) self.custom_reward_function = custom_reward_function self.mlflow_resource_arn = mlflow_resource_arn self.mlflow_experiment_name = mlflow_experiment_name @@ -146,18 +157,23 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.RLVR.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) - + self.hyperparameters, self._model_arn, is_gated_model = ( + _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLVR.value, + self.training_type, + self.sagemaker_session + or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length, + ) + ) + # Remove constructor-handled hyperparameters self._process_hyperparameters() - + # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) @@ -165,28 +181,34 @@ def _process_hyperparameters(self): """Remove hyperparameter keys that are handled by constructor inputs.""" if self.hyperparameters: # Remove keys that are handled by constructor inputs - if hasattr(self.hyperparameters, 'data_s3_path'): - delattr(self.hyperparameters, 'data_s3_path') - self.hyperparameters._specs.pop('data_s3_path', None) - if hasattr(self.hyperparameters, 'reward_lambda_arn'): - delattr(self.hyperparameters, 'reward_lambda_arn') - self.hyperparameters._specs.pop('reward_lambda_arn', None) - if hasattr(self.hyperparameters, 'preset_reward_function'): - delattr(self.hyperparameters, 'preset_reward_function') - self.hyperparameters._specs.pop('preset_reward_function', None) - if hasattr(self.hyperparameters, 'data_path'): - delattr(self.hyperparameters, 'data_path') - self.hyperparameters._specs.pop('data_path', None) - if hasattr(self.hyperparameters, 'validation_data_path'): - delattr(self.hyperparameters, 'validation_data_path') - self.hyperparameters._specs.pop('validation_data_path', None) - if hasattr(self.hyperparameters, 'output_path'): - delattr(self.hyperparameters, 'output_path') - self.hyperparameters._specs.pop('output_path', None) + if hasattr(self.hyperparameters, "data_s3_path"): + delattr(self.hyperparameters, "data_s3_path") + self.hyperparameters._specs.pop("data_s3_path", None) + if hasattr(self.hyperparameters, "reward_lambda_arn"): + delattr(self.hyperparameters, "reward_lambda_arn") + self.hyperparameters._specs.pop("reward_lambda_arn", None) + if hasattr(self.hyperparameters, "preset_reward_function"): + delattr(self.hyperparameters, "preset_reward_function") + self.hyperparameters._specs.pop("preset_reward_function", None) + if hasattr(self.hyperparameters, "data_path"): + delattr(self.hyperparameters, "data_path") + self.hyperparameters._specs.pop("data_path", None) + if hasattr(self.hyperparameters, "validation_data_path"): + delattr(self.hyperparameters, "validation_data_path") + self.hyperparameters._specs.pop("validation_data_path", None) + if hasattr(self.hyperparameters, "output_path"): + delattr(self.hyperparameters, "output_path") + self.hyperparameters._specs.pop("output_path", None) @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLVRTrainer.train") - def train(self, training_dataset: Optional[Union[str, DataSet]] = None, - validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5): + def train( + self, + training_dataset: Optional[Union[str, DataSet]] = None, + validation_dataset: Optional[Union[str, DataSet]] = None, + wait: bool = True, + wait_timeout: Optional[int] = None, + poll: int = 5, + ): """Execute the RLVR training job. Parameters: @@ -219,27 +241,33 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, logger.info(f"Training Job Name: {current_training_job_name}") - #data - input_data_config = _create_input_data_config(training_dataset or self.training_dataset, - validation_dataset or self.validation_dataset - ) + # data + input_data_config = _create_input_data_config( + training_dataset or self.training_dataset, validation_dataset or self.validation_dataset + ) channels = _convert_input_data_to_channels(input_data_config) output_config = _create_output_config( s3_output_path=self.s3_output_path, sagemaker_session=sagemaker_session, - kms_key_id=self.kms_key_id + kms_key_id=self.kms_key_id, ) # Extract and validate evaluator ARN - evaluator_arn = _extract_evaluator_arn(self.custom_reward_function) if self.custom_reward_function else None - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.RLVR.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - evaluator_arn=evaluator_arn, - job_type=JOB_TYPE - ) + evaluator_arn = ( + _extract_evaluator_arn(self.custom_reward_function) + if self.custom_reward_function + else None + ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.RLVR.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + evaluator_arn=evaluator_arn, + sequence_length=self.sequence_length, + job_type=JOB_TYPE, + ) mlflow_config = _create_mlflow_config( sagemaker_session, mlflow_resource_arn=self.mlflow_resource_arn, @@ -248,14 +276,14 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, ) final_hyperparameters = self.hyperparameters.to_dict() - + # Validate hyperparameter values _validate_hyperparameter_values(final_hyperparameters) model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group, model=self.model, - sagemaker_session=sagemaker_session + sagemaker_session=sagemaker_session, ) vpc_config = self.networking if self.networking else None @@ -276,7 +304,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, "region": sagemaker_session.boto_session.region_name, "tags": tags, } - + # Only pass stopping_condition if explicitly provided by user if self.stopping_condition is not None: create_args["stopping_condition"] = self.stopping_condition @@ -290,11 +318,12 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError + try: wait_kwargs = {} if wait_timeout is not None: - wait_kwargs['timeout'] = wait_timeout - wait_kwargs['poll'] = poll + wait_kwargs["timeout"] = wait_timeout + wait_kwargs["poll"] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 233f169d0f..9d47e9742a 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -20,7 +20,7 @@ _create_mlflow_config, _create_model_package_config, _validate_eula_for_gated_model, - _validate_hyperparameter_values + _validate_hyperparameter_values, ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -55,19 +55,19 @@ class SFTTrainer(BaseTrainer): model="meta-llama/Llama-2-7b-hf", model_package_group="my-fine-tuned-models" ) - + # Create training job (non-blocking) training_job = trainer.train( training_dataset="s3://bucket/train.jsonl", wait=False ) - + # Wait for completion training_job.wait() - + # Refresh job status training_job.refresh() - + # Get the fine-tuned model artifacts ARN model_package_arn = training_job.output_model_package_arn @@ -102,6 +102,10 @@ class SFTTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -119,16 +123,18 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: Optional[bool] = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) - + # Resolve model and model name self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session) self.training_type = training_type - self.model_package_group = _validate_and_resolve_model_package_group(model, - model_package_group) + self.model_package_group = _validate_and_resolve_model_package_group( + model, model_package_group + ) self.mlflow_resource_arn = mlflow_resource_arn self.mlflow_experiment_name = mlflow_experiment_name self.mlflow_run_name = mlflow_run_name @@ -138,18 +144,23 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.SFT.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) - + self.hyperparameters, self._model_arn, is_gated_model = ( + _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.SFT.value, + self.training_type, + self.sagemaker_session + or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length, + ) + ) + # Process hyperparameters self._process_hyperparameters() - + # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) @@ -157,30 +168,37 @@ def _process_hyperparameters(self): """Remove hyperparameter keys that are handled by constructor inputs.""" if self.hyperparameters: # Remove keys that are handled by constructor inputs - if hasattr(self.hyperparameters, 'data_path'): - delattr(self.hyperparameters, 'data_path') - self.hyperparameters._specs.pop('data_path', None) - if hasattr(self.hyperparameters, 'output_path'): - delattr(self.hyperparameters, 'output_path') - self.hyperparameters._specs.pop('output_path', None) - if hasattr(self.hyperparameters, 'data_s3_path'): - delattr(self.hyperparameters, 'data_s3_path') - self.hyperparameters._specs.pop('data_s3_path', None) - if hasattr(self.hyperparameters, 'output_s3_path'): - delattr(self.hyperparameters, 'output_s3_path') - self.hyperparameters._specs.pop('output_s3_path', None) - if hasattr(self.hyperparameters, 'training_data_name'): - delattr(self.hyperparameters, 'training_data_name') - self.hyperparameters._specs.pop('training_data_name', None) - if hasattr(self.hyperparameters, 'validation_data_name'): - delattr(self.hyperparameters, 'validation_data_name') - self.hyperparameters._specs.pop('validation_data_name', None) - if hasattr(self.hyperparameters, 'validation_data_path'): - delattr(self.hyperparameters, 'validation_data_path') - self.hyperparameters._specs.pop('validation_data_path', None) + if hasattr(self.hyperparameters, "data_path"): + delattr(self.hyperparameters, "data_path") + self.hyperparameters._specs.pop("data_path", None) + if hasattr(self.hyperparameters, "output_path"): + delattr(self.hyperparameters, "output_path") + self.hyperparameters._specs.pop("output_path", None) + if hasattr(self.hyperparameters, "data_s3_path"): + delattr(self.hyperparameters, "data_s3_path") + self.hyperparameters._specs.pop("data_s3_path", None) + if hasattr(self.hyperparameters, "output_s3_path"): + delattr(self.hyperparameters, "output_s3_path") + self.hyperparameters._specs.pop("output_s3_path", None) + if hasattr(self.hyperparameters, "training_data_name"): + delattr(self.hyperparameters, "training_data_name") + self.hyperparameters._specs.pop("training_data_name", None) + if hasattr(self.hyperparameters, "validation_data_name"): + delattr(self.hyperparameters, "validation_data_name") + self.hyperparameters._specs.pop("validation_data_name", None) + if hasattr(self.hyperparameters, "validation_data_path"): + delattr(self.hyperparameters, "validation_data_path") + self.hyperparameters._specs.pop("validation_data_path", None) @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="SFTTrainer.train") - def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5): + def train( + self, + training_dataset: Optional[Union[str, DataSet]] = None, + validation_dataset: Optional[Union[str, DataSet]] = None, + wait: bool = True, + wait_timeout: Optional[int] = None, + poll: int = 5, + ): """Execute the SFT training job. Parameters: @@ -213,24 +231,26 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati logger.info(f"Training Job Name: {current_training_job_name}") - #data - input_data_config = _create_input_data_config(training_dataset or self.training_dataset, - validation_dataset or self.validation_dataset - ) + # data + input_data_config = _create_input_data_config( + training_dataset or self.training_dataset, validation_dataset or self.validation_dataset + ) channels = _convert_input_data_to_channels(input_data_config) output_config = _create_output_config( s3_output_path=self.s3_output_path, sagemaker_session=sagemaker_session, - kms_key_id=self.kms_key_id + kms_key_id=self.kms_key_id, ) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.SFT.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.SFT.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + sequence_length=self.sequence_length, + job_type=JOB_TYPE, + ) mlflow_config = _create_mlflow_config( sagemaker_session, mlflow_resource_arn=self.mlflow_resource_arn, @@ -239,14 +259,14 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) final_hyperparameters = self.hyperparameters.to_dict() - + # Validate hyperparameter values _validate_hyperparameter_values(final_hyperparameters) model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group, model=self.model, - sagemaker_session=sagemaker_session + sagemaker_session=sagemaker_session, ) vpc_config = self.networking if self.networking else None @@ -267,7 +287,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati "region": sagemaker_session.boto_session.region_name, "tags": tags, } - + # Only pass stopping_condition if explicitly provided by user if self.stopping_condition is not None: create_args["stopping_condition"] = self.stopping_condition @@ -281,16 +301,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError - try : + + try: wait_kwargs = {} if wait_timeout is not None: - wait_kwargs['timeout'] = wait_timeout - wait_kwargs['poll'] = poll + wait_kwargs["timeout"] = wait_timeout + wait_kwargs["poll"] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) self._latest_training_job = training_job return training_job - - diff --git a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py index 93be84a738..b09b608ae3 100644 --- a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py @@ -21,11 +21,12 @@ from sagemaker.train.sft_trainer import SFTTrainer from sagemaker.train.common import TrainingType + @pytest.mark.gpu_intensive def test_sft_trainer_lora_complete_workflow(sagemaker_session): """Test complete SFT training workflow with LORA.""" unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}" - + sft_trainer = SFTTrainer( model="meta-textgeneration-llama-3-2-1b-instruct", training_type=TrainingType.LORA, @@ -35,27 +36,27 @@ def test_sft_trainer_lora_complete_workflow(sagemaker_session): accept_eula=True, base_job_name=f"sft-lora-integ-{unique_id}", ) - + # Create training job training_job = sft_trainer.train(wait=False) - + # Manual wait loop to avoid resource_config issue max_wait_time = 3600 # 1 hour timeout - poll_interval = 30 # Check every 30 seconds + poll_interval = 30 # Check every 30 seconds start_time = time.time() - + while time.time() - start_time < max_wait_time: training_job.refresh() status = training_job.training_job_status - + if status in ["Completed", "Failed", "Stopped"]: break - + time.sleep(poll_interval) - + # Verify job completed successfully assert training_job.training_job_status == "Completed" - assert hasattr(training_job, 'output_model_package_arn') + assert hasattr(training_job, "output_model_package_arn") assert training_job.output_model_package_arn is not None @@ -73,26 +74,26 @@ def test_sft_trainer_with_validation_dataset(sagemaker_session): accept_eula=True, base_job_name=f"sft-val-integ-{unique_id}", ) - + training_job = sft_trainer.train(wait=False) - + # Manual wait loop max_wait_time = 3600 poll_interval = 30 start_time = time.time() - + while time.time() - start_time < max_wait_time: training_job.refresh() status = training_job.training_job_status - + if status in ["Completed", "Failed", "Stopped"]: break - + time.sleep(poll_interval) - + # Verify job completed successfully assert training_job.training_job_status == "Completed" - assert hasattr(training_job, 'output_model_package_arn') + assert hasattr(training_job, "output_model_package_arn") @pytest.mark.gpu_intensive @@ -104,7 +105,7 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1): unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}" sft_trainer_nova = SFTTrainer( model="nova-textgeneration-lite-v2", - training_type=TrainingType.LORA, + training_type=TrainingType.LORA, model_package_group="sdk-test-finetuned-models", mlflow_experiment_name="test-nova-finetuned-models-exp", mlflow_run_name="test-nova-finetuned-models-run", @@ -113,25 +114,61 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1): sagemaker_session=sagemaker_session_us_east_1, base_job_name=f"sft-nova-integ-{unique_id}", ) - + # Create training job training_job = sft_trainer_nova.train(wait=False) - + # Manual wait loop max_wait_time = 3600 # 1 hour timeout - poll_interval = 30 # Check every 30 seconds + poll_interval = 30 # Check every 30 seconds start_time = time.time() - + while time.time() - start_time < max_wait_time: training_job.refresh() status = training_job.training_job_status - + if status in ["Completed", "Failed", "Stopped"]: break - + time.sleep(poll_interval) - + # Verify job completed successfully assert training_job.training_job_status == "Completed" - assert hasattr(training_job, 'output_model_package_arn') + assert hasattr(training_job, "output_model_package_arn") + assert training_job.output_model_package_arn is not None + + +@pytest.mark.gpu_intensive +def test_sft_trainer_lora_with_sequence_length(sagemaker_session): + """Test SFT training workflow with LORA and sequence_length specified.""" + unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}" + + sft_trainer = SFTTrainer( + model="meta-textgeneration-llama-3-2-1b-instruct", + training_type=TrainingType.LORA, + model_package_group="arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models", + training_dataset="s3://mc-flows-sdk-testing/input_data/sft/sample_data_256_final.jsonl", + s3_output_path="s3://mc-flows-sdk-testing/output/", + accept_eula=True, + sequence_length="8K", + base_job_name=f"sft-seqlen-integ-{unique_id}", + ) + + training_job = sft_trainer.train(wait=False) + + max_wait_time = 3600 + poll_interval = 30 + start_time = time.time() + + while time.time() - start_time < max_wait_time: + training_job.refresh() + status = training_job.training_job_status + + if status in ["Completed", "Failed", "Stopped"]: + break + + time.sleep(poll_interval) + + assert training_job.training_job_status == "Completed" + assert hasattr(training_job, "output_model_package_arn") assert training_job.output_model_package_arn is not None diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 7a63e36234..0f5ff6f92e 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -27,7 +27,8 @@ _create_mlflow_config, _validate_eula_for_gated_model, _validate_model_region_availability, - _validate_s3_path_exists + _validate_s3_path_exists, + _parse_context_length, ) from sagemaker.core.resources import ModelPackage, ModelPackageGroup from sagemaker.ai_registry.dataset import DataSet @@ -37,47 +38,53 @@ class TestFinetuneUtils: - @patch('sagemaker.train.common_utils.finetune_utils.boto3.client') - @patch('sagemaker.train.common_utils.finetune_utils.Session') + @patch("sagemaker.train.common_utils.finetune_utils.boto3.client") + @patch("sagemaker.train.common_utils.finetune_utils.Session") def test__get_beta_session(self, mock_session, mock_boto_client): mock_client = Mock() mock_boto_client.return_value = mock_client mock_sagemaker_session = Mock() mock_session.return_value = mock_sagemaker_session - + result = _get_beta_session() - + assert result == mock_sagemaker_session mock_boto_client.assert_called_once() def test_get_current_domain_id_with_studio_arn(self): mock_session = Mock() - mock_session.get_caller_identity_arn.return_value = "arn:aws:sts::123456789012:assumed-role/SageMakerStudioExecutionRole/SageMaker" - + mock_session.get_caller_identity_arn.return_value = ( + "arn:aws:sts::123456789012:assumed-role/SageMakerStudioExecutionRole/SageMaker" + ) + result = _get_current_domain_id(mock_session) - + assert result is None def test_get_current_domain_id_with_domain_arn(self): mock_session = Mock() - mock_session.get_caller_identity_arn.return_value = "arn:aws:sagemaker:us-east-1:123456789012:user-profile/d-123456789/test-user" - + mock_session.get_caller_identity_arn.return_value = ( + "arn:aws:sagemaker:us-east-1:123456789012:user-profile/d-123456789/test-user" + ) + result = _get_current_domain_id(mock_session) - + assert result == "d-123456789" def test__resolve_mlflow_resource_arn_with_provided_arn(self): mock_session = Mock() provided_arn = "arn:aws:mlflow:us-east-1:123456789012:tracking-server/test" - + result = _resolve_mlflow_resource_arn(mock_session, provided_arn) - + assert result == provided_arn - @patch('sagemaker.train.common_utils.finetune_utils._get_current_domain_id') - @patch('sagemaker.train.common_utils.finetune_utils._create_mlflow_app') - @patch('sagemaker.core.resources.MlflowApp.get_all') - def test__resolve_mlflow_resource_arn_creates_new_app(self, mock_get_all, mock_create_app, mock_get_domain): + @patch("sagemaker.train.common_utils.finetune_utils._get_current_domain_id") + @patch("sagemaker.train.common_utils.finetune_utils._create_mlflow_app") + @patch("sagemaker.core.resources.MlflowApp.get_all") + def test__resolve_mlflow_resource_arn_creates_new_app( + self, mock_get_all, mock_create_app, mock_get_domain + ): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" mock_get_domain.return_value = "d-123456789" @@ -85,13 +92,13 @@ def test__resolve_mlflow_resource_arn_creates_new_app(self, mock_get_all, mock_c mock_app = Mock() mock_app.arn = "arn:aws:mlflow:us-east-1:123456789012:tracking-server/new-app" mock_create_app.return_value = mock_app - + result = _resolve_mlflow_resource_arn(mock_session, None) - + assert result == mock_app.arn - @patch('sagemaker.train.common_utils.finetune_utils.TrainDefaults.get_role') - @patch('sagemaker.core.resources.MlflowApp.create') + @patch("sagemaker.train.common_utils.finetune_utils.TrainDefaults.get_role") + @patch("sagemaker.core.resources.MlflowApp.create") def test_create_mlflow_app_success(self, mock_create, mock_get_role): mock_session = Mock() mock_session.region_name = "us-east-1" @@ -99,63 +106,67 @@ def test_create_mlflow_app_success(self, mock_create, mock_get_role): mock_sts_client.get_caller_identity.return_value = {"Account": "123456789012"} mock_s3_client = Mock() mock_s3_client.list_objects_v2.return_value = {"Contents": [{"Key": "mlflow-artifacts/"}]} - + def mock_client(service_name): - if service_name == 'sts': + if service_name == "sts": return mock_sts_client - elif service_name == 's3': + elif service_name == "s3": return mock_s3_client return Mock() - + mock_session.boto_session.client.side_effect = mock_client mock_session.boto_session.region_name = "us-east-1" mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role" mock_app = Mock() mock_app.status = "Created" mock_create.return_value = mock_app - + result = _create_mlflow_app(mock_session) - + assert result == mock_app mock_create.assert_called_once() mock_app.refresh.assert_called() - @patch('sagemaker.core.resources.MlflowApp.create') + @patch("sagemaker.core.resources.MlflowApp.create") def test_create_mlflow_app_failure(self, mock_create): mock_session = Mock() mock_create.side_effect = Exception("Creation failed") - + result = _create_mlflow_app(mock_session) - + assert result is None def test__validate_dataset_arn_valid(self): valid_arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test-dataset/1.0" - + # Should not raise exception _validate_dataset_arn(valid_arn, "test_dataset") def test__validate_dataset_arn_invalid(self): invalid_arn = "invalid-arn" - - with pytest.raises(ValueError, match="test_dataset must be a valid SageMaker hub-content DataSet ARN"): + + with pytest.raises( + ValueError, match="test_dataset must be a valid SageMaker hub-content DataSet ARN" + ): _validate_dataset_arn(invalid_arn, "test_dataset") def test_validate_evaluator_arn_valid(self): valid_arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/test-evaluator/1.0" - + # Should not raise exception _validate_evaluator_arn(valid_arn, "test_evaluator") def test_validate_evaluator_arn_invalid(self): invalid_arn = "invalid-arn" - - with pytest.raises(ValueError, match="test_evaluator must be a valid SageMaker hub-content evaluator ARN"): + + with pytest.raises( + ValueError, match="test_evaluator must be a valid SageMaker hub-content evaluator ARN" + ): _validate_evaluator_arn(invalid_arn, "test_evaluator") def test__validate_model_package_group_requirement_with_model_package(self): model_package = Mock(spec=ModelPackage) - + # Should not raise exception _validate_model_package_group_requirement(model_package, None) @@ -163,33 +174,37 @@ def test__validate_model_package_group_requirement_without_group_name(self): with pytest.raises(ValueError, match="model_package_group_name must be provided"): _validate_model_package_group_requirement("string-model", None) - @patch('sagemaker.core.resources.ModelPackageGroup.get') + @patch("sagemaker.core.resources.ModelPackageGroup.get") def test__resolve_model_package_group_arn_with_name(self, mock_get): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" mock_group = Mock() - mock_group.model_package_group_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" + mock_group.model_package_group_arn = ( + "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" + ) mock_get.return_value = mock_group - + result = _resolve_model_package_group_arn("test-group", mock_session) - + assert result == mock_group.model_package_group_arn def test__resolve_model_package_group_arn_with_arn(self): mock_session = Mock() arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" - + result = _resolve_model_package_group_arn(arn, mock_session) - + assert result == arn def test__resolve_model_package_group_arn_with_object(self): mock_session = Mock() mock_group = Mock(spec=ModelPackageGroup) - mock_group.model_package_group_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" - + mock_group.model_package_group_arn = ( + "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" + ) + result = _resolve_model_package_group_arn(mock_group, mock_session) - + assert result == mock_group.model_package_group_arn def test__get_default_s3_output_path(self): @@ -198,50 +213,50 @@ def test__get_default_s3_output_path(self): mock_sts_client.get_caller_identity.return_value = {"Account": "123456789012"} mock_session.boto_session.client.return_value = mock_sts_client mock_session.boto_session.region_name = "us-east-1" - + result = _get_default_s3_output_path(mock_session) - + assert result == "s3://sagemaker-us-east-1-123456789012/output" def test__extract_dataset_source_s3_uri(self): s3_uri = "s3://bucket/dataset" - + result = _extract_dataset_source(s3_uri, "test_dataset") - + assert result == s3_uri - @patch('sagemaker.train.common_utils.finetune_utils._validate_dataset_arn') + @patch("sagemaker.train.common_utils.finetune_utils._validate_dataset_arn") def test__extract_dataset_source_arn(self, mock_validate): arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test/1.0" - + result = _extract_dataset_source(arn, "test_dataset") - + assert result == arn mock_validate.assert_called_once_with(arn, "test_dataset") def test__extract_dataset_source_dataset_object(self): mock_dataset = Mock(spec=DataSet) mock_dataset.arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test/1.0" - + result = _extract_dataset_source(mock_dataset, "test_dataset") - + assert result == mock_dataset.arn - @patch('sagemaker.train.common_utils.finetune_utils._validate_evaluator_arn') + @patch("sagemaker.train.common_utils.finetune_utils._validate_evaluator_arn") def test_extract_evaluator_arn_string(self, mock_validate): arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/test/1.0" - + result = _extract_evaluator_arn(arn, "test_evaluator") - + assert result == arn mock_validate.assert_called_once_with(arn, "test_evaluator") def test_extract_evaluator_arn_object(self): mock_evaluator = Mock() mock_evaluator.arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/test/1.0" - + result = _extract_evaluator_arn(mock_evaluator, "test_evaluator") - + assert result == mock_evaluator.arn def test__resolve_model_name_with_model_package(self): @@ -251,9 +266,9 @@ def test__resolve_model_name_with_model_package(self): mock_base_model.hub_content_name = "test-model" mock_container.base_model = mock_base_model mock_model_package.inference_specification.containers = [mock_container] - + result = _resolve_model_name(mock_model_package) - + assert result == "test-model" def test__resolve_model_name_with_none(self): @@ -264,41 +279,41 @@ def test__resolve_model_package_arn_success(self): mock_model_package = Mock() expected_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/test-package" mock_model_package.model_package_arn = expected_arn - + result = _resolve_model_package_arn(mock_model_package) - + assert result == expected_arn def test__resolve_model_package_arn_failure(self): mock_model_package = Mock() mock_model_package.model_package_arn = None - + result = _resolve_model_package_arn(mock_model_package) - + assert result is None - @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') - @patch('boto3.client') + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + @patch("boto3.client") def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get_hub_content): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" - + # Mock hub content metadata mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", - 'hub_content_document': { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { "GatedBucket": False, "RecipeCollection": [ { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://bucket/template.json", "SmtjOverrideParamsS3Uri": "s3://bucket/params.json", - "Peft": True + "Peft": True, } - ] - } + ], + }, } - + # Mock S3 client mock_s3_client = Mock() mock_boto_client.return_value = mock_s3_client @@ -307,9 +322,9 @@ def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get } mock_session.boto_session.client.return_value = mock_s3_client mock_session.boto_session.client.return_value = mock_s3_client - + result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session) - + # Handle case where function might return None if result is not None: options, model_arn, is_gated_model = result @@ -322,7 +337,7 @@ def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get def test_create_input_channels_s3_uri(self): result = _create_input_channels("s3://bucket/data", "application/json") - + assert len(result) == 1 assert result[0].channel_name == "train" assert result[0].data_source.s3_data_source.s3_uri == "s3://bucket/data" @@ -330,9 +345,9 @@ def test_create_input_channels_s3_uri(self): def test_create_input_channels_dataset_arn(self): arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/DataSet/test/1.0" - + result = _create_input_channels(arn) - + assert len(result) == 1 assert result[0].channel_name == "train" assert result[0].data_source.dataset_source.dataset_arn == arn @@ -340,24 +355,24 @@ def test_create_input_channels_dataset_arn(self): def test__validate_and_resolve_model_package_group_with_provided_name(self): model = "test-model" group_name = "test-group" - + result = _validate_and_resolve_model_package_group(model, group_name) - + assert result == group_name def test__validate_and_resolve_model_package_group_from_model_package(self): mock_model = Mock(spec=ModelPackage) mock_model.model_package_group_name = "extracted-group" - + result = _validate_and_resolve_model_package_group(mock_model, None) - + assert result == "extracted-group" def test__validate_and_resolve_model_package_group_missing_both(self): with pytest.raises(ValueError, match="model_package_group_name must be provided"): _validate_and_resolve_model_package_group("string-model", None) - @patch('sagemaker.core.resources.ModelPackage.get') + @patch("sagemaker.core.resources.ModelPackage.get") def test__resolve_model_and_name_with_model_package_arn(self, mock_get): mock_session = Mock() mock_session.boto_region_name = "us-east-1" # Set valid region @@ -369,15 +384,17 @@ def test__resolve_model_and_name_with_model_package_arn(self, mock_get): mock_model_package.inference_specification = Mock() mock_model_package.inference_specification.containers = [mock_container] mock_get.return_value = mock_model_package - - model, name = _resolve_model_and_name("arn:aws:sagemaker:us-east-1:123456789012:model-package/test", mock_session) - + + model, name = _resolve_model_and_name( + "arn:aws:sagemaker:us-east-1:123456789012:model-package/test", mock_session + ) + assert model == mock_model_package assert name == "test-model" def test__resolve_model_and_name_with_string(self): model, name = _resolve_model_and_name("test-model") - + assert model == "test-model" assert name == "test-model" @@ -389,15 +406,15 @@ def test__resolve_model_and_name_with_model_package_object(self): mock_container.base_model = mock_base_model mock_model_package.inference_specification = Mock() mock_model_package.inference_specification.containers = [mock_container] - + model, name = _resolve_model_and_name(mock_model_package) - + assert model == mock_model_package assert name == "test-model" def test__create_serverless_config_with_lora(self): config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True) - + assert config.job_type == "FineTuning" assert config.base_model_arn == "model-arn" assert config.customization_technique == "SFT" @@ -405,14 +422,13 @@ def test__create_serverless_config_with_lora(self): def test__create_serverless_config_with_full(self): config = _create_serverless_config("model-arn", "SFT", TrainingType.FULL, accept_eula=True) - + assert config.peft is None def test__create_input_data_config(self): - config = _create_input_data_config("s3://bucket/train", "s3://bucket/val") - + assert len(config) == 2 assert config[0].channel_name == "train" assert config[1].channel_name == "validation" @@ -421,30 +437,34 @@ def test__create_model_package_config(self): mock_session = Mock() mock_model = Mock(spec=ModelPackage) mock_model.model_package_arn = "source-arn" - - with patch('sagemaker.train.common_utils.finetune_utils._resolve_model_package_group_arn') as mock_resolve: + + with patch( + "sagemaker.train.common_utils.finetune_utils._resolve_model_package_group_arn" + ) as mock_resolve: mock_resolve.return_value = "group-arn" config = _create_model_package_config("test-group", mock_model, mock_session) - + assert config.model_package_group_arn == "group-arn" assert config.source_model_package_arn == "source-arn" def test__create_mlflow_config(self): mock_session = Mock() - - with patch('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn') as mock_resolve: + + with patch( + "sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn" + ) as mock_resolve: mock_resolve.return_value = "mlflow-arn" config = _create_mlflow_config(mock_session, mlflow_experiment_name="test-exp") - + assert config.mlflow_resource_arn == "mlflow-arn" assert config.mlflow_experiment_name == "test-exp" - @patch('sagemaker.train.common_utils.finetune_utils._validate_s3_path_exists') + @patch("sagemaker.train.common_utils.finetune_utils._validate_s3_path_exists") def test__create_output_config(self, mock_validate_s3): mock_session = Mock() - + config = _create_output_config(mock_session, "s3://bucket/output", "kms-key") - + assert config.s3_output_path == "s3://bucket/output" assert config.kms_key_id == "kms-key" mock_validate_s3.assert_called_once_with("s3://bucket/output", mock_session) @@ -453,22 +473,23 @@ def test__convert_input_data_to_channels(self): input_data = [InputData(channel_name="train", data_source="s3://bucket/data")] channels = _convert_input_data_to_channels(input_data) - + assert len(channels) == 1 assert channels[0].channel_name == "train" def test__validate_eula_for_gated_model_with_model_package(self): """Test EULA validation returns True for ModelPackage input""" from sagemaker.core.resources import ModelPackage + model_package = Mock(spec=ModelPackage) - + result = _validate_eula_for_gated_model(model_package, False, True) assert result == True def test__validate_eula_for_gated_model_with_arn(self): """Test EULA validation returns True for ARN input""" model_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/test/1" - + result = _validate_eula_for_gated_model(model_arn, False, True) assert result == True @@ -495,7 +516,9 @@ def test__validate_model_region_availability_nova_valid_region(self): def test__validate_model_region_availability_nova_invalid_region(self): """Test Nova model validation fails for invalid region""" - with pytest.raises(ValueError, match="Region 'eu-west-1' does not support model customization"): + with pytest.raises( + ValueError, match="Region 'eu-west-1' does not support model customization" + ): _validate_model_region_availability("nova-textgeneration-lite-v2", "eu-west-1") def test__validate_model_region_availability_open_weights_valid_region(self): @@ -505,57 +528,63 @@ def test__validate_model_region_availability_open_weights_valid_region(self): def test__validate_model_region_availability_open_weights_invalid_region(self): """Test open weights model validation fails for invalid region""" - with pytest.raises(ValueError, match="Region 'us-west-1' does not support model customization"): + with pytest.raises( + ValueError, match="Region 'us-west-1' does not support model customization" + ): _validate_model_region_availability("meta-textgeneration-llama-3-2-1b", "us-west-1") def test__validate_s3_path_exists_invalid_format(self): """Test S3 path validation fails for invalid format""" mock_session = Mock() - + with pytest.raises(ValueError, match="Invalid S3 path format"): _validate_s3_path_exists("invalid-path", mock_session) - @patch('boto3.client') + @patch("boto3.client") def test__validate_s3_path_exists_bucket_only_success(self, mock_boto_client): """Test S3 path validation succeeds for bucket-only path""" mock_session = Mock() mock_s3_client = Mock() mock_session.boto_session.client.return_value = mock_s3_client - + _validate_s3_path_exists("s3://test-bucket", mock_session) - + mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket") - @patch('boto3.client') + @patch("boto3.client") def test__validate_s3_path_exists_with_prefix_exists(self, mock_boto_client): """Test S3 path validation succeeds when prefix exists""" mock_session = Mock() mock_s3_client = Mock() mock_session.boto_session.client.return_value = mock_s3_client mock_s3_client.list_objects_v2.return_value = {"Contents": [{"Key": "prefix/file.txt"}]} - + _validate_s3_path_exists("s3://test-bucket/prefix/", mock_session) - + mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket") - mock_s3_client.list_objects_v2.assert_called_once_with(Bucket="test-bucket", Prefix="prefix/", MaxKeys=1) + mock_s3_client.list_objects_v2.assert_called_once_with( + Bucket="test-bucket", Prefix="prefix/", MaxKeys=1 + ) - @patch('boto3.client') + @patch("boto3.client") def test__validate_s3_path_exists_with_prefix_not_exists(self, mock_boto_client): """Test S3 path validation creates prefix when it doesn't exist""" mock_session = Mock() mock_s3_client = Mock() mock_session.boto_session.client.return_value = mock_s3_client mock_s3_client.list_objects_v2.return_value = {} # No contents - - _validate_s3_path_exists("s3://test-bucket/prefix", mock_session) - - mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket") - mock_s3_client.list_objects_v2.assert_called_once_with(Bucket="test-bucket", Prefix="prefix", MaxKeys=1) - mock_s3_client.put_object.assert_called_once_with(Bucket="test-bucket", Key="prefix/", Body=b'') + _validate_s3_path_exists("s3://test-bucket/prefix", mock_session) + mock_s3_client.head_bucket.assert_called_once_with(Bucket="test-bucket") + mock_s3_client.list_objects_v2.assert_called_once_with( + Bucket="test-bucket", Prefix="prefix", MaxKeys=1 + ) + mock_s3_client.put_object.assert_called_once_with( + Bucket="test-bucket", Key="prefix/", Body=b"" + ) - @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_get_hub_content): """When and user is subscribed, datamix HPs are available.""" mock_session = Mock() @@ -563,34 +592,40 @@ def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_ge mock_s3 = Mock() mock_sts = Mock() mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} - mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + mock_session.boto_session.client.side_effect = lambda service, **kwargs: ( + mock_s3 if service == "s3" else mock_sts + ) mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", - 'hub_content_document': { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { "GatedBucket": False, "RecipeCollection": [ { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", - "Name": "standard_sft" + "Name": "standard_sft", }, { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-123456789012/source/template.yaml", "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", "Name": "datamix_sft", - "IsSubscriptionModel": True - } - ] - } + "IsSubscriptionModel": True, + }, + ], + }, } # Standard recipe returns base params - standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) + standard_params = json.dumps( + {"max_steps": {"type": "integer", "required": True, "default": 100}} + ) # Subscription recipe returns datamix params - datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}}) + datamix_params = json.dumps( + {"customer_data_percent": {"type": "integer", "required": False, "default": 50}} + ) mock_s3.get_object.side_effect = [ {"Body": Mock(read=Mock(return_value=standard_params.encode()))}, @@ -598,15 +633,22 @@ def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_ge ] options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( - "test-model", "SFT", "FULL", mock_session, + "test-model", + "SFT", + "FULL", + mock_session, ) assert "max_steps" in options._specs assert "customer_data_percent" in options._specs - assert options._specs["customer_data_percent"]["default"] is None # defaults are None so they dont serialize unless explicitly set - - @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') - def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(self, mock_get_hub_content): + assert ( + options._specs["customer_data_percent"]["default"] is None + ) # defaults are None so they dont serialize unless explicitly set + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps( + self, mock_get_hub_content + ): """When (default), datamix HPs are NOT available.""" mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" @@ -614,70 +656,83 @@ def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(self, moc mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", - 'hub_content_document': { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { "GatedBucket": False, "RecipeCollection": [ { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", - "Name": "standard_sft" + "Name": "standard_sft", }, { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml", "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", "Name": "datamix_sft", - "IsSubscriptionModel": True - } - ] - } + "IsSubscriptionModel": True, + }, + ], + }, } - standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) - mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=standard_params.encode()))} + standard_params = json.dumps( + {"max_steps": {"type": "integer", "required": True, "default": 100}} + ) + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=standard_params.encode())) + } options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( - "test-model", "SFT", "FULL", mock_session, + "test-model", + "SFT", + "FULL", + mock_session, ) assert "max_steps" in options._specs assert "customer_data_percent" not in options._specs - @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') - def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, mock_get_hub_content): + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed( + self, mock_get_hub_content + ): """When but user is NOT subscribed, falls back gracefully.""" mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" mock_s3 = Mock() mock_sts = Mock() mock_sts.get_caller_identity.return_value = {"Account": "999999999999"} - mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + mock_session.boto_session.client.side_effect = lambda service, **kwargs: ( + mock_s3 if service == "s3" else mock_sts + ) mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", - 'hub_content_document': { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { "GatedBucket": False, "RecipeCollection": [ { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", - "Name": "standard_sft" + "Name": "standard_sft", }, { "CustomizationTechnique": "SFT", "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml", "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", "Name": "datamix_sft", - "IsSubscriptionModel": True - } - ] - } + "IsSubscriptionModel": True, + }, + ], + }, } - standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) + standard_params = json.dumps( + {"max_steps": {"type": "integer", "required": True, "default": 100}} + ) # First call succeeds (standard recipe), second call fails (access denied) mock_s3.get_object.side_effect = [ {"Body": Mock(read=Mock(return_value=standard_params.encode()))}, @@ -685,9 +740,289 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, ] options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( - "test-model", "SFT", "FULL", mock_session, + "test-model", + "SFT", + "FULL", + mock_session, ) # Should still have standard params, just not datamix ones assert "max_steps" in options._specs assert "customer_data_percent" not in options._specs + + def test__create_serverless_config_with_sequence_length(self): + config = _create_serverless_config( + "model-arn", "SFT", TrainingType.LORA, accept_eula=True, sequence_length="8K" + ) + + assert config.sequence_length == "8K" + assert config.base_model_arn == "model-arn" + + def test__create_serverless_config_without_sequence_length(self): + config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True) + + assert config.sequence_length is None + + def test__parse_context_length_with_k_suffix(self): + assert _parse_context_length("8K") == 8192 + assert _parse_context_length("32K") == 32768 + assert _parse_context_length("128K") == 131072 + + def test__parse_context_length_with_lowercase(self): + assert _parse_context_length("8k") == 8192 + + def test__parse_context_length_with_integer(self): + assert _parse_context_length("4096") == 4096 + + def test__parse_context_length_with_none(self): + assert _parse_context_length(None) == 0 + + def test__parse_context_length_with_empty(self): + assert _parse_context_length("") == 0 + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_content): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}')) + } + mock_session.boto_session.client.return_value = mock_s3 + + mock_get_hub_content.return_value = { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K", + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json", + "Peft": True, + "SequenceLength": "32K", + }, + ], + }, + } + + result = _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "LORA", mock_session, sequence_length="8K" + ) + + if result is not None: + options, model_arn, is_gated_model = result + # Should pick the 32K recipe (smallest >= 8K) + mock_s3.get_object.assert_called_once() + call_args = mock_s3.get_object.call_args[1] + assert "params-32k" in call_args["Key"] + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_raises_when_no_sufficient_context_length( + self, mock_get_hub_content + ): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + + mock_get_hub_content.return_value = { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K", + } + ], + }, + } + + # Requesting 128K but only 4K available — should raise + with pytest.raises(ValueError, match="No recipes found with SequenceLength >= 128K"): + _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "LORA", mock_session, sequence_length="128K" + ) + + def test__parse_context_length_with_invalid_k_value(self): + assert _parse_context_length("abcK") == 0 + + def test__parse_context_length_with_non_numeric_string(self): + assert _parse_context_length("hello") == 0 + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_raises_when_no_recipes_have_sequence_length( + self, mock_get_hub_content + ): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + + mock_get_hub_content.return_value = { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params.json", + "Peft": True, + } + ], + }, + } + + with pytest.raises(ValueError, match="and sequence length"): + _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "LORA", mock_session, sequence_length="8K" + ) + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_filters_by_sequence_length_full_training( + self, mock_get_hub_content + ): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 8192}}')) + } + mock_session.boto_session.client.return_value = mock_s3 + + mock_get_hub_content.return_value = { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-8k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-8k.json", + "SequenceLength": "8K", + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json", + "SequenceLength": "32K", + }, + ], + }, + } + + result = _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "FULL", mock_session, sequence_length="8K" + ) + + if result is not None: + options, model_arn, is_gated_model = result + mock_s3.get_object.assert_called_once() + call_args = mock_s3.get_object.call_args[1] + assert "params-8k" in call_args["Key"] + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_selects_smallest_sufficient_sequence_length( + self, mock_get_hub_content + ): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 16384}}')) + } + mock_session.boto_session.client.return_value = mock_s3 + + mock_get_hub_content.return_value = { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K", + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-16k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-16k.json", + "Peft": True, + "SequenceLength": "16K", + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-128k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-128k.json", + "Peft": True, + "SequenceLength": "128K", + }, + ], + }, + } + + result = _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "LORA", mock_session, sequence_length="8K" + ) + + if result is not None: + options, model_arn, is_gated_model = result + # Should pick 16K (smallest >= 8K), not 128K + mock_s3.get_object.assert_called_once() + call_args = mock_s3.get_object.call_args[1] + assert "params-16k" in call_args["Key"] + + @patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata") + def test__get_fine_tuning_options_without_sequence_length_uses_first_recipe( + self, mock_get_hub_content + ): + """Verify that when no sequence_length is provided, existing behavior is unchanged.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 4096}}')) + } + mock_session.boto_session.client.return_value = mock_s3 + + mock_get_hub_content.return_value = { + "hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + "hub_content_document": { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-first.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-first.json", + "Peft": True, + "SequenceLength": "4K", + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-second.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-second.json", + "Peft": True, + "SequenceLength": "32K", + }, + ], + }, + } + + result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session) + + if result is not None: + options, model_arn, is_gated_model = result + # Without sequence_length, should pick the first matching recipe + mock_s3.get_object.assert_called_once() + call_args = mock_s3.get_object.call_args[1] + assert "params-first" in call_args["Key"] diff --git a/sagemaker-train/tests/unit/train/test_dpo_trainer.py b/sagemaker-train/tests/unit/train/test_dpo_trainer.py index 1b70e0bf89..7648b46e35 100644 --- a/sagemaker-train/tests/unit/train/test_dpo_trainer.py +++ b/sagemaker-train/tests/unit/train/test_dpo_trainer.py @@ -506,4 +506,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = DPOTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = DPOTrainer(model="test-model", model_package_group="test-group", sequence_length="8K") + assert trainer.sequence_length == "8K" + + @patch('sagemaker.train.dpo_trainer._resolve_model_and_name') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.dpo_trainer._get_unique_name') + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._create_input_data_config') + @patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.dpo_trainer._create_output_config') + @patch('sagemaker.train.dpo_trainer._create_serverless_config') + @patch('sagemaker.train.dpo_trainer._create_mlflow_config') + @patch('sagemaker.train.dpo_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = DPOTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="16K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "16K" diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py index e5666883e8..6811c45540 100644 --- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py @@ -682,4 +682,70 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_hyperparams._specs = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_hyperparams._specs = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", sequence_length="128K") + assert trainer.sequence_length == "128K" + + @patch('sagemaker.train.rlaif_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlaif_trainer._get_unique_name') + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._create_input_data_config') + @patch('sagemaker.train.rlaif_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlaif_trainer._create_output_config') + @patch('sagemaker.train.rlaif_trainer._create_serverless_config') + @patch('sagemaker.train.rlaif_trainer._create_mlflow_config') + @patch('sagemaker.train.rlaif_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_fine_tuning_options._specs = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="64K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "64K" diff --git a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py index 320b81555d..b4c01385e2 100644 --- a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py @@ -509,4 +509,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = RLVRTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLVRTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", sequence_length="32K") + assert trainer.sequence_length == "32K" + + @patch('sagemaker.train.rlvr_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlvr_trainer._get_unique_name') + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._create_input_data_config') + @patch('sagemaker.train.rlvr_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlvr_trainer._create_output_config') + @patch('sagemaker.train.rlvr_trainer._create_serverless_config') + @patch('sagemaker.train.rlvr_trainer._create_mlflow_config') + @patch('sagemaker.train.rlvr_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="4K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "4K" diff --git a/sagemaker-train/tests/unit/train/test_sft_trainer.py b/sagemaker-train/tests/unit/train/test_sft_trainer.py index 108990f839..01fc21f4bd 100644 --- a/sagemaker-train/tests/unit/train/test_sft_trainer.py +++ b/sagemaker-train/tests/unit/train/test_sft_trainer.py @@ -520,4 +520,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = SFTTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = SFTTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = SFTTrainer(model="test-model", model_package_group="test-group", sequence_length="8K") + assert trainer.sequence_length == "8K" + + @patch('sagemaker.train.sft_trainer._resolve_model_and_name') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.sft_trainer._get_unique_name') + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._create_input_data_config') + @patch('sagemaker.train.sft_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.sft_trainer._create_output_config') + @patch('sagemaker.train.sft_trainer._create_serverless_config') + @patch('sagemaker.train.sft_trainer._create_mlflow_config') + @patch('sagemaker.train.sft_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = SFTTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="16K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "16K"