Skip to content
155 changes: 60 additions & 95 deletions sagemaker-serve/tests/integ/test_model_customization_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,22 @@
"""Integration tests for ModelBuilder model customization deployment."""
from __future__ import absolute_import

import boto3
import pytest
import random

from sagemaker.core.helper.session_helper import Session

# This test relies on resources in a specific region
AWS_REGION = "us-west-2"


@pytest.fixture(scope="module")
def sagemaker_session():
"""Create a SageMaker session with explicit region."""
boto_session = boto3.Session(region_name=AWS_REGION)
return Session(boto_session=boto_session)


@pytest.fixture(scope="module")
def training_job_name():
Expand Down Expand Up @@ -48,51 +61,6 @@ def endpoint_name():
return f"e2e-{int(time.time())}-{random.randint(100, 10000)}"


@pytest.fixture(scope="session", autouse=True)
def cleanup_e2e_endpoints():
"""Cleanup e2e endpoints before and after tests."""
import os
from botocore.exceptions import ClientError

# This file's tests use us-west-2 resources. Set SAGEMAKER_REGION so the
# SDK's SageMakerClient creates sessions in the correct region from the start.
# Save/restore to avoid leaking into other test files.
original_sm_region = os.environ.get("SAGEMAKER_REGION")
os.environ["SAGEMAKER_REGION"] = "us-west-2"

from sagemaker.core.resources import Endpoint

# Cleanup before tests
try:
for endpoint in Endpoint.get_all():
try:
if endpoint.endpoint_name.startswith('e2e-'):
endpoint.delete()
except (ClientError, Exception):
pass
except (ClientError, Exception):
pass

yield

# Cleanup after tests
try:
for endpoint in Endpoint.get_all():
try:
if endpoint.endpoint_name.startswith('e2e-'):
endpoint.delete()
except (ClientError, Exception):
pass
except (ClientError, Exception):
pass

# Restore original SAGEMAKER_REGION
if original_sm_region:
os.environ["SAGEMAKER_REGION"] = original_sm_region
elif "SAGEMAKER_REGION" in os.environ:
del os.environ["SAGEMAKER_REGION"]


@pytest.fixture(scope="module")
def cleanup_endpoints():
"""Track endpoints to cleanup after tests."""
Expand All @@ -102,7 +70,7 @@ def cleanup_endpoints():
for ep_name in endpoints_to_cleanup:
try:
from sagemaker.core.resources import Endpoint
endpoint = Endpoint.get(endpoint_name=ep_name)
endpoint = Endpoint.get(endpoint_name=ep_name, region=AWS_REGION)
endpoint.delete()
except Exception:
pass
Expand All @@ -111,24 +79,23 @@ def cleanup_endpoints():
class TestModelCustomizationFromTrainingJob:
"""Test model customization deployment from TrainingJob."""

def test_build_from_training_job(self, training_job_name):
def test_build_from_training_job(self, training_job_name, sagemaker_session):
"""Test building model from training job."""
from sagemaker.core.resources import TrainingJob
from sagemaker.serve import ModelBuilder
import time

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job)
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)
model_builder.accept_eula = True
model = model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
model = model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}", region=AWS_REGION)

assert model is not None
assert model.model_arn is not None
assert model_builder.image_uri is not None
assert model_builder.instance_type is not None

