22
33import asyncio
44import json
5+ import logging
56import os
67import random
78import socket
1920from tensordict import TensorDict
2021from verl import DataProto
2122
23+ # =============================================================================
24+ # Tool Call Filtering Support (for filtering unexpected tool call turns)
25+ # Reference: Youtu-Agent implementation in contrib/youtu-agent-lightning branch
26+ # The ToolParser extracts tool calls from response tokens to detect cases where
27+ # the model continues generating after a tool call (hallucinated tool responses)
28+ # instead of properly stopping with </tool_call><|im_end|>
29+ # =============================================================================
30+ try :
31+ from verl .experimental .agent_loop .tool_parser import ToolParser
32+ TOOL_PARSER_AVAILABLE = True
33+ except ImportError :
34+ TOOL_PARSER_AVAILABLE = False
35+
2236from agentlightning import LLM , AgentLightningServer , NamedResources , RolloutLegacy
2337from agentlightning .adapter .triplet import TracerTraceToTriplet , TraceToTripletBase
2438from agentlightning .llm_proxy import LLMProxy , ModelConfig
2539from agentlightning .store .base import LightningStore
2640from agentlightning .types import EnqueueRolloutRequest , Rollout , RolloutConfig , Task
2741
42+ logger = logging .getLogger (__name__ )
43+
2844__all__ = [
2945 "AgentModeDaemon" ,
3046 "get_left_padded_ids_and_attention_mask" ,
@@ -283,6 +299,11 @@ def __init__(
283299 self ._proxy_thread : Optional [threading .Thread ] = None
284300 self .is_train = True
285301
302+ # Tool Call Filtering Setup (config key: trace_aggregator["filter_unexpected_tool_calls"])
303+ self .tool_parser = None
304+ self .toolcall_candidate_token_last2_list = []
305+ self ._setup_tool_call_filter (train_information , tokenizer )
306+
286307 def _internal_loop_runner (self ):
287308 """Run the internal loop."""
288309 loop = asyncio .new_event_loop ()
@@ -291,6 +312,112 @@ def _internal_loop_runner(self):
291312 loop .run_forever ()
292313 loop .close ()
293314
315+ # =========================================================================
316+ # Tool Call Filtering Methods
317+ # Reference: Youtu-Agent implementation (contrib/youtu-agent-lightning)
318+ # Purpose: Filter out "unexpected tool call turns" where the model continues
319+ # generating text after a tool call instead of stopping properly.
320+ # =========================================================================
321+
322+ def _setup_tool_call_filter (self , train_information : Dict [str , Any ], tokenizer : Any ) -> None :
323+ """Initialize tool parser and valid ending token patterns for filtering.
324+
325+ Uses apply_chat_template to auto-detect the correct tool call ending tokens
326+ rather than hardcoding token IDs. Also builds variants with eos/pad tokens
327+ to allow various ending conditions and prevent over-filtering.
328+
329+ Args:
330+ train_information: Training config containing 'format' for toolcall format
331+ tokenizer: The tokenizer used for encoding/decoding
332+ """
333+ if not TOOL_PARSER_AVAILABLE :
334+ print ("Warning: ToolParser not available, tool call filtering disabled." )
335+ self .tool_parser = None
336+ return
337+
338+ toolcall_format = train_information .get ("format" , "hermes" )
339+ self .tool_parser = ToolParser .get_tool_parser (toolcall_format , tokenizer )
340+
341+ # Use chat template to detect the actual tool call ending token sequence
342+ # Example uses calculator tool to match calc-x example for consistency
343+ tools_examples = [{
344+ "type" : "function" ,
345+ "name" : "calculate" ,
346+ "description" : "Evaluate a mathematical expression" ,
347+ "parameters" : {
348+ "type" : "object" ,
349+ "properties" : {
350+ "expression" : {"type" : "string" , "description" : "Math expression, e.g., '2 + 3 * 4'" },
351+ },
352+ "required" : ["expression" ],
353+ },
354+ }]
355+ toolcall_message_examples = [
356+ {"role" : "user" , "content" : "What is 15 + 27?" },
357+ {"role" : "assistant" , "content" : "" , "tool_calls" : [{
358+ "id" : "call_001" ,
359+ "type" : "function" ,
360+ "function" : {"name" : "calculate" , "arguments" : '{"expression":"15 + 27"}' },
361+ }]},
362+ ]
363+ toolcall_example_chat_template = tokenizer .apply_chat_template (
364+ toolcall_message_examples , tools = tools_examples ,
365+ add_generation_prompt = False , tokenize = False ,
366+ )
367+ # Extract the last 2 tokens from the chat template output (e.g., </tool_call><|im_end|>)
368+ toolcall_example_token_last2 = tokenizer .encode (toolcall_example_chat_template .strip ())[- 2 :]
369+
370+ eos_token_id = tokenizer .eos_token_id
371+ pad_token_id = tokenizer .pad_token_id
372+
373+ # Build candidate list: the detected ending + variants with eos/pad
374+ # This allows various tool-call ending conditions to prevent over-filtering
375+ toolcall_candidate_token_last2_list = [toolcall_example_token_last2 ]
376+ if toolcall_example_token_last2 [- 1 ] != eos_token_id :
377+ toolcall_candidate_token_last2_list .append ([toolcall_example_token_last2 [0 ], eos_token_id ])
378+ if toolcall_example_token_last2 [- 1 ] != pad_token_id :
379+ toolcall_candidate_token_last2_list .append ([toolcall_example_token_last2 [0 ], pad_token_id ])
380+
381+ self .toolcall_candidate_token_last2_list = toolcall_candidate_token_last2_list
382+ logger .info (
383+ f"Tool call filter initialized: { eos_token_id = } , { pad_token_id = } , "
384+ f"candidates={ self .toolcall_candidate_token_last2_list } "
385+ )
386+
387+
388+ def _is_valid_tool_call_response (self , response_ids : List [int ]) -> Tuple [bool , bool ]:
389+ """Check if a response with tool calls ends with valid ending tokens.
390+
391+ Uses strict last-2-token check (same as youtu branch): the response must end
392+ with one of the candidate token pairs (e.g., </tool_call><|im_end|> or
393+ </tool_call><|endoftext|>).
394+
395+ Args:
396+ response_ids: List of token IDs from the model's response
397+
398+ Returns:
399+ Tuple of (has_tool_calls, has_valid_ending):
400+ - has_tool_calls: True if the response contains tool calls
401+ - has_valid_ending: True if no tool calls, or tool calls with proper ending
402+ """
403+ if self .tool_parser is None :
404+ return False , True
405+
406+ _ , tool_calls = asyncio .run (self .tool_parser .extract_tool_calls (response_ids ))
407+
408+ if not tool_calls :
409+ return False , True
410+
411+ if len (response_ids ) < 2 :
412+ return True , False
413+
414+ # Strict last-2 check against all valid ending candidates
415+ for candidate in self .toolcall_candidate_token_last2_list :
416+ if response_ids [- 2 ] == candidate [0 ] and response_ids [- 1 ] == candidate [1 ]:
417+ return True , True
418+
419+ return True , False
420+
294421 # Multimodal utilities for M-RoPE position embeddings
295422
296423 def _is_mrope_model (self ) -> bool :
@@ -821,6 +948,12 @@ def get_train_data_batch(
821948 finished_id_to_sample_info : Dict [str , Dict [str , Any ]] = {}
822949 finished_id_to_final_reward : Dict [str , float ] = {}
823950 sample_with_reward_count = 0
951+
952+ # Tool call filtering metrics
953+ n_total_turns_before_filter = 0
954+ n_unexpected_tool_calls = 0
955+ n_skipped_rollouts_by_filter = 0
956+
824957 for rollout_id , rollout in self ._completed_rollouts_v0 .items ():
825958 original_sample = self ._task_id_to_original_sample [rollout_id ]
826959 sample_with_reward_count += int (rollout .final_reward is not None )
@@ -842,6 +975,31 @@ def get_train_data_batch(
842975 }
843976 for t in rollout .triplets
844977 ]
978+
979+ # Filter void turns (empty prompt or response)
980+ trace_list = [
981+ trace for trace in trace_list
982+ if len (trace ["prompt_ids" ]) and len (trace ["response_ids" ])
983+ ]
984+
985+ # Filter turns with unexpected tool calls (model continues after tool call)
986+ if self .tool_parser is not None :
987+ n_total_turns_before_filter += len (trace_list )
988+ trace_list_filtered = []
989+ for trace in trace_list :
990+ has_tool_calls , has_valid_ending = self ._is_valid_tool_call_response (trace ["response_ids" ])
991+ if has_tool_calls and not has_valid_ending :
992+ n_unexpected_tool_calls += 1
993+ else :
994+ trace_list_filtered .append (trace )
995+
996+ if self .trace_aggregator .get ("filter_unexpected_tool_calls" , False ):
997+ if len (trace_list_filtered ) <= 1 :
998+ n_skipped_rollouts_by_filter += 1
999+ finished_id_to_final_reward [rollout_id ] = final_reward
1000+ continue
1001+ trace_list = trace_list_filtered
1002+
8451003 info = {
8461004 "reward" : final_reward ,
8471005 "trace_list" : trace_list ,
@@ -1123,6 +1281,14 @@ def get_train_data_batch(
11231281 and self .trace_aggregator .get ("debug" , False )
11241282 else {}
11251283 ),
1284+ "training/n_unexpected_tool_calls" : n_unexpected_tool_calls ,
1285+ "training/n_total_turns_before_filter" : n_total_turns_before_filter ,
1286+ "training/unexpected_tool_call_ratio" : (
1287+ n_unexpected_tool_calls / n_total_turns_before_filter if n_total_turns_before_filter > 0 else 0.0
1288+ ),
1289+ "training/n_skipped_rollouts_by_filter" : n_skipped_rollouts_by_filter ,
1290+ "training/filter_enabled" : float (self .trace_aggregator .get ("filter_unexpected_tool_calls" , False )),
1291+ "training/reward_std" : np .std (list (finished_id_to_final_reward .values ())),
11261292 }
11271293
11281294 # Add non-tensor data for advantage calculation and logging
0 commit comments