Skip to content

Commit 95bdf1f

Browse files
Pin JS model integ test to 4.* version (#5777)
* Pin JS model to 4.* version * Add model_version to modelbuilder and fix js version pinning
1 parent 8a2f9d2 commit 95bdf1f

5 files changed

Lines changed: 53 additions & 3 deletions

File tree

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,11 @@ def _is_jumpstart_model_id(self) -> bool:
142142
return False
143143

144144
try:
145-
model_uris.retrieve(model_id=self.model, model_version="*", model_scope=_JS_SCOPE)
145+
model_uris.retrieve(
146+
model_id=self.model,
147+
model_version=getattr(self, "model_version", None) or "*",
148+
model_scope=_JS_SCOPE,
149+
)
146150
except KeyError:
147151
logger.warning(_NO_JS_MODEL_EX)
148152
return False
@@ -154,6 +158,7 @@ def _create_pre_trained_js_model(self) -> Type[Model]:
154158
"""Placeholder docstring"""
155159
pysdk_model = JumpStartModel(
156160
self.model,
161+
model_version=getattr(self, "model_version", None) or "*",
157162
vpc_config=self.vpc_config,
158163
sagemaker_session=self.sagemaker_session,
159164
name=self.name,

src/sagemaker/serve/builder/model_builder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
225225
available for providing s3 path to fine-tuned model artifacts. ``FINE_TUNING_JOB_NAME``
226226
is available for providing fine-tuned job name. Both ``FINE_TUNING_MODEL_PATH`` and
227227
``FINE_TUNING_JOB_NAME`` are mutually exclusive.
228+
model_version (Optional[str]): Override the JumpStart model version to resolve.
229+
Defaults to ``"*"`` (latest) when not set. Ignored for non-JumpStart models.
228230
inference_component_name (Optional[str]): The name for an inference component
229231
created from this ModelBuilder instance. This or ``resource_requirements`` must be set
230232
to denote that this instance refers to an inference component.
@@ -337,6 +339,13 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
337339
"in the Hub, Adding unsupported task types will throw an exception."
338340
},
339341
)
342+
model_version: Optional[str] = field(
343+
default=None,
344+
metadata={
345+
"help": "Override the JumpStart model version to resolve. Defaults to the "
346+
"latest version ('*') when not set. Ignored for non-JumpStart models."
347+
},
348+
)
340349
inference_component_name: Optional[str] = field(
341350
default=None,
342351
metadata={

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,11 @@ def test_jumpstart_gated_model_inference_component_enabled(setup):
197197

198198
model = JumpStartModel(
199199
model_id=model_id,
200-
model_version="*", # version >=3.0.0 stores artifacts in jumpstart-private-cache-* buckets
200+
model_version="4.*", # pin: v5.0.0 default image fails to start
201201
role=get_sm_session().get_caller_identity_arn(),
202202
sagemaker_session=get_sm_session(),
203203
)
204204

205-
# uses ml.g5.2xlarge instance
206205
model.deploy(
207206
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
208207
accept_eula=True,

tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
def model_builder_llama_inference_component():
4949
return ModelBuilder(
5050
model=LLAMA_2_7B_JS_ID,
51+
# Pin: JumpStart v5.0.0 default image fails to start for this model.
52+
model_version="4.*",
5153
schema_builder=SchemaBuilder(sample_input, sample_output),
5254
resource_requirements=ResourceRequirements(
5355
requests={"memory": 98304, "num_accelerators": 4, "copies": 1, "num_cpus": 40}

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,3 +1957,38 @@ def test_optimize_on_js_model_test_image_defaulting_scenarios(
19571957
optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"],
19581958
"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124",
19591959
)
1960+
1961+
1962+
class TestModelBuilderJumpStartModelVersion(unittest.TestCase):
1963+
"""``ModelBuilder.model_version`` should be threaded into JumpStart resolution."""
1964+
1965+
@patch("sagemaker.serve.builder.jumpstart_builder.model_uris.retrieve")
1966+
def test_is_jumpstart_model_id_defaults_to_star(self, mock_retrieve):
1967+
mb = ModelBuilder(model=mock_model_id, schema_builder=mock_schema_builder)
1968+
self.assertTrue(mb._is_jumpstart_model_id())
1969+
mock_retrieve.assert_called_once()
1970+
self.assertEqual(mock_retrieve.call_args.kwargs["model_version"], "*")
1971+
1972+
@patch("sagemaker.serve.builder.jumpstart_builder.model_uris.retrieve")
1973+
def test_is_jumpstart_model_id_uses_override(self, mock_retrieve):
1974+
mb = ModelBuilder(
1975+
model=mock_model_id, schema_builder=mock_schema_builder, model_version="4.*"
1976+
)
1977+
self.assertTrue(mb._is_jumpstart_model_id())
1978+
self.assertEqual(mock_retrieve.call_args.kwargs["model_version"], "4.*")
1979+
1980+
@patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel")
1981+
def test_create_pre_trained_js_model_defaults_to_star(self, mock_js_model_cls):
1982+
mock_js_model_cls.return_value = MagicMock(deploy=MagicMock())
1983+
mb = ModelBuilder(model=mock_model_id, schema_builder=mock_schema_builder)
1984+
mb._create_pre_trained_js_model()
1985+
self.assertEqual(mock_js_model_cls.call_args.kwargs["model_version"], "*")
1986+
1987+
@patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel")
1988+
def test_create_pre_trained_js_model_uses_override(self, mock_js_model_cls):
1989+
mock_js_model_cls.return_value = MagicMock(deploy=MagicMock())
1990+
mb = ModelBuilder(
1991+
model=mock_model_id, schema_builder=mock_schema_builder, model_version="4.*"
1992+
)
1993+
mb._create_pre_trained_js_model()
1994+
self.assertEqual(mock_js_model_cls.call_args.kwargs["model_version"], "4.*")

0 commit comments

Comments
 (0)