@pytest.mark.skip(reason="Skipped: parallel cleanup race condition under investigation")
def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanup_endpoints):
def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanup_endpoints, sagemaker_session):
"""Test deploying model from training job.

For LORA models, this verifies the two-step deployment:
Expand All @@ -138,10 +105,10 @@ def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanu
from sagemaker.serve import ModelBuilder
import time

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job, instance_type="ml.g5.4xlarge")
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, instance_type="ml.g5.4xlarge", sagemaker_session=sagemaker_session)
model_builder.accept_eula = True
model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}", region=AWS_REGION)

peft_type = model_builder._fetch_peft()
adapter_name = f"{endpoint_name}-adapter"
Expand All @@ -160,52 +127,52 @@ def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanu
if peft_type == "LORA":
# Verify base IC was created
base_ic_name = f"{endpoint_name}-inference-component"
base_ic = InferenceComponent.get(inference_component_name=base_ic_name)
base_ic = InferenceComponent.get(inference_component_name=base_ic_name, region=AWS_REGION)
assert base_ic is not None
assert base_ic.inference_component_status == "InService"

# Verify adapter IC was created
adapter_ic = InferenceComponent.get(inference_component_name=adapter_name)
adapter_ic = InferenceComponent.get(inference_component_name=adapter_name, region=AWS_REGION)
assert adapter_ic is not None

def test_fetch_endpoint_names_for_base_model(self, training_job_name):
def test_fetch_endpoint_names_for_base_model(self, training_job_name, sagemaker_session):
"""Test fetching endpoint names for base model."""
from sagemaker.core.resources import TrainingJob
from sagemaker.serve import ModelBuilder

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job)
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)
endpoint_names = model_builder.fetch_endpoint_names_for_base_model()

assert isinstance(endpoint_names, set)


class TestModelCustomizationFromModelPackage:

def test_build_from_model_package(self, model_package_arn):
def test_build_from_model_package(self, model_package_arn, sagemaker_session):
"""Test building model from model package."""
from sagemaker.core.resources import ModelPackage
from sagemaker.serve import ModelBuilder

model_package = ModelPackage.get(model_package_name=model_package_arn)
model_builder = ModelBuilder(model=model_package)
model_package = ModelPackage.get(model_package_name=model_package_arn, region=AWS_REGION)
model_builder = ModelBuilder(model=model_package, sagemaker_session=sagemaker_session)
model_builder.accept_eula = True
model = model_builder.build()
model = model_builder.build(region=AWS_REGION)

assert model is not None
assert model.model_arn is not None

def test_deploy_from_model_package(self, model_package_arn, cleanup_endpoints):
def test_deploy_from_model_package(self, model_package_arn, cleanup_endpoints, sagemaker_session):
"""Test deploying model from model package."""
from sagemaker.core.resources import ModelPackage
from sagemaker.serve import ModelBuilder
import time

model_package = ModelPackage.get(model_package_name=model_package_arn)
model_package = ModelPackage.get(model_package_name=model_package_arn, region=AWS_REGION)
endpoint_name = f"e2e-{int(time.time())}-{random.randint(100, 10000)}"
model_builder = ModelBuilder(model=model_package)
model_builder = ModelBuilder(model=model_package, sagemaker_session=sagemaker_session)
model_builder.accept_eula = True
model_builder.build()
model_builder.build(region=AWS_REGION)
endpoint = model_builder.deploy(endpoint_name=endpoint_name)

cleanup_endpoints.append(endpoint_name)
Expand All @@ -217,15 +184,15 @@ def test_deploy_from_model_package(self, model_package_arn, cleanup_endpoints):
class TestInstanceTypeAutoDetection:
"""Test automatic instance type detection."""

def test_instance_type_from_recipe(self, training_job_name):
def test_instance_type_from_recipe(self, training_job_name, sagemaker_session):
"""Test instance type auto-detection from recipe."""
from sagemaker.core.resources import TrainingJob
from sagemaker.serve import ModelBuilder

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job)
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)
model_builder.accept_eula = True
model_builder.build()
model_builder.build(region=AWS_REGION)

assert model_builder.instance_type is not None
assert "ml." in model_builder.instance_type
Expand All @@ -234,33 +201,33 @@ def test_instance_type_from_recipe(self, training_job_name):
class TestModelCustomizationDetection:
"""Test model customization detection logic."""

def test_is_model_customization_training_job(self, training_job_name):
def test_is_model_customization_training_job(self, training_job_name, sagemaker_session):
"""Test detection from training job."""
from sagemaker.core.resources import TrainingJob
from sagemaker.serve import ModelBuilder

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job)
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)

assert model_builder._is_model_customization() is True

def test_is_model_customization_model_package(self, model_package_arn):
def test_is_model_customization_model_package(self, model_package_arn, sagemaker_session):
"""Test detection from model package."""
from sagemaker.core.resources import ModelPackage
from sagemaker.serve import ModelBuilder

model_package = ModelPackage.get(model_package_name=model_package_arn)
model_builder = ModelBuilder(model=model_package)
model_package = ModelPackage.get(model_package_name=model_package_arn, region=AWS_REGION)
model_builder = ModelBuilder(model=model_package, sagemaker_session=sagemaker_session)

assert model_builder._is_model_customization() is True

def test_fetch_model_package_arn(self, training_job_name):
def test_fetch_model_package_arn(self, training_job_name, sagemaker_session):
"""Test fetching model package ARN."""
from sagemaker.core.resources import TrainingJob
from sagemaker.serve import ModelBuilder

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job)
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)

arn = model_builder._fetch_model_package_arn()

Expand All @@ -271,14 +238,14 @@ def test_fetch_model_package_arn(self, training_job_name):
class TestTrainerIntegration:
"""Test ModelBuilder integration with SFTTrainer and DPOTrainer."""

def test_sft_trainer_build(self, training_job_name):
def test_sft_trainer_build(self, training_job_name, sagemaker_session):
"""Test building model from SFTTrainer."""
from sagemaker.core.resources import TrainingJob
from sagemaker.train.sft_trainer import SFTTrainer
from sagemaker.serve import ModelBuilder

training_job = TrainingJob.get(
training_job_name=training_job_name
training_job_name=training_job_name, region=AWS_REGION
)

trainer = SFTTrainer(
Expand All @@ -289,21 +256,21 @@ def test_sft_trainer_build(self, training_job_name):
)
trainer._latest_training_job = training_job

model_builder = ModelBuilder(model=trainer)
model = model_builder.build()
model_builder = ModelBuilder(model=trainer, sagemaker_session=sagemaker_session)
model = model_builder.build(region=AWS_REGION)

assert model is not None
assert model.model_arn is not None

def test_dpo_trainer_build(self, training_job_name):
def test_dpo_trainer_build(self, training_job_name, sagemaker_session):
"""Test building model from DPOTrainer."""
from sagemaker.core.resources import TrainingJob
from sagemaker.train.dpo_trainer import DPOTrainer
from sagemaker.serve import ModelBuilder
from unittest.mock import patch

training_job = TrainingJob.get(
training_job_name=training_job_name
training_job_name=training_job_name, region=AWS_REGION
)

with patch('sagemaker.train.common_utils.finetune_utils._get_fine_tuning_options_and_model_arn',
Expand All @@ -316,8 +283,8 @@ def test_dpo_trainer_build(self, training_job_name):
)
trainer._latest_training_job = training_job

model_builder = ModelBuilder(model=trainer)
model = model_builder.build()
model_builder = ModelBuilder(model=trainer, sagemaker_session=sagemaker_session)
model = model_builder.build(region=AWS_REGION)

assert model is not None
assert model.model_arn is not None
Expand All @@ -335,8 +302,6 @@ def test_dpo_trainer_build(self, training_job_name):

import json
import time
import random
import boto3
import pytest
from sagemaker.core.resources import TrainingJob, ModelPackage
from sagemaker.serve.bedrock_model_builder import BedrockModelBuilder
Expand All @@ -361,6 +326,7 @@ def training_job(self, setup_config):
"""Get the training job."""
return TrainingJob.get(
training_job_name=setup_config["training_job_name"],
region=setup_config["region"],
)

@pytest.fixture(scope="class")
Expand Down Expand Up @@ -432,7 +398,7 @@ def _setup_model_files(self, training_job, s3_client, setup_config):
base_s3_path = training_job.model_artifacts.s3_model_artifacts
elif hasattr(training_job, 'output_model_package_arn'):
# If training job has model package ARN, get artifacts from model package
model_package = ModelPackage.get(training_job.output_model_package_arn)
model_package = ModelPackage.get(training_job.output_model_package_arn, region=AWS_REGION)
if hasattr(model_package,
'inference_specification') and model_package.inference_specification.containers:
container = model_package.inference_specification.containers[0]
Expand Down Expand Up @@ -561,8 +527,7 @@ def test_zzz_cleanup_deployed_model(self, bedrock_client):
def test_model_customization_workflow(training_job_name):
"""Standalone test function for pytest discovery.

Relies on SAGEMAKER_REGION being set by the cleanup_e2e_endpoints
session fixture (us-west-2).
Uses explicit region parameter for all SDK calls.
"""
config = {
"training_job_name": training_job_name,
Expand All @@ -572,7 +537,7 @@ def test_model_customization_workflow(training_job_name):

try:
s3_client = boto3.client('s3', region_name=config["region"])
training_job = TrainingJob.get(training_job_name=config["training_job_name"])
training_job = TrainingJob.get(training_job_name=config["training_job_name"], region=config["region"])

test_class = TestModelCustomizationDeployment()
test_class.test_training_job_exists(training_job)
Expand Down
Loading
Loading