Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ def _is_jumpstart_model_id(self) -> bool:
return False

try:
model_uris.retrieve(model_id=self.model, model_version="*", model_scope=_JS_SCOPE)
model_uris.retrieve(
model_id=self.model,
model_version=getattr(self, "model_version", None) or "*",
model_scope=_JS_SCOPE,
)
except KeyError:
logger.warning(_NO_JS_MODEL_EX)
return False
Expand All @@ -154,6 +158,7 @@ def _create_pre_trained_js_model(self) -> Type[Model]:
"""Placeholder docstring"""
pysdk_model = JumpStartModel(
self.model,
model_version=getattr(self, "model_version", None) or "*",
vpc_config=self.vpc_config,
sagemaker_session=self.sagemaker_session,
name=self.name,
Expand Down
9 changes: 9 additions & 0 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
available for providing s3 path to fine-tuned model artifacts. ``FINE_TUNING_JOB_NAME``
is available for providing fine-tuned job name. Both ``FINE_TUNING_MODEL_PATH`` and
``FINE_TUNING_JOB_NAME`` are mutually exclusive.
model_version (Optional[str]): Override the JumpStart model version to resolve.
Defaults to ``"*"`` (latest) when not set. Ignored for non-JumpStart models.
inference_component_name (Optional[str]): The name for an inference component
created from this ModelBuilder instance. This or ``resource_requirements`` must be set
to denote that this instance refers to an inference component.
Expand Down Expand Up @@ -337,6 +339,13 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
"in the Hub, Adding unsupported task types will throw an exception."
},
)
model_version: Optional[str] = field(
default=None,
metadata={
"help": "Override the JumpStart model version to resolve. Defaults to the "
"latest version ('*') when not set. Ignored for non-JumpStart models."
},
)
inference_component_name: Optional[str] = field(
default=None,
metadata={
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,11 @@ def test_jumpstart_gated_model_inference_component_enabled(setup):

model = JumpStartModel(
model_id=model_id,
model_version="*", # version >=3.0.0 stores artifacts in jumpstart-private-cache-* buckets
model_version="4.*", # pin: v5.0.0 default image fails to start
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)

# uses ml.g5.2xlarge instance
model.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
accept_eula=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
def model_builder_llama_inference_component():
return ModelBuilder(
model=LLAMA_2_7B_JS_ID,
# Pin: JumpStart v5.0.0 default image fails to start for this model.
model_version="4.*",
schema_builder=SchemaBuilder(sample_input, sample_output),
resource_requirements=ResourceRequirements(
requests={"memory": 98304, "num_accelerators": 4, "copies": 1, "num_cpus": 40}
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_js_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,3 +1957,38 @@ def test_optimize_on_js_model_test_image_defaulting_scenarios(
optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"],
"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124",
)


class TestModelBuilderJumpStartModelVersion(unittest.TestCase):
"""``ModelBuilder.model_version`` should be threaded into JumpStart resolution."""

@patch("sagemaker.serve.builder.jumpstart_builder.model_uris.retrieve")
def test_is_jumpstart_model_id_defaults_to_star(self, mock_retrieve):
mb = ModelBuilder(model=mock_model_id, schema_builder=mock_schema_builder)
self.assertTrue(mb._is_jumpstart_model_id())
mock_retrieve.assert_called_once()
self.assertEqual(mock_retrieve.call_args.kwargs["model_version"], "*")

@patch("sagemaker.serve.builder.jumpstart_builder.model_uris.retrieve")
def test_is_jumpstart_model_id_uses_override(self, mock_retrieve):
mb = ModelBuilder(
model=mock_model_id, schema_builder=mock_schema_builder, model_version="4.*"
)
self.assertTrue(mb._is_jumpstart_model_id())
self.assertEqual(mock_retrieve.call_args.kwargs["model_version"], "4.*")

@patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel")
def test_create_pre_trained_js_model_defaults_to_star(self, mock_js_model_cls):
mock_js_model_cls.return_value = MagicMock(deploy=MagicMock())
mb = ModelBuilder(model=mock_model_id, schema_builder=mock_schema_builder)
mb._create_pre_trained_js_model()
self.assertEqual(mock_js_model_cls.call_args.kwargs["model_version"], "*")

@patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel")
def test_create_pre_trained_js_model_uses_override(self, mock_js_model_cls):
mock_js_model_cls.return_value = MagicMock(deploy=MagicMock())
mb = ModelBuilder(
model=mock_model_id, schema_builder=mock_schema_builder, model_version="4.*"
)
mb._create_pre_trained_js_model()
self.assertEqual(mock_js_model_cls.call_args.kwargs["model_version"], "4.*")
Loading