From b526d225697cee6faec732d06a64dd39b627f3ec Mon Sep 17 00:00:00 2001 From: Haard Mehta Date: Mon, 15 Jun 2026 19:05:16 +0000 Subject: [PATCH] fix(train): Handle subscription-only models in recipe selection Models like nova-textgeneration-micro-v2 have only IsSubscriptionModel recipes. The primary recipe filter required not IsSubscriptionModel, causing ValueError when no standard recipe exists. Fix: - Fallback to subscription recipe as primary when no standard one exists - Handle access point ARN URIs in primary recipe download path - Guard against merging subscription override_params into itself - Only resolve {customer_id} placeholder when present in URI Tests: - subscription_only_model_lora: Micro v2 LoRA case - subscription_only_model_full: Micro v2 full-rank case - mixed_recipes_still_prefers_standard: Lite v2 regression guard Fixes: V2248468914 --- .../train/common_utils/finetune_utils.py | 25 ++- .../train/common_utils/test_finetune_utils.py | 143 ++++++++++++++++++ 2 files changed, 165 insertions(+), 3 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 6479e803bd..f12d27b512 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -448,32 +448,51 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}") # Select recipe based on training type - # Collect override_params from ALL matching recipes (standard + subscription) + # Prefer non-subscription (standard) recipe as primary; fallback to subscription + # recipe if model only has subscription recipes (e.g., nova-textgeneration-micro-v2) recipe = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) + if not recipe: + recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None) elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + if not recipe: + recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None) if not recipe: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") - # Start with the standard recipe's override_params + # Start with the primary recipe's override_params options_dict = {} if recipe.get("SmtjOverrideParamsS3Uri"): s3_uri = recipe["SmtjOverrideParamsS3Uri"] + if "{customer_id}" in s3_uri: + s3_uri = s3_uri.replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"]) s3 = sagemaker_session.boto_session.client("s3") uri_path = s3_uri.replace("s3://", "") - bucket, key = uri_path.split("/", 1) + # Handle access point ARN URIs (subscription-only models) + if uri_path.startswith("arn:"): + arn_parts = uri_path.split("/", 2) + bucket = arn_parts[0] + "/" + arn_parts[1] + key = arn_parts[2] if len(arn_parts) > 2 else "" + else: + bucket, key = uri_path.split("/", 1) obj = s3.get_object(Bucket=bucket, Key=key) options_dict = json.loads(obj["Body"].read()) # Auto-detect and merge subscription recipe's override_params if available + # Skip if primary recipe is already the subscription recipe (subscription-only model) + sub_recipe = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None) else: sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None) + # Guard: don't merge subscription override_params into itself + if sub_recipe and sub_recipe is recipe: + sub_recipe = None + if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"): try: sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"]) diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index c98dea477f..37609df42d 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -864,3 +864,146 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, # Should still have standard params, just not datamix ones assert "max_steps" in options._specs assert "customer_data_percent" not in options._specs + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test__get_fine_tuning_options_subscription_only_model_lora(self, mock_get_hub_content): + """When model has ONLY subscription recipes (e.g. nova-micro-v2), should select it as primary.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_sts = Mock() + mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://subscription-bucket/lora_template.yaml", + "SmtjOverrideParamsS3Uri": "s3://subscription-bucket/lora_params.json", + "Name": "nova_micro_2_lora_sft_datamix", + "Peft": "LORA", + "IsSubscriptionModel": True + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://subscription-bucket/full_template.yaml", + "SmtjOverrideParamsS3Uri": "s3://subscription-bucket/full_params.json", + "Name": "nova_micro_2_full_sft_datamix", + "IsSubscriptionModel": True + } + ] + } + } + + datamix_params = json.dumps({ + "max_steps": {"type": "integer", "required": True, "default": 100}, + "customer_data_percent": {"type": "integer", "required": False, "default": 50} + }) + mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=datamix_params.encode()))} + + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( + "nova-micro-v2", "SFT", "LORA", mock_session, + ) + + # Should succeed (not raise ValueError) and have params from subscription recipe + assert "max_steps" in options._specs + assert "customer_data_percent" in options._specs + # Should only call get_object ONCE (no merge into itself) + assert mock_s3.get_object.call_count == 1 + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test__get_fine_tuning_options_subscription_only_model_full(self, mock_get_hub_content): + """When model has ONLY subscription recipes with FULL rank, should select it as primary.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_sts = Mock() + mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://subscription-bucket/full_template.yaml", + "SmtjOverrideParamsS3Uri": "s3://subscription-bucket/full_params.json", + "Name": "nova_micro_2_full_sft_datamix", + "IsSubscriptionModel": True + } + ] + } + } + + datamix_params = json.dumps({ + "max_steps": {"type": "integer", "required": True, "default": 100}, + "customer_data_percent": {"type": "integer", "required": False, "default": 50} + }) + mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=datamix_params.encode()))} + + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( + "nova-micro-v2", "SFT", "FULL", mock_session, + ) + + assert "max_steps" in options._specs + assert "customer_data_percent" in options._specs + assert mock_s3.get_object.call_count == 1 + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test__get_fine_tuning_options_mixed_recipes_still_prefers_standard(self, mock_get_hub_content): + """When model has both standard and subscription recipes, standard is preferred as primary.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_sts = Mock() + mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-lite-v2", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/standard_lora.yaml", + "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", + "Name": "standard_lora_sft", + "Peft": "LORA" + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-123456789012/source/datamix_lora.yaml", + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/datamix_params.json", + "Name": "datamix_lora_sft", + "Peft": "LORA", + "IsSubscriptionModel": True + } + ] + } + } + + standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) + datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}}) + mock_s3.get_object.side_effect = [ + {"Body": Mock(read=Mock(return_value=standard_params.encode()))}, + {"Body": Mock(read=Mock(return_value=datamix_params.encode()))}, + ] + + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( + "nova-lite-v2", "SFT", "LORA", mock_session, + ) + + # Standard recipe params present with original default + assert options._specs["max_steps"]["default"] == 100 + # Subscription params merged with None default + assert "customer_data_percent" in options._specs + assert options._specs["customer_data_percent"]["default"] is None + # Two S3 calls: one for standard, one for subscription + assert mock_s3.get_object.call_count == 2