Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 98 additions & 12 deletions sagemaker-serve/src/sagemaker/serve/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,12 +1076,105 @@ def _is_nova_model_for_telemetry(self) -> bool:
except Exception:
return False

def _select_nova_hosting_config_entry(self, configs, instance_type, identifier):
"""Select a single hosting config entry from a list of Nova configs.

Picks the entry matching ``instance_type`` when provided, otherwise the
entry with ``Profile == "Default"`` (falling back to the first entry).

Args:
configs: List of hosting config dicts.
instance_type: Requested instance type, or None.
identifier: Model identifier used for error messages.

Returns:
The selected hosting config dict.

Raises:
ValueError: If ``instance_type`` is provided but no entry matches it.
"""
if instance_type:
config = next(
(c for c in configs if c.get("InstanceType") == instance_type), None
)
if not config:
supported = [c.get("InstanceType") for c in configs]
raise ValueError(
f"Instance type '{instance_type}' not supported for '{identifier}'. "
f"Supported: {supported}"
)
return config
return next((c for c in configs if c.get("Profile") == "Default"), configs[0])

def _get_nova_hosting_config_from_hub_document(self, instance_type=None):
"""Resolve Nova hosting config from the JumpStart hub document, if present.

Reads hosting configs published in the hub content document, matching the
standard schema used by other custom models. Looks first inside the
``RecipeCollection`` entry whose ``Name`` matches the recipe, then falls
back to the top-level ``HostingConfigs``.

Returns:
A dict with ``image_uri``, ``env_vars``, and ``instance_type`` when a
usable hosting config is found, otherwise ``None``.
"""
try:
hub_document = self._fetch_hub_document_for_custom_model()
except Exception as e: # pragma: no cover - defensive, hub may be unavailable
logger.debug(f"Could not fetch hub document for Nova hosting config: {e}")
return None

if not hub_document:
return None

container = self._fetch_model_package().inference_specification.containers[0]
recipe_name = getattr(container.base_model, "recipe_name", None) or ""

hosting_configs = None
for recipe in hub_document.get("RecipeCollection", []):
if recipe.get("Name") == recipe_name:
hosting_configs = recipe.get("HostingConfigs")
break
if not hosting_configs:
hosting_configs = hub_document.get("HostingConfigs")

if not hosting_configs:
return None

config = self._select_nova_hosting_config_entry(
hosting_configs, instance_type, recipe_name or "nova"
)

image_uri = config.get("EcrAddress")
if not image_uri:
# Hosting config present but no image override; let the hardcoded
# fallback supply the escrow image URI.
return None

resolved_instance_type = config.get("InstanceType") or config.get(
"DefaultInstanceType"
)

return {
"image_uri": image_uri,
"env_vars": config.get("Environment", {}),
"instance_type": resolved_instance_type,
}

def _get_nova_hosting_config(self, instance_type=None):
"""Get Nova hosting config (image URI, env vars, instance type).

Nova training recipes don't have hosting configs in the JumpStart hub document.
This provides the hardcoded fallback, matching Rhinestone's getNovaHostingConfigs().
Prefers hosting configs published in the JumpStart hub document (the
standard location used by other custom models). Falls back to the
hardcoded ``_NOVA_HOSTING_CONFIGS``, matching Rhinestone's
getNovaHostingConfigs(), when the hub document does not provide one.
"""
hub_config = self._get_nova_hosting_config_from_hub_document(
instance_type=instance_type
)
if hub_config:
return hub_config

model_package = self._fetch_model_package()
hub_content_name = model_package.inference_specification.containers[0].base_model.hub_content_name

Expand All @@ -1102,16 +1195,9 @@ def _get_nova_hosting_config(self, instance_type=None):

image_uri = f"{escrow_account}.dkr.ecr.{region}.amazonaws.com/nova-inference-repo:SM-Inference-latest"

if instance_type:
config = next((c for c in configs if c["InstanceType"] == instance_type), None)
if not config:
supported = [c["InstanceType"] for c in configs]
raise ValueError(
f"Instance type '{instance_type}' not supported for '{hub_content_name}'. "
f"Supported: {supported}"
)
else:
config = next((c for c in configs if c.get("Profile") == "Default"), configs[0])
config = self._select_nova_hosting_config_entry(
configs, instance_type, hub_content_name
)

