@@ -139,6 +139,12 @@ class Coder:
139139 partial_response_tool_calls = []
140140 commit_before_message = []
141141 message_cost = 0.0
142+ total_tokens_sent = 0
143+ total_tokens_received = 0
144+ total_cached_tokens = 0
145+ message_tokens_sent = 0
146+ message_tokens_received = 0
147+ message_cached_tokens = 0
142148 add_cache_headers = False
143149 cache_warming_thread = None
144150 num_cache_warming_pings = 0
@@ -227,6 +233,7 @@ async def create(
227233 ignore_mentions = from_coder .ignore_mentions ,
228234 total_tokens_sent = from_coder .total_tokens_sent ,
229235 total_tokens_received = from_coder .total_tokens_received ,
236+ total_cached_tokens = from_coder .total_cached_tokens ,
230237 file_watcher = from_coder .file_watcher ,
231238 mcp_manager = from_coder .mcp_manager ,
232239 uuid = from_coder .uuid ,
@@ -316,6 +323,7 @@ def __init__(
316323 ignore_mentions = None ,
317324 total_tokens_sent = 0 ,
318325 total_tokens_received = 0 ,
326+ total_cached_tokens = 0 ,
319327 file_watcher = None ,
320328 auto_copy_context = False ,
321329 auto_accept_architect = True ,
@@ -331,6 +339,7 @@ def __init__(
331339 ):
332340 # initialize from args.map_cache_dir
333341 self .interrupt_event = asyncio .Event ()
342+ self .coroutines = coroutines
334343 self .uuid = generate_unique_id ()
335344 if uuid :
336345 self .uuid = uuid
@@ -388,8 +397,10 @@ def __init__(
388397 self .total_cost = total_cost
389398 self .total_tokens_sent = total_tokens_sent
390399 self .total_tokens_received = total_tokens_received
400+ self .total_cached_tokens = total_cached_tokens
391401 self .message_tokens_sent = 0
392402 self .message_tokens_received = 0
403+ self .message_cached_tokens = 0
393404
394405 self .token_profiler = TokenProfiler (
395406 enable_printing = nested .getter (self .args , "show_speed" , False )
@@ -1370,11 +1381,6 @@ async def _run_parallel(self, with_message=None, preproc=True):
13701381 except (SwitchCoderSignal , SystemExit ):
13711382 # Re-raise SwitchCoder to be handled by outer try block
13721383 raise
1373- except KeyboardInterrupt :
1374- # Handle keyboard interrupt gracefully
1375- self .io .set_placeholder ("" )
1376- self .io .stop_spinner ()
1377- self .keyboard_interrupt ()
13781384 finally :
13791385 # Signal tasks to stop
13801386 self .input_running = False
@@ -1454,10 +1460,6 @@ async def input_task(self, preproc):
14541460
14551461 await asyncio .sleep (0.1 ) # Small yield to prevent tight loop
14561462
1457- except KeyboardInterrupt :
1458- self .io .set_placeholder ("" )
1459- self .keyboard_interrupt ()
1460- await self .io .stop_task_streams ()
14611463 except (SwitchCoderSignal , SystemExit ):
14621464 raise
14631465 except Exception as e :
@@ -1738,8 +1740,7 @@ def keyboard_interrupt(self):
17381740 # Ensure cursor is visible on exit
17391741 Console ().show_cursor (True )
17401742
1741- self .io .tool_warning ("\n \n ^C KeyboardInterrupt" )
1742-
1743+ self .io .tool_warning ("^C KeyboardInterrupt" )
17431744 self .interrupt_event .set ()
17441745 self .last_keyboard_interrupt = time .time ()
17451746
@@ -2262,9 +2263,16 @@ async def send_message(self, inp):
22622263 self .io .tool_error (err_msg )
22632264
22642265 self .io .tool_output (f"Retrying in { retry_delay :.1f} seconds..." )
2265- await asyncio .sleep (retry_delay )
2266+
2267+ _res , interrupted_sleep = await coroutines .interruptible (
2268+ asyncio .sleep (retry_delay ), self .interrupt_event
2269+ )
2270+ if interrupted_sleep :
2271+ interrupted = True
2272+ break
2273+
22662274 continue
2267- except KeyboardInterrupt :
2275+ except ( KeyboardInterrupt , asyncio . CancelledError ) :
22682276 interrupted = True
22692277 break
22702278 except FinishReasonLength :
@@ -2629,11 +2637,19 @@ async def _execute_mcp_tools(self, server, tool_calls):
26292637 all_results_content .append ("Tool Request Aborted." )
26302638 continue
26312639
2632- call_result = await experimental_mcp_client .call_openai_tool (
2633- session = session ,
2634- openai_tool = new_tool_call ,
2640+ async def do_tool_call ():
2641+ return await experimental_mcp_client .call_openai_tool (
2642+ session = session ,
2643+ openai_tool = new_tool_call ,
2644+ )
2645+
2646+ call_result , interrupted = await coroutines .interruptible (
2647+ do_tool_call (), self .interrupt_event
26352648 )
26362649
2650+ if interrupted :
2651+ raise KeyboardInterrupt ("Tool call interrupted" )
2652+
26372653 content_parts = []
26382654 if call_result .content :
26392655 for item in call_result .content :
@@ -2678,6 +2694,9 @@ async def _execute_mcp_tools(self, server, tool_calls):
26782694 }
26792695 )
26802696
2697+ except KeyboardInterrupt :
2698+ self .io .tool_warning (f"Tool call { tool_call .function .name } interrupted." )
2699+ raise
26812700 except Exception as e :
26822701 tool_error = f"Error executing tool call { tool_call .function .name } : \n { e } "
26832702 self .io .tool_warning (
@@ -2694,6 +2713,9 @@ async def _execute_mcp_tools(self, server, tool_calls):
26942713 tool_responses .append (
26952714 {"role" : "tool" , "tool_call_id" : tool_call .id , "content" : connection_error }
26962715 )
2716+ except asyncio .CancelledError :
2717+ # Re-raise CancelledError to ensure the task cancellation propagates
2718+ raise
26972719 except Exception as e :
26982720 connection_error = f"Could not connect to server { server .name } \n { e } "
26992721 self .io .tool_warning (connection_error )
@@ -2728,7 +2750,15 @@ async def process_tool_calls(self, tool_call_response):
27282750 return False
27292751
27302752 # 5. Execute tools
2731- tool_responses_by_server = await self ._execute_tool_groups (tool_groups )
2753+ self .interrupt_event .clear ()
2754+
2755+ tool_responses_by_server , interrupted = await coroutines .interruptible (
2756+ self ._execute_tool_groups (tool_groups ), self .interrupt_event
2757+ )
2758+
2759+ if interrupted :
2760+ self .io .tool_warning ("Tool execution interrupted." )
2761+ return False
27322762
27332763 # 6. Add responses to conversation (re-prefixing if necessary)
27342764 tool_responses = []
@@ -3040,33 +3070,22 @@ async def send(self, messages, model=None, functions=None, tools=None):
30403070 self .token_profiler .start ()
30413071
30423072 try :
3043- completion_task = asyncio .create_task (
3044- model .send_completion (
3045- messages ,
3046- functions ,
3047- self .stream ,
3048- self .temperature ,
3049- # This could include any tools, but for now it is just MCP tools
3050- tools = tools ,
3051- override_kwargs = self .model_kwargs .copy (),
3052- )
3073+ completion_coro = model .send_completion (
3074+ messages ,
3075+ functions ,
3076+ self .stream ,
3077+ self .temperature ,
3078+ # This could include any tools, but for now it is just MCP tools
3079+ tools = tools ,
3080+ override_kwargs = self .model_kwargs .copy (),
3081+ interrupt_event = self .interrupt_event ,
30533082 )
3054- interrupt_task = asyncio .create_task (self .interrupt_event .wait ())
30553083
3056- done , pending = await asyncio .wait (
3057- {completion_task , interrupt_task },
3058- return_when = asyncio .FIRST_COMPLETED ,
3084+ (hash_object , completion ), interrupted = await coroutines .interruptible (
3085+ completion_coro , self .interrupt_event
30593086 )
3060-
3061- if interrupt_task in done :
3062- completion_task .cancel ()
3063- try :
3064- await completion_task
3065- except asyncio .CancelledError :
3066- pass
3087+ if interrupted :
30673088 raise KeyboardInterrupt
3068-
3069- hash_object , completion = completion_task .result ()
30703089 self .chat_completion_call_hashes .append (hash_object .hexdigest ())
30713090
30723091 if not isinstance (completion , ModelResponse ):
@@ -3089,7 +3108,7 @@ async def send(self, messages, model=None, functions=None, tools=None):
30893108 self .token_profiler .on_error ()
30903109 self .calculate_and_show_tokens_and_cost (messages , completion )
30913110 raise
3092- except KeyboardInterrupt as kbi :
3111+ except ( KeyboardInterrupt , asyncio . CancelledError ) as kbi :
30933112 self .keyboard_interrupt ()
30943113 raise kbi
30953114 finally :
@@ -3498,10 +3517,13 @@ def calculate_and_show_tokens_and_cost(self, messages, completion=None):
34983517 if completion and hasattr (completion , "usage" ) and completion .usage is not None :
34993518 prompt_tokens = completion .usage .prompt_tokens
35003519 completion_tokens = completion .usage .completion_tokens
3501- cache_hit_tokens = getattr (completion .usage , "prompt_cache_hit_tokens" , 0 ) or getattr (
3502- completion .usage , "cache_read_input_tokens" , 0
3520+ cache_hit_tokens = (
3521+ getattr (completion .usage , "prompt_cache_hit_tokens" , 0 )
3522+ or getattr (completion .usage , "cache_read_input_tokens" , 0 )
3523+ or 0
35033524 )
3504- cache_write_tokens = getattr (completion .usage , "cache_creation_input_tokens" , 0 )
3525+ cache_write_tokens = getattr (completion .usage , "cache_creation_input_tokens" , 0 ) or 0
3526+ self .message_cached_tokens += cache_hit_tokens
35053527
35063528 if hasattr (completion .usage , "cache_read_input_tokens" ) or hasattr (
35073529 completion .usage , "cache_creation_input_tokens"
@@ -3534,8 +3556,22 @@ def calculate_and_show_tokens_and_cost(self, messages, completion=None):
35343556 tokens_report , self .message_tokens_sent , self .message_tokens_received
35353557 )
35363558
3559+ total_combined_tokens = (
3560+ self .total_tokens_sent
3561+ + self .total_tokens_received
3562+ + self .message_tokens_sent
3563+ + self .message_tokens_received
3564+ )
3565+ total_combined_cached = self .total_cached_tokens + self .message_cached_tokens
3566+
3567+ total_stats = f"{ format_tokens (total_combined_tokens )} "
3568+ if total_combined_cached :
3569+ total_stats += f"/{ format_tokens (total_combined_cached )} "
3570+
3571+ total_stats += " ↑↓"
3572+
35373573 if not self .get_active_model ().info .get ("input_cost_per_token" ):
3538- self .usage_report = tokens_report
3574+ self .usage_report = tokens_report + " \n " + total_stats
35393575 return
35403576
35413577 try :
@@ -3552,11 +3588,8 @@ def calculate_and_show_tokens_and_cost(self, messages, completion=None):
35523588 self .total_cost += cost
35533589 self .message_cost += cost
35543590
3555- total_combined_tokens = (
3556- self .total_tokens_sent + self .total_tokens_received + prompt_tokens + completion_tokens
3557- )
35583591 cost_report = (
3559- f"${ self .format_cost (self .message_cost )} • { format_tokens ( total_combined_tokens ) } ↑↓ "
3592+ f"${ self .format_cost (self .message_cost )} • { total_stats } "
35603593 f" ${ self .format_cost (self .total_cost )} "
35613594 )
35623595
@@ -3614,6 +3647,7 @@ def show_usage_report(self):
36143647
36153648 self .total_tokens_sent += self .message_tokens_sent
36163649 self .total_tokens_received += self .message_tokens_received
3650+ self .total_cached_tokens += self .message_cached_tokens
36173651
36183652 if self .tui and self .tui ():
36193653 self .tui ().update_cost (self .usage_report .replace ("\n " , " " ))
@@ -3624,6 +3658,7 @@ def show_usage_report(self):
36243658 self .message_cost = 0.0
36253659 self .message_tokens_sent = 0
36263660 self .message_tokens_received = 0
3661+ self .message_cached_tokens = 0
36273662
36283663 def get_multi_response_content_in_progress (self , final = False ):
36293664 cur = self .multi_response_content or ""
0 commit comments