Skip to content
179 changes: 146 additions & 33 deletions tests/integ/sagemaker/jumpstart/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import json
import os
import pathlib
from datetime import datetime, timedelta, timezone

import boto3
import pytest
from filelock import FileLock
from botocore.config import Config
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.hub.hub import Hub
Expand All @@ -39,19 +44,28 @@
)


def _setup():
# Only delete leftover hubs from previous test runs that are older than this many
# hours. This guards against deleting a hub that another concurrent test run (or
# xdist worker) is actively using.
STALE_HUB_AGE_HOURS = 3


def _setup(test_suite_id=None, test_hub_name=None):
print("Setting up...")
test_suite_id = get_test_suite_id()
test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}"
test_suite_id = test_suite_id or get_test_suite_id()
test_hub_name = test_hub_name or f"{HUB_NAME_PREFIX}{test_suite_id}"
test_hub_description = "PySDK Integ Test Private Hub"

os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: test_suite_id})
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: test_hub_name})

# Create a private hub to use for the test session
hub = Hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
)
hub = Hub(hub_name=test_hub_name, sagemaker_session=get_sm_session())

# Proactively reclaim stale hubs from prior runs so we don't accumulate
# toward the per-account private hub limit. This only deletes hubs older
# than STALE_HUB_AGE_HOURS and never the hub we are about to use.
_cleanup_old_hubs(get_sm_session(), active_hub_name=test_hub_name)

# Check if hub already exists before creating
try:
Expand All @@ -73,14 +87,14 @@ def _setup():
raise


def _teardown():
def _teardown(test_suite_id=None, test_hub_name=None, delete_hub=False):
print("Tearing down...")

test_cache_bucket = get_test_artifact_bucket()

test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
test_suite_id = test_suite_id or os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]

test_hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
test_hub_name = test_hub_name or os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]

boto3_session = boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)

Expand Down Expand Up @@ -152,34 +166,49 @@ def _teardown():
bucket = s3_resource.Bucket(test_cache_bucket)
bucket.objects.filter(Prefix=test_suite_id + "/").delete()

# delete private hubs
_delete_hubs(sagemaker_session, test_hub_name)
# delete private hubs (only when explicitly requested). During an xdist run
# we never delete the active hub, because a straggler worker may still be
# running a hub test when another process reaches teardown; stale hubs from
# prior runs are reclaimed by the age-based _cleanup_old_hubs instead.
if delete_hub:
_delete_hubs(sagemaker_session, test_hub_name)


def _cleanup_old_hubs(sagemaker_session, active_hub_name=None):
"""Clean up stale test hubs from previous runs to free up resources.

def _cleanup_old_hubs(sagemaker_session):
"""Clean up old test hubs to free up resources."""
Only deletes hubs that are clearly stale (older than ``STALE_HUB_AGE_HOURS``)
so that hubs actively in use by the current test run or by concurrent xdist
workers are never removed. The hub for the current run (``active_hub_name``)
is always preserved.
"""
try:
active_hub_name = active_hub_name or os.environ.get(ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME)
cutoff = datetime.now(timezone.utc) - timedelta(hours=STALE_HUB_AGE_HOURS)

response = sagemaker_session.list_hubs()
test_hubs = [
hub
for hub in response.get("HubSummaries", [])
if hub["HubName"].startswith(HUB_NAME_PREFIX)
]

# Sort by creation time and delete oldest hubs
test_hubs.sort(key=lambda x: x.get("CreationTime", ""))

# Delete oldest hubs (keep only the most recent 10)
hubs_to_delete = (
test_hubs[:-10] if len(test_hubs) > 10 else test_hubs[: max(0, len(test_hubs) - 40)]
)
for hub in response.get("HubSummaries", []):
hub_name = hub["HubName"]
if not hub_name.startswith(HUB_NAME_PREFIX):
continue
if hub_name == active_hub_name:
continue

creation_time = hub.get("CreationTime")
# Only delete hubs we can confirm are older than the cutoff. If the
# creation time is unavailable, err on the side of keeping the hub.
if creation_time is None:
continue
if creation_time.tzinfo is None:
creation_time = creation_time.replace(tzinfo=timezone.utc)
if creation_time >= cutoff:
continue

