@@ -1957,3 +1957,38 @@ def test_optimize_on_js_model_test_image_defaulting_scenarios(
19571957 optimization_args ["OptimizationConfigs" ][0 ]["ModelQuantizationConfig" ]["Image" ],
19581958 "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124" ,
19591959 )
1960+
1961+
1962+ class TestModelBuilderJumpStartModelVersion (unittest .TestCase ):
1963+ """``ModelBuilder.model_version`` should be threaded into JumpStart resolution."""
1964+
1965+ @patch ("sagemaker.serve.builder.jumpstart_builder.model_uris.retrieve" )
1966+ def test_is_jumpstart_model_id_defaults_to_star (self , mock_retrieve ):
1967+ mb = ModelBuilder (model = mock_model_id , schema_builder = mock_schema_builder )
1968+ self .assertTrue (mb ._is_jumpstart_model_id ())
1969+ mock_retrieve .assert_called_once ()
1970+ self .assertEqual (mock_retrieve .call_args .kwargs ["model_version" ], "*" )
1971+
1972+ @patch ("sagemaker.serve.builder.jumpstart_builder.model_uris.retrieve" )
1973+ def test_is_jumpstart_model_id_uses_override (self , mock_retrieve ):
1974+ mb = ModelBuilder (
1975+ model = mock_model_id , schema_builder = mock_schema_builder , model_version = "4.*"
1976+ )
1977+ self .assertTrue (mb ._is_jumpstart_model_id ())
1978+ self .assertEqual (mock_retrieve .call_args .kwargs ["model_version" ], "4.*" )
1979+
1980+ @patch ("sagemaker.serve.builder.jumpstart_builder.JumpStartModel" )
1981+ def test_create_pre_trained_js_model_defaults_to_star (self , mock_js_model_cls ):
1982+ mock_js_model_cls .return_value = MagicMock (deploy = MagicMock ())
1983+ mb = ModelBuilder (model = mock_model_id , schema_builder = mock_schema_builder )
1984+ mb ._create_pre_trained_js_model ()
1985+ self .assertEqual (mock_js_model_cls .call_args .kwargs ["model_version" ], "*" )
1986+
1987+ @patch ("sagemaker.serve.builder.jumpstart_builder.JumpStartModel" )
1988+ def test_create_pre_trained_js_model_uses_override (self , mock_js_model_cls ):
1989+ mock_js_model_cls .return_value = MagicMock (deploy = MagicMock ())
1990+ mb = ModelBuilder (
1991+ model = mock_model_id , schema_builder = mock_schema_builder , model_version = "4.*"
1992+ )
1993+ mb ._create_pre_trained_js_model ()
1994+ self .assertEqual (mock_js_model_cls .call_args .kwargs ["model_version" ], "4.*" )
0 commit comments