diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 4aa67e3d28..6f37601aa4 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -358,21 +358,58 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}") # Select recipe based on training type + # Collect override_params from ALL matching recipes (standard + subscription) recipe = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - recipe = next((r for r in recipes_with_template if r.get("Peft")), None) + recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - recipe = next((r for r in recipes_with_template if not r.get("Peft")), None) + recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) if not recipe: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") - elif recipe and recipe.get("SmtjOverrideParamsS3Uri"): + # Start with the standard recipe's override_params + options_dict = {} + if recipe.get("SmtjOverrideParamsS3Uri"): s3_uri = recipe["SmtjOverrideParamsS3Uri"] - s3 = boto3.client("s3") - bucket, key = s3_uri.replace("s3://", "").split("/", 1) + s3 = sagemaker_session.boto_session.client("s3") + uri_path = s3_uri.replace("s3://", "") + bucket, key = uri_path.split("/", 1) obj = s3.get_object(Bucket=bucket, Key=key) options_dict = json.loads(obj["Body"].read()) + + # Auto-detect and merge subscription recipe's override_params if available + if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": + sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None) + else: + sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None) + + if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"): + try: + sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"]) + sub_uri_path = sub_s3_uri.replace("s3://", "") + # Handle access point ARN URIs + if sub_uri_path.startswith("arn:"): + arn_parts = sub_uri_path.split("/", 2) + sub_bucket = arn_parts[0] + "/" + arn_parts[1] + sub_key = arn_parts[2] if len(arn_parts) > 2 else "" + else: + sub_bucket, sub_key = sub_uri_path.split("/", 1) + s3_sub = sagemaker_session.boto_session.client("s3") + sub_obj = s3_sub.get_object(Bucket=sub_bucket, Key=sub_key) + sub_options = json.loads(sub_obj["Body"].read()) + # Merge: subscription params into _specs only (don't set defaults) + # This makes them settable but not serialized unless user explicitly sets them + for k, v in sub_options.items(): + if k not in options_dict: + v_copy = v.copy() if isinstance(v, dict) else v + if isinstance(v_copy, dict): + v_copy['default'] = None # No default — won't appear in to_dict() unless set + options_dict[k] = v_copy + except Exception as e: + logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}") + + if options_dict: return FineTuningOptions(options_dict), model_arn, is_gated_model else: return FineTuningOptions({}), model_arn, is_gated_model diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 5abad5d296..7a63e36234 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -1,3 +1,4 @@ +import json import pytest from unittest.mock import Mock, patch, MagicMock from sagemaker.train.common_utils.finetune_utils import ( @@ -304,6 +305,8 @@ def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get mock_s3_client.get_object.return_value = { "Body": Mock(read=Mock(return_value=b'{"learning_rate": 0.001}')) } + mock_session.boto_session.client.return_value = mock_s3_client + mock_session.boto_session.client.return_value = mock_s3_client result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session) @@ -551,3 +554,140 @@ def test__validate_s3_path_exists_with_prefix_not_exists(self, mock_boto_client) mock_s3_client.put_object.assert_called_once_with(Bucket="test-bucket", Key="prefix/", Body=b'') + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_get_hub_content): + """When and user is subscribed, datamix HPs are available.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_sts = Mock() + mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", + "Name": "standard_sft" + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-123456789012/source/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", + "Name": "datamix_sft", + "IsSubscriptionModel": True + } + ] + } + } + + # Standard recipe returns base params + standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) + # Subscription recipe returns datamix params + datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}}) + + mock_s3.get_object.side_effect = [ + {"Body": Mock(read=Mock(return_value=standard_params.encode()))}, + {"Body": Mock(read=Mock(return_value=datamix_params.encode()))}, + ] + + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "FULL", mock_session, + ) + + assert "max_steps" in options._specs + assert "customer_data_percent" in options._specs + assert options._specs["customer_data_percent"]["default"] is None # defaults are None so they dont serialize unless explicitly set + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(self, mock_get_hub_content): + """When (default), datamix HPs are NOT available.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", + "Name": "standard_sft" + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", + "Name": "datamix_sft", + "IsSubscriptionModel": True + } + ] + } + } + + standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) + mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=standard_params.encode()))} + + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "FULL", mock_session, + ) + + assert "max_steps" in options._specs + assert "customer_data_percent" not in options._specs + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, mock_get_hub_content): + """When but user is NOT subscribed, falls back gracefully.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_sts = Mock() + mock_sts.get_caller_identity.return_value = {"Account": "999999999999"} + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", + "Name": "standard_sft" + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json", + "Name": "datamix_sft", + "IsSubscriptionModel": True + } + ] + } + } + + standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) + # First call succeeds (standard recipe), second call fails (access denied) + mock_s3.get_object.side_effect = [ + {"Body": Mock(read=Mock(return_value=standard_params.encode()))}, + Exception("Access Denied"), + ] + + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( + "test-model", "SFT", "FULL", mock_session, + ) + + # Should still have standard params, just not datamix ones + assert "max_steps" in options._specs + assert "customer_data_percent" not in options._specs