for hub in hubs_to_delete:
try:
print(f"Deleting old hub: {hub['HubName']}")
_delete_hubs(sagemaker_session, hub["HubName"])
print(f"Deleting stale hub: {hub_name}")
_delete_hubs(sagemaker_session, hub_name)
except Exception as e:
print(f"Failed to delete hub {hub['HubName']}: {e}")
print(f"Failed to delete hub {hub_name}: {e}")
except Exception as e:
print(f"Failed to cleanup old hubs: {e}")

Expand Down Expand Up @@ -210,8 +239,92 @@ def _delete_hub_contents(sagemaker_session, hub_name, model):
)


def _hub_state_root(config):
"""Return the run-level tmp dir shared by the xdist controller and workers.

The controller's basetemp is the run root (e.g. ``.../pytest-N``) while each
worker's basetemp is a ``popen-gw*`` subdir of it. Normalizing to the run
root gives every process the same location for the shared state file.

Works across pytest versions: prefers the ``TempPathFactory`` attached as
``config._tmp_path_factory`` and falls back to the legacy ``_tmpdirhandler``.
"""
factory = getattr(config, "_tmp_path_factory", None)
if factory is not None:
basetemp = pathlib.Path(str(factory.getbasetemp()))
else:
basetemp = pathlib.Path(str(config._tmpdirhandler.getbasetemp()))

if basetemp.name.startswith("popen-gw"):
return basetemp.parent
return basetemp


@pytest.fixture(scope="session", autouse=True)
def setup(request):
_setup()

request.addfinalizer(_teardown)
"""Ensure a single shared private hub exists for the whole test run.

Under pytest-xdist every worker is a separate process, so a naive
``scope="session"`` fixture would create one hub per worker. With high
parallelism (e.g. ``-n 120``) that quickly exhausts the per-account private
hub limit (100). All workers therefore coordinate through a lock file and a
shared JSON state file: the first worker creates the hub, the rest reuse it.

The hub is intentionally NOT deleted at the end of the run. xdist
distributes tests dynamically and hub tests deploy long-lived endpoints, so
a straggler worker can still be running a hub test (at ~100%) while another
process reaches teardown. Deleting the hub there pulls it out from under the
straggler ("Hub ... does not exist" failures). Instead, leaked endpoints and
artifacts are cleaned at run end, and the hub itself is reclaimed on a later
run by the age-based ``_cleanup_old_hubs`` (older than STALE_HUB_AGE_HOURS).
"""
root_tmp_dir = _hub_state_root(request.config)
state_file = root_tmp_dir / "jumpstart_hub_state.json"
lock_file = root_tmp_dir / "jumpstart_hub_state.json.lock"

with FileLock(str(lock_file)):
if state_file.is_file():
state = json.loads(state_file.read_text())
else:
test_suite_id = get_test_suite_id()
test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}"
_setup(test_suite_id=test_suite_id, test_hub_name=test_hub_name)
state = {
"test_suite_id": test_suite_id,
"test_hub_name": test_hub_name,
}
state_file.write_text(json.dumps(state))

# Ensure this worker's environment points at the shared hub.
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: state["test_suite_id"]})
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: state["test_hub_name"]})


def pytest_sessionfinish(session, exitstatus):
"""Clean up leaked test resources once, after all xdist workers finish.

Runs only on the controller (xdist workers carry a ``workerinput`` attribute
on their config; a non-xdist run has none). Deletes endpoints/models/configs
and S3 artifacts tagged for this run, but deliberately does NOT delete the
shared hub (see ``setup``); stale hubs are reclaimed by ``_cleanup_old_hubs``
on a subsequent run.
"""
if hasattr(session.config, "workerinput"):
return # xdist worker: the controller handles cleanup.

root_tmp_dir = _hub_state_root(session.config)
state_file = root_tmp_dir / "jumpstart_hub_state.json"
lock_file = root_tmp_dir / "jumpstart_hub_state.json.lock"

with FileLock(str(lock_file)):
if not state_file.is_file():
return
state = json.loads(state_file.read_text())
try:
_teardown(
test_suite_id=state["test_suite_id"],
test_hub_name=state["test_hub_name"],
delete_hub=False,
)
finally:
state_file.unlink()
5 changes: 4 additions & 1 deletion tests/integ/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"),
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"),
# Use the tiny dataset for the floating "*" version too: these are canary
# tests that only need to exercise the train/deploy flow, not produce a
# well-trained model. The full QNLI dataset made fit() dramatically slower.
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def test_jumpstart_estimator(setup):
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
max_run=259200, # avoid exceeding resource limits
instance_type="ml.g4dn.xlarge",
# Canary only needs to exercise the train/deploy flow, so cap training
# to a single epoch to keep fit() fast.
hyperparameters={"epochs": "1"},
)

