Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import absolute_import

import json
import time
import pytest
import logging

Expand Down Expand Up @@ -140,21 +141,33 @@ def test_base_model_evaluation_uses_correct_weights(self):

# Step 3: Verify pipeline structure
logger.info("\nVerifying pipeline structure...")
execution.refresh()

# Check that we have both base and custom inference steps
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []

logger.info(f"Pipeline steps ({len(step_names)}): {step_names}")

# If no steps yet, wait a bit for pipeline to initialize
if not step_names:
logger.info("No steps found yet, waiting for pipeline initialization...")
import time
time.sleep(10)
# Poll for steps to appear since the pipeline takes time to initialize all steps
max_wait_seconds = 120
poll_interval = 10
elapsed = 0
step_names = []

while elapsed < max_wait_seconds:
execution.refresh()
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []
logger.info(f"Pipeline steps after wait ({len(step_names)}): {step_names}")
logger.info(f"Pipeline steps after {elapsed}s ({len(step_names)}): {step_names}")

# Check if both inference steps have appeared
has_base_step = any("base" in name.lower() and "inference" in name.lower() for name in step_names)
has_custom_step = any("custom" in name.lower() and "inference" in name.lower() for name in step_names)
if has_base_step and has_custom_step:
break

# Also break if the pipeline has finished (all steps reported)
if execution.status.overall_status in ("Succeeded", "Failed", "Stopped"):
logger.info(f"Pipeline reached terminal status: {execution.status.overall_status}")
break

time.sleep(poll_interval)
elapsed += poll_interval

logger.info(f"Final pipeline steps ({len(step_names)}): {step_names}")

# Verify both inference steps exist (case-insensitive, flexible matching)
has_base_step = any("base" in name.lower() and "inference" in name.lower() for name in step_names)
Expand Down Expand Up @@ -275,19 +288,31 @@ def test_base_model_false_still_works(self):
logger.info(f" Execution ARN: {execution.arn}")

# Verify pipeline structure - should only have custom inference step
execution.refresh()
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []

logger.info(f"Pipeline steps ({len(step_names)}): {step_names}")

# If no steps yet, wait a bit for pipeline to initialize
if not step_names:
logger.info("No steps found yet, waiting for pipeline initialization...")
import time
time.sleep(10)
# Poll for steps to appear since the pipeline takes time to initialize all steps
max_wait_seconds = 120
poll_interval = 10
elapsed = 0
step_names = []

while elapsed < max_wait_seconds:
execution.refresh()
step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else []
logger.info(f"Pipeline steps after wait ({len(step_names)}): {step_names}")
logger.info(f"Pipeline steps after {elapsed}s ({len(step_names)}): {step_names}")

# Check if the custom inference step has appeared
has_custom_step = any("custom" in name.lower() and "inference" in name.lower() for name in step_names)
if has_custom_step:
break

# Also break if the pipeline has finished (all steps reported)
if execution.status.overall_status in ("Succeeded", "Failed", "Stopped"):
logger.info(f"Pipeline reached terminal status: {execution.status.overall_status}")
break

time.sleep(poll_interval)
elapsed += poll_interval

logger.info(f"Final pipeline steps ({len(step_names)}): {step_names}")

# Should NOT have base inference step (case-insensitive, flexible matching)
has_base_step = any("base" in name.lower() and "inference" in name.lower() for name in step_names)
Expand Down
14 changes: 11 additions & 3 deletions sagemaker-train/tests/unit/train/evaluate/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ def test_wait_fails_on_failed_status(self):
with pytest.raises(FailedStatusError):
execution.wait(target_status="Succeeded", poll=1)

@patch("time.time")
@patch("sagemaker.train.evaluate.execution.time")
def test_wait_timeout_exceeded(self, mock_time):
"""Test wait raises exception on timeout."""
execution = EvaluationPipelineExecution(
Expand All @@ -1130,8 +1130,16 @@ def test_wait_timeout_exceeded(self, mock_time):

execution._pipeline_execution = mock_pe

# Mock time to simulate timeout
mock_time.side_effect = [0, 10, 20, 30, 40, 50, 60] # Exceeds timeout
# Mock time.time() to simulate timeout - use a counter that always exceeds timeout
# First call sets start_time=0, second call returns value >= timeout (5)
call_count = {"n": 0}
def time_side_effect():
val = call_count["n"] * 10
call_count["n"] += 1
return val

mock_time.time.side_effect = time_side_effect
mock_time.sleep = MagicMock() # no-op sleep

with pytest.raises(TimeoutExceededError, match="EvaluationJob") as exc_info:
execution.wait(target_status="Succeeded", poll=1, timeout=5)
Expand Down
Loading