diff --git a/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py b/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py index 7490a373b5..1da31f71c6 100644 --- a/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py +++ b/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py @@ -19,6 +19,7 @@ from __future__ import absolute_import import json +import time import pytest import logging @@ -140,21 +141,33 @@ def test_base_model_evaluation_uses_correct_weights(self): # Step 3: Verify pipeline structure logger.info("\nVerifying pipeline structure...") - execution.refresh() - # Check that we have both base and custom inference steps - step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else [] - - logger.info(f"Pipeline steps ({len(step_names)}): {step_names}") - - # If no steps yet, wait a bit for pipeline to initialize - if not step_names: - logger.info("No steps found yet, waiting for pipeline initialization...") - import time - time.sleep(10) + # Poll for steps to appear since the pipeline takes time to initialize all steps + max_wait_seconds = 120 + poll_interval = 10 + elapsed = 0 + step_names = [] + + while elapsed < max_wait_seconds: execution.refresh() step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else [] - logger.info(f"Pipeline steps after wait ({len(step_names)}): {step_names}") + logger.info(f"Pipeline steps after {elapsed}s ({len(step_names)}): {step_names}") + + # Check if both inference steps have appeared + has_base_step = any("base" in name.lower() and "inference" in name.lower() for name in step_names) + has_custom_step = any("custom" in name.lower() and "inference" in name.lower() for name in step_names) + if has_base_step and has_custom_step: + break + + # Also break if the pipeline has finished (all steps reported) + if execution.status.overall_status in ("Succeeded", "Failed", "Stopped"): + logger.info(f"Pipeline reached terminal status: {execution.status.overall_status}") + break + + time.sleep(poll_interval) + elapsed += poll_interval + + logger.info(f"Final pipeline steps ({len(step_names)}): {step_names}") # Verify both inference steps exist (case-insensitive, flexible matching) has_base_step = any("base" in name.lower() and "inference" in name.lower() for name in step_names) @@ -275,19 +288,31 @@ def test_base_model_false_still_works(self): logger.info(f" Execution ARN: {execution.arn}") # Verify pipeline structure - should only have custom inference step - execution.refresh() - step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else [] - - logger.info(f"Pipeline steps ({len(step_names)}): {step_names}") - - # If no steps yet, wait a bit for pipeline to initialize - if not step_names: - logger.info("No steps found yet, waiting for pipeline initialization...") - import time - time.sleep(10) + # Poll for steps to appear since the pipeline takes time to initialize all steps + max_wait_seconds = 120 + poll_interval = 10 + elapsed = 0 + step_names = [] + + while elapsed < max_wait_seconds: execution.refresh() step_names = [step.name for step in execution.status.step_details] if execution.status.step_details else [] - logger.info(f"Pipeline steps after wait ({len(step_names)}): {step_names}") + logger.info(f"Pipeline steps after {elapsed}s ({len(step_names)}): {step_names}") + + # Check if the custom inference step has appeared + has_custom_step = any("custom" in name.lower() and "inference" in name.lower() for name in step_names) + if has_custom_step: + break + + # Also break if the pipeline has finished (all steps reported) + if execution.status.overall_status in ("Succeeded", "Failed", "Stopped"): + logger.info(f"Pipeline reached terminal status: {execution.status.overall_status}") + break + + time.sleep(poll_interval) + elapsed += poll_interval + + logger.info(f"Final pipeline steps ({len(step_names)}): {step_names}") # Should NOT have base inference step (case-insensitive, flexible matching) has_base_step = any("base" in name.lower() and "inference" in name.lower() for name in step_names) diff --git a/sagemaker-train/tests/unit/train/evaluate/test_execution.py b/sagemaker-train/tests/unit/train/evaluate/test_execution.py index 562e2e9e00..b4cd07a59b 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_execution.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_execution.py @@ -1114,7 +1114,7 @@ def test_wait_fails_on_failed_status(self): with pytest.raises(FailedStatusError): execution.wait(target_status="Succeeded", poll=1) - @patch("time.time") + @patch("sagemaker.train.evaluate.execution.time") def test_wait_timeout_exceeded(self, mock_time): """Test wait raises exception on timeout.""" execution = EvaluationPipelineExecution( @@ -1130,8 +1130,16 @@ def test_wait_timeout_exceeded(self, mock_time): execution._pipeline_execution = mock_pe - # Mock time to simulate timeout - mock_time.side_effect = [0, 10, 20, 30, 40, 50, 60] # Exceeds timeout + # Mock time.time() to simulate timeout - use a counter that always exceeds timeout + # First call sets start_time=0, second call returns value >= timeout (5) + call_count = {"n": 0} + def time_side_effect(): + val = call_count["n"] * 10 + call_count["n"] += 1 + return val + + mock_time.time.side_effect = time_side_effect + mock_time.sleep = MagicMock() # no-op sleep with pytest.raises(TimeoutExceededError, match="EvaluationJob") as exc_info: execution.wait(target_status="Succeeded", poll=1, timeout=5)