# uses ml.g4dn.xlarge instance
Expand Down Expand Up @@ -111,6 +114,9 @@ def test_gated_model_training_v1(setup):
environment={"accept_eula": "true"},
max_run=259200, # avoid exceeding resource limits
tolerate_vulnerable_model=True,
# Canary only verifies the train/deploy flow, so cap training to a
# single step to keep fit() fast (sec_amazon has no tiny variant).
hyperparameters={"max_steps": "1"},
)

# uses ml.g5.12xlarge instance
Expand Down Expand Up @@ -153,6 +159,9 @@ def test_gated_model_training_v2(setup):
environment={"accept_eula": "true"},
max_run=259200, # avoid exceeding resource limits
tolerate_vulnerable_model=True, # tolerate old version of model
# Canary only verifies the train/deploy flow, so cap training to a
# single step to keep fit() fast (sec_amazon has no tiny variant).
hyperparameters={"max_steps": "1"},
)

# uses ml.g5.12xlarge instance
Expand Down Expand Up @@ -190,6 +199,7 @@ def test_gated_model_training_v2(setup):


@x_fail_if_ice
@pytest.mark.slow_test
@pytest.mark.skipif(
tests.integ.test_region() not in TRN2_SUPPORTED_REGIONS,
reason=f"TRN2 instances unavailable in {tests.integ.test_region()}.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
download_inference_assets,
get_sm_session,
get_tabular_data,
x_fail_if_ice,
)

INF2_SUPPORTED_REGIONS = {
Expand Down Expand Up @@ -192,6 +193,7 @@ def test_jumpstart_gated_model(setup):
assert response is not None


@x_fail_if_ice
def test_jumpstart_gated_model_inference_component_enabled(setup):

model_id = "meta-textgeneration-llama-2-7b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import pytest
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.hub.hub import Hub

from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
Expand All @@ -28,10 +27,9 @@
JUMPSTART_TAG,
)
from tests.integ.sagemaker.jumpstart.utils import (
get_public_hub_model_arn,
get_sm_session,
with_exponential_backoff,
get_training_dataset_for_model_and_version,
add_model_references_to_hub,
)

MAX_INIT_TIME_SECONDS = 5
Expand All @@ -43,23 +41,13 @@
}


@with_exponential_backoff()
def create_model_reference(hub_instance, model_arn):
try:
hub_instance.create_model_reference(model_arn=model_arn)
except Exception:
pass


@pytest.fixture(scope="session")
def add_model_references():
# Create Model References to test in Hub
hub_instance = Hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
# Create Model References to test in Hub (idempotent + waits for readiness)
add_model_references_to_hub(
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
model_ids=TEST_MODEL_IDS,
)
for model in TEST_MODEL_IDS:
model_arn = get_public_hub_model_arn(hub_instance, model)
create_model_reference(hub_instance, model_arn)


def test_jumpstart_hub_estimator(setup, add_model_references):
Expand All @@ -70,6 +58,9 @@ def test_jumpstart_hub_estimator(setup, add_model_references):
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
instance_type="ml.g4dn.xlarge",
# Canary only needs to exercise the train/deploy flow, so cap training
# to a single epoch to keep fit() fast.
hyperparameters={"epochs": "1"},
)

estimator.fit(
Expand Down Expand Up @@ -110,6 +101,9 @@ def test_jumpstart_hub_estimator_with_session(setup, add_model_references):
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
instance_type="ml.g4dn.xlarge",
# Canary only needs to exercise the train/deploy flow, so cap training
# to a single epoch to keep fit() fast.
hyperparameters={"epochs": "1"},
)

estimator.fit(
Expand Down Expand Up @@ -149,6 +143,9 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
instance_type="ml.g5.2xlarge",
# Canary only verifies the train/deploy flow, so cap training to a
# single step to keep fit() fast (sec_amazon has no tiny variant).
hyperparameters={"max_steps": "1"},
)

estimator.fit(
Expand Down
Loading
Loading