@@ -157,7 +157,7 @@ def _calculate_transition_duration(trans) -> Tuple[str, str]:
157157def wait (
158158 training_job : TrainingJob ,
159159 poll : int = 5 ,
160- timeout : Optional [int ] = None
160+ timeout : Optional [int ] = 3000
161161) -> None :
162162 """Wait for training job to complete with progress tracking.
163163
@@ -192,8 +192,10 @@ def wait(
192192 iteration = 0
193193 while True :
194194 iteration += 1
195- time .sleep (poll )
196- training_job .refresh ()
195+ time .sleep (1 )
196+ if iteration == poll :
197+ training_job .refresh ()
198+ iteration = 0
197199 clear_output (wait = True )
198200
199201 status = training_job .training_job_status
@@ -302,7 +304,7 @@ def wait(
302304 raise FailedStatusError (resource_type = "TrainingJob" , status = status , reason = failure_reason )
303305
304306 if timeout and elapsed >= timeout :
305- raise TimeoutExceededError (resouce_type = "TrainingJob" , status = status )
307+ raise TimeoutExceededError (resource_type = "TrainingJob" , status = status )
306308
307309 else :
308310 print (f"\n TrainingJob Name: { training_job .training_job_name } " )
@@ -363,7 +365,7 @@ def wait(
363365 raise FailedStatusError (resource_type = "TrainingJob" , status = status , reason = failure_reason )
364366
365367 if timeout and elapsed >= timeout :
366- raise TimeoutExceededError (resouce_type = "TrainingJob" , status = status )
368+ raise TimeoutExceededError (resource_type = "TrainingJob" , status = status )
367369
368370
369371 except (FailedStatusError , TimeoutExceededError ):
0 commit comments