@@ -259,6 +259,16 @@ def __init__(
259259 else :
260260 logger .info ("No API key provided, using existing Claude CLI authentication" )
261261
262+ def _is_retryable_error (self , exc : BaseException ) -> bool :
263+ """Return True for transient errors that warrant a retry.
264+ asyncio.TimeoutError is intentional (user-configured timeout) — not retried.
265+ Only non-MCP CLIConnectionError is considered transient.
266+ """
267+ if isinstance (exc , CLIConnectionError ):
268+ msg = str (exc ).lower ()
269+ return "mcp" not in msg # "server" alone is too broad
270+ return False
271+
262272 async def execute_command (
263273 self ,
264274 prompt : str ,
@@ -393,43 +403,84 @@ async def _run_client() -> None:
393403 finally :
394404 await client .disconnect ()
395405
396- # Execute: race client against timeout and optional interrupt
397- run_task = asyncio .create_task (_run_client ())
406+ # Execute with timeout and retry, racing against optional interrupt
407+ max_attempts = max (1 , self .config .claude_retry_max_attempts )
408+ last_exc : Optional [BaseException ] = None
409+
410+ for attempt in range (max_attempts ):
411+ # Reset message accumulator each attempt so that a failed attempt
412+ # does not pollute the next one with partial/duplicate messages.
413+ # _run_client() closes over `messages` by reference (late-binding
414+ # closure), so clearing it here is seen by every new call.
415+ messages .clear ()
416+
417+ if attempt > 0 :
418+ delay = min (
419+ self .config .claude_retry_base_delay
420+ * (self .config .claude_retry_backoff_factor ** (attempt - 1 )),
421+ self .config .claude_retry_max_delay ,
422+ )
423+ logger .warning (
424+ "Retrying Claude SDK command" ,
425+ attempt = attempt + 1 ,
426+ max_attempts = max_attempts ,
427+ delay_seconds = delay ,
428+ )
429+ await asyncio .sleep (delay )
398430
399- interrupt_watcher : Optional ["asyncio.Task[None]" ] = None
400- if interrupt_event is not None :
431+ run_task = asyncio .create_task (_run_client ())
401432
402- async def _cancel_on_interrupt () -> None :
403- nonlocal interrupted
404- await interrupt_event .wait ()
405- interrupted = True
406- run_task .cancel ()
433+ interrupt_watcher : Optional ["asyncio.Task[None]" ] = None
434+ if interrupt_event is not None :
407435
408- interrupt_watcher = asyncio .create_task (_cancel_on_interrupt ())
436+ async def _cancel_on_interrupt () -> None :
437+ nonlocal interrupted
438+ await interrupt_event .wait ()
439+ interrupted = True
440+ run_task .cancel ()
409441
410- try :
411- await asyncio .wait_for (
412- asyncio .shield (run_task ),
413- timeout = self .config .claude_timeout_seconds ,
414- )
415- except asyncio .CancelledError :
416- if not interrupted :
417- raise
418- # Interrupt cancelled the task — wait for cleanup
419- try :
420- await run_task
421- except asyncio .CancelledError :
422- pass
423- except asyncio .TimeoutError :
424- run_task .cancel ()
442+ interrupt_watcher = asyncio .create_task (_cancel_on_interrupt ())
443+
444+ # Note: asyncio.TimeoutError is intentionally NOT retried —
445+ # it reflects a user-configured hard limit.
425446 try :
426- await run_task
447+ await asyncio .wait_for (
448+ asyncio .shield (run_task ),
449+ timeout = self .config .claude_timeout_seconds ,
450+ )
451+ break # success — exit retry loop
427452 except asyncio .CancelledError :
428- pass
429- raise
430- finally :
431- if interrupt_watcher is not None :
432- interrupt_watcher .cancel ()
453+ if not interrupted :
454+ raise
455+ # Interrupt cancelled the task — wait for cleanup
456+ try :
457+ await run_task
458+ except asyncio .CancelledError :
459+ pass
460+ break # user interrupted — don't retry
461+ except asyncio .TimeoutError :
462+ run_task .cancel ()
463+ try :
464+ await run_task
465+ except asyncio .CancelledError :
466+ pass
467+ raise # timeout — don't retry
468+ except CLIConnectionError as exc :
469+ if self ._is_retryable_error (exc ) and attempt < max_attempts - 1 :
470+ last_exc = exc
471+ logger .warning (
472+ "Transient connection error, will retry" ,
473+ attempt = attempt + 1 ,
474+ error = str (exc ),
475+ )
476+ continue
477+ raise # non-retryable or attempts exhausted
478+ finally :
479+ if interrupt_watcher is not None :
480+ interrupt_watcher .cancel ()
481+ else :
482+ if last_exc is not None :
483+ raise last_exc
433484
434485 # Extract cost, tools, and session_id from result message
435486 cost = 0.0
0 commit comments