Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,32 +448,51 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}")

# Select recipe based on training type
# Collect override_params from ALL matching recipes (standard + subscription)
# Prefer non-subscription (standard) recipe as primary; fallback to subscription
# recipe if model only has subscription recipes (e.g., nova-textgeneration-micro-v2)
recipe = None
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
if not recipe:
recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
if not recipe:
recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)

if not recipe:
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")

# Start with the standard recipe's override_params
# Start with the primary recipe's override_params
options_dict = {}
if recipe.get("SmtjOverrideParamsS3Uri"):
s3_uri = recipe["SmtjOverrideParamsS3Uri"]
if "{customer_id}" in s3_uri:
s3_uri = s3_uri.replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"])
s3 = sagemaker_session.boto_session.client("s3")
uri_path = s3_uri.replace("s3://", "")
bucket, key = uri_path.split("/", 1)
# Handle access point ARN URIs (subscription-only models)
if uri_path.startswith("arn:"):
arn_parts = uri_path.split("/", 2)
bucket = arn_parts[0] + "/" + arn_parts[1]
key = arn_parts[2] if len(arn_parts) > 2 else ""
else:
bucket, key = uri_path.split("/", 1)
obj = s3.get_object(Bucket=bucket, Key=key)
options_dict = json.loads(obj["Body"].read())

# Auto-detect and merge subscription recipe's override_params if available
# Skip if primary recipe is already the subscription recipe (subscription-only model)
sub_recipe = None
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
else:
sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)

# Guard: don't merge subscription override_params into itself
if sub_recipe and sub_recipe is recipe:
sub_recipe = None

if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"):
try:
sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"])
Expand Down
143 changes: 143 additions & 0 deletions sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,3 +864,146 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self,
# Should still have standard params, just not datamix ones
assert "max_steps" in options._specs
assert "customer_data_percent" not in options._specs

@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
def test__get_fine_tuning_options_subscription_only_model_lora(self, mock_get_hub_content):
"""When model has ONLY subscription recipes (e.g. nova-micro-v2), should select it as primary."""
mock_session = Mock()
mock_session.boto_session.region_name = "us-east-1"
mock_s3 = Mock()
mock_sts = Mock()
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts

mock_get_hub_content.return_value = {
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2",
'hub_content_document': {
"GatedBucket": False,
"RecipeCollection": [
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://subscription-bucket/lora_template.yaml",
"SmtjOverrideParamsS3Uri": "s3://subscription-bucket/lora_params.json",
"Name": "nova_micro_2_lora_sft_datamix",
"Peft": "LORA",
"IsSubscriptionModel": True
},
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://subscription-bucket/full_template.yaml",
"SmtjOverrideParamsS3Uri": "s3://subscription-bucket/full_params.json",
"Name": "nova_micro_2_full_sft_datamix",
"IsSubscriptionModel": True
}
]
}
}

datamix_params = json.dumps({
"max_steps": {"type": "integer", "required": True, "default": 100},
"customer_data_percent": {"type": "integer", "required": False, "default": 50}
})
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=datamix_params.encode()))}

options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
"nova-micro-v2", "SFT", "LORA", mock_session,
)

# Should succeed (not raise ValueError) and have params from subscription recipe
assert "max_steps" in options._specs
assert "customer_data_percent" in options._specs
# Should only call get_object ONCE (no merge into itself)
assert mock_s3.get_object.call_count == 1

@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
def test__get_fine_tuning_options_subscription_only_model_full(self, mock_get_hub_content):
"""When model has ONLY subscription recipes with FULL rank, should select it as primary."""
mock_session = Mock()
mock_session.boto_session.region_name = "us-east-1"
mock_s3 = Mock()
mock_sts = Mock()
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts

mock_get_hub_content.return_value = {
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2",
'hub_content_document': {
"GatedBucket": False,
"RecipeCollection": [
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://subscription-bucket/full_template.yaml",
"SmtjOverrideParamsS3Uri": "s3://subscription-bucket/full_params.json",
"Name": "nova_micro_2_full_sft_datamix",
"IsSubscriptionModel": True
}
]
}
}

datamix_params = json.dumps({
"max_steps": {"type": "integer", "required": True, "default": 100},
"customer_data_percent": {"type": "integer", "required": False, "default": 50}
})
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=datamix_params.encode()))}

options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
"nova-micro-v2", "SFT", "FULL", mock_session,
)

assert "max_steps" in options._specs
assert "customer_data_percent" in options._specs
assert mock_s3.get_object.call_count == 1

@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
def test__get_fine_tuning_options_mixed_recipes_still_prefers_standard(self, mock_get_hub_content):
"""When model has both standard and subscription recipes, standard is preferred as primary."""
mock_session = Mock()
mock_session.boto_session.region_name = "us-east-1"
mock_s3 = Mock()
mock_sts = Mock()
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts

mock_get_hub_content.return_value = {
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-lite-v2",
'hub_content_document': {
"GatedBucket": False,
"RecipeCollection": [
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://bucket/standard_lora.yaml",
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
"Name": "standard_lora_sft",
"Peft": "LORA"
},
{
"CustomizationTechnique": "SFT",
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-123456789012/source/datamix_lora.yaml",
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/datamix_params.json",
"Name": "datamix_lora_sft",
"Peft": "LORA",
"IsSubscriptionModel": True
}
]
}
}

standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}})
mock_s3.get_object.side_effect = [
{"Body": Mock(read=Mock(return_value=standard_params.encode()))},
{"Body": Mock(read=Mock(return_value=datamix_params.encode()))},
]

options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
"nova-lite-v2", "SFT", "LORA", mock_session,
)

# Standard recipe params present with original default
assert options._specs["max_steps"]["default"] == 100
# Subscription params merged with None default
assert "customer_data_percent" in options._specs
assert options._specs["customer_data_percent"]["default"] is None
# Two S3 calls: one for standard, one for subscription
assert mock_s3.get_object.call_count == 2
Loading