diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index bf6fcaa376..2d40f40160 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -142,7 +142,11 @@ def _is_jumpstart_model_id(self) -> bool: return False try: - model_uris.retrieve(model_id=self.model, model_version="*", model_scope=_JS_SCOPE) + model_uris.retrieve( + model_id=self.model, + model_version=getattr(self, "model_version", None) or "*", + model_scope=_JS_SCOPE, + ) except KeyError: logger.warning(_NO_JS_MODEL_EX) return False @@ -154,6 +158,7 @@ def _create_pre_trained_js_model(self) -> Type[Model]: """Placeholder docstring""" pysdk_model = JumpStartModel( self.model, + model_version=getattr(self, "model_version", None) or "*", vpc_config=self.vpc_config, sagemaker_session=self.sagemaker_session, name=self.name, diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 2e74f5eba5..cfa10f669a 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -225,6 +225,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, available for providing s3 path to fine-tuned model artifacts. ``FINE_TUNING_JOB_NAME`` is available for providing fine-tuned job name. Both ``FINE_TUNING_MODEL_PATH`` and ``FINE_TUNING_JOB_NAME`` are mutually exclusive. + model_version (Optional[str]): Override the JumpStart model version to resolve. + Defaults to ``"*"`` (latest) when not set. Ignored for non-JumpStart models. inference_component_name (Optional[str]): The name for an inference component created from this ModelBuilder instance. This or ``resource_requirements`` must be set to denote that this instance refers to an inference component. @@ -337,6 +339,13 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, "in the Hub, Adding unsupported task types will throw an exception." }, ) + model_version: Optional[str] = field( + default=None, + metadata={ + "help": "Override the JumpStart model version to resolve. Defaults to the " + "latest version ('*') when not set. Ignored for non-JumpStart models." + }, + ) inference_component_name: Optional[str] = field( default=None, metadata={ diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 56e3234863..8a19baaf40 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -197,12 +197,11 @@ def test_jumpstart_gated_model_inference_component_enabled(setup): model = JumpStartModel( model_id=model_id, - model_version="*", # version >=3.0.0 stores artifacts in jumpstart-private-cache-* buckets + model_version="4.*", # pin: v5.0.0 default image fails to start role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), ) - # uses ml.g5.2xlarge instance model.deploy( tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], accept_eula=True, diff --git a/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py b/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py index 7191de4e7d..531b63305c 100644 --- a/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py @@ -48,6 +48,8 @@ def model_builder_llama_inference_component(): return ModelBuilder( model=LLAMA_2_7B_JS_ID, + # Pin: JumpStart v5.0.0 default image fails to start for this model. + model_version="4.*", schema_builder=SchemaBuilder(sample_input, sample_output), resource_requirements=ResourceRequirements( requests={"memory": 98304, "num_accelerators": 4, "copies": 1, "num_cpus": 40} diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 25d829b056..0f976ba060 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -1957,3 +1957,38 @@ def test_optimize_on_js_model_test_image_defaulting_scenarios( optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124", ) + + +class TestModelBuilderJumpStartModelVersion(unittest.TestCase): + """``ModelBuilder.model_version`` should be threaded into JumpStart resolution.""" + + @patch("sagemaker.serve.builder.jumpstart_builder.model_uris.retrieve") + def test_is_jumpstart_model_id_defaults_to_star(self, mock_retrieve): + mb = ModelBuilder(model=mock_model_id, schema_builder=mock_schema_builder) + self.assertTrue(mb._is_jumpstart_model_id()) + mock_retrieve.assert_called_once() + self.assertEqual(mock_retrieve.call_args.kwargs["model_version"], "*") + + @patch("sagemaker.serve.builder.jumpstart_builder.model_uris.retrieve") + def test_is_jumpstart_model_id_uses_override(self, mock_retrieve): + mb = ModelBuilder( + model=mock_model_id, schema_builder=mock_schema_builder, model_version="4.*" + ) + self.assertTrue(mb._is_jumpstart_model_id()) + self.assertEqual(mock_retrieve.call_args.kwargs["model_version"], "4.*") + + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + def test_create_pre_trained_js_model_defaults_to_star(self, mock_js_model_cls): + mock_js_model_cls.return_value = MagicMock(deploy=MagicMock()) + mb = ModelBuilder(model=mock_model_id, schema_builder=mock_schema_builder) + mb._create_pre_trained_js_model() + self.assertEqual(mock_js_model_cls.call_args.kwargs["model_version"], "*") + + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + def test_create_pre_trained_js_model_uses_override(self, mock_js_model_cls): + mock_js_model_cls.return_value = MagicMock(deploy=MagicMock()) + mb = ModelBuilder( + model=mock_model_id, schema_builder=mock_schema_builder, model_version="4.*" + ) + mb._create_pre_trained_js_model() + self.assertEqual(mock_js_model_cls.call_args.kwargs["model_version"], "4.*")