return {
"image_uri": image_uri,
Expand Down
208 changes: 208 additions & 0 deletions sagemaker-serve/tests/unit/test_nova_hosting_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Unit tests for Nova hosting config resolution in ModelBuilder.

Verifies that hosting configs published in the JumpStart hub document take
priority over the hardcoded ``_NOVA_HOSTING_CONFIGS`` fallback.
"""

import unittest
from unittest.mock import MagicMock, patch

from sagemaker.serve.model_builder import ModelBuilder


def _make_builder(region="us-east-1"):
"""Create a ModelBuilder without running __init__."""
mb = ModelBuilder.__new__(ModelBuilder)
mb.image_uri = None
mb.env_vars = None
mb.instance_type = None
session = MagicMock()
session.boto_region_name = region
mb.sagemaker_session = session
return mb


def _make_model_package(recipe_name="", hub_content_name="nova-textgeneration-lite"):
pkg = MagicMock()
base_model = MagicMock()
base_model.recipe_name = recipe_name
base_model.hub_content_name = hub_content_name
pkg.inference_specification.containers = [MagicMock(base_model=base_model)]
return pkg


class TestNovaHostingConfigResolution(unittest.TestCase):
"""Tests for ModelBuilder._get_nova_hosting_config priority behavior."""

def test_hub_recipe_collection_config_takes_priority(self):
"""Hosting config from RecipeCollection in the hub doc is preferred."""
mb = _make_builder()
hub_doc = {
"RecipeCollection": [
{
"Name": "my-nova-recipe",
"HostingConfigs": [
{
"Profile": "Default",
"EcrAddress": "111.dkr.ecr.us-east-1.amazonaws.com/custom:tag",
"InstanceType": "ml.p5.48xlarge",
"Environment": {
"CONTEXT_LENGTH": "999",
"MAX_CONCURRENCY": "3",
},
}
],
}
]
}
mp = _make_model_package(
recipe_name="my-nova-recipe", hub_content_name="nova-textgeneration-lite"
)
with patch.object(
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
cfg = mb._get_nova_hosting_config()

self.assertEqual(
cfg["image_uri"], "111.dkr.ecr.us-east-1.amazonaws.com/custom:tag"
)
self.assertEqual(
cfg["env_vars"], {"CONTEXT_LENGTH": "999", "MAX_CONCURRENCY": "3"}
)
self.assertEqual(cfg["instance_type"], "ml.p5.48xlarge")

def test_top_level_hosting_configs_used_when_no_recipe_match(self):
"""Top-level HostingConfigs is used when no RecipeCollection matches."""
mb = _make_builder()
hub_doc = {
"HostingConfigs": [
{
"Profile": "Default",
"EcrAddress": "222.dkr.ecr.us-east-1.amazonaws.com/top:tag",
"InstanceType": "ml.g6.24xlarge",
"Environment": {"CONTEXT_LENGTH": "100"},
}
]
}
mp = _make_model_package(
recipe_name="unmatched", hub_content_name="nova-textgeneration-micro"
)
with patch.object(
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
cfg = mb._get_nova_hosting_config()

self.assertEqual(
cfg["image_uri"], "222.dkr.ecr.us-east-1.amazonaws.com/top:tag"
)

def test_hardcoded_fallback_when_hub_has_no_hosting_config(self):
"""Hardcoded escrow config is used when the hub doc has no hosting config."""
mb = _make_builder()
mp = _make_model_package(hub_content_name="nova-textgeneration-lite")
with patch.object(
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value={}
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
cfg = mb._get_nova_hosting_config()

self.assertIn("nova-inference-repo:SM-Inference-latest", cfg["image_uri"])
self.assertEqual(cfg["instance_type"], "ml.g6.48xlarge")

def test_hardcoded_fallback_when_hub_fetch_raises(self):
"""Hardcoded config is used defensively when hub fetch raises."""
mb = _make_builder()
mp = _make_model_package(hub_content_name="nova-textgeneration-pro")
with patch.object(
ModelBuilder,
"_fetch_hub_document_for_custom_model",
side_effect=RuntimeError("hub unavailable"),
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
cfg = mb._get_nova_hosting_config()

self.assertEqual(cfg["instance_type"], "ml.p5.48xlarge")
self.assertIn("nova-inference-repo:SM-Inference-latest", cfg["image_uri"])

def test_missing_ecr_address_falls_through_to_hardcoded(self):
"""A hub hosting config without EcrAddress falls back to the escrow image."""
mb = _make_builder()
hub_doc = {
"RecipeCollection": [
{
"Name": "r",
"HostingConfigs": [
{"Profile": "Default", "InstanceType": "ml.p5.48xlarge"}
],
}
]
}
mp = _make_model_package(
recipe_name="r", hub_content_name="nova-textgeneration-pro"
)
with patch.object(
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
cfg = mb._get_nova_hosting_config()

self.assertIn("nova-inference-repo:SM-Inference-latest", cfg["image_uri"])

def test_instance_type_match_in_hub_config(self):
"""A requested instance type selects the matching hub config entry."""
mb = _make_builder()
hub_doc = {
"RecipeCollection": [
{
"Name": "r",
"HostingConfigs": [
{
"Profile": "Default",
"EcrAddress": "333.dkr.ecr.us-east-1.amazonaws.com/a:tag",
"InstanceType": "ml.p5.48xlarge",
"Environment": {"CONTEXT_LENGTH": "1"},
},
{
"EcrAddress": "333.dkr.ecr.us-east-1.amazonaws.com/b:tag",
"InstanceType": "ml.g6.48xlarge",
"Environment": {"CONTEXT_LENGTH": "2"},
},
],
}
]
}
mp = _make_model_package(
recipe_name="r", hub_content_name="nova-textgeneration-lite"
)
with patch.object(
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value=hub_doc
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
cfg = mb._get_nova_hosting_config(instance_type="ml.g6.48xlarge")

self.assertEqual(
cfg["image_uri"], "333.dkr.ecr.us-east-1.amazonaws.com/b:tag"
)
self.assertEqual(cfg["instance_type"], "ml.g6.48xlarge")

def test_unsupported_instance_type_raises(self):
"""Requesting an unsupported instance type raises ValueError (fallback path)."""
mb = _make_builder()
mp = _make_model_package(hub_content_name="nova-textgeneration-pro")
with patch.object(
ModelBuilder, "_fetch_hub_document_for_custom_model", return_value={}
), patch.object(ModelBuilder, "_fetch_model_package", return_value=mp):
with self.assertRaises(ValueError):
mb._get_nova_hosting_config(instance_type="ml.invalid.type")


if __name__ == "__main__":
unittest.main()
Loading