Skip to content

Commit 7a5ee47

Browse files
committed
feat(verl): add unexpected tool call filtering
Add filtering for "unexpected tool call" turns where the model continues generating after a tool call instead of stopping at </tool_call><|im_end|>. This helps prevent entropy explosion during GRPO training. Changes: - daemon.py: Add _setup_tool_call_filter(), _count_invalid_turns(), _filter_invalid_turns(), and void turn filtering - config.yaml: Add filter_unexpected_tool_calls option (default: False) - trainer.py: Fix missing gts parameter in _dump_generations() - examples/calc_x/train_calc_agent.py: Add --filter-unexpected-tool-calls CLI flag Key improvements over Youtu branch: - Uses apply_chat_template() for model-agnostic token detection - Supports multiple valid endings (eos_token, pad_token variants) - Uses calculator tool example for calc-x consistency Reference: contrib/youtu-agent-lightning branch
1 parent 5f3093d commit 7a5ee47

4 files changed

Lines changed: 254 additions & 0 deletions

File tree

agentlightning/verl/config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ agentlightning:
1414
trajectory_max_response_length: 8192 # supported in trajectory level aggregation, suggest to set as maximum length for the cumulative agent responses in the full trajectory, i.e., n_turns * (max_response_length + max_prompt_length)
1515
debug: False # supported in trajectory level aggregation, enable to diagnose trace merging failures
1616
mismatch_log_dir: ./mismatch_cases # supported in trajectory level aggregation with debug=True, directory to store logs of mismatch cases
17+
# =========================================================================
18+
# Tool Call Filtering (Youtu-Agent style)
19+
# When enabled, filters out "unexpected tool call" turns where the model
20+
# continues generating after a tool call instead of stopping properly.
21+
# This helps prevent entropy explosion during RL training.
22+
# Reference: contrib/youtu-agent-lightning branch
23+
# =========================================================================
24+
filter_unexpected_tool_calls: False # set to True to enable filtering
1725

1826
data:
1927
filter_overlong_prompts: false

agentlightning/verl/daemon.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import json
5+
import logging
56
import os
67
import random
78
import socket
@@ -19,12 +20,27 @@
1920
from tensordict import TensorDict
2021
from verl import DataProto
2122

23+
# =============================================================================
24+
# Tool Call Filtering Support (for filtering unexpected tool call turns)
25+
# Reference: Youtu-Agent implementation in contrib/youtu-agent-lightning branch
26+
# The ToolParser extracts tool calls from response tokens to detect cases where
27+
# the model continues generating after a tool call (hallucinated tool responses)
28+
# instead of properly stopping with </tool_call><|im_end|>
29+
# =============================================================================
30+
try:
31+
from verl.experimental.agent_loop.tool_parser import ToolParser
32+
TOOL_PARSER_AVAILABLE = True
33+
except ImportError:
34+
TOOL_PARSER_AVAILABLE = False
35+
2236
from agentlightning import LLM, AgentLightningServer, NamedResources, RolloutLegacy
2337
from agentlightning.adapter.triplet import TracerTraceToTriplet, TraceToTripletBase
2438
from agentlightning.llm_proxy import LLMProxy, ModelConfig
2539
from agentlightning.store.base import LightningStore
2640
from agentlightning.types import EnqueueRolloutRequest, Rollout, RolloutConfig, Task
2741

42+
logger = logging.getLogger(__name__)
43+
2844
__all__ = [
2945
"AgentModeDaemon",
3046
"get_left_padded_ids_and_attention_mask",
@@ -283,6 +299,11 @@ def __init__(
283299
self._proxy_thread: Optional[threading.Thread] = None
284300
self.is_train = True
285301

302+
# Tool Call Filtering Setup (config key: trace_aggregator["filter_unexpected_tool_calls"])
303+
self.tool_parser = None
304+
self.toolcall_candidate_token_last2_list = []
305+
self._setup_tool_call_filter(train_information, tokenizer)
306+
286307
def _internal_loop_runner(self):
287308
"""Run the internal loop."""
288309
loop = asyncio.new_event_loop()
@@ -291,6 +312,112 @@ def _internal_loop_runner(self):
291312
loop.run_forever()
292313
loop.close()
293314

315+
# =========================================================================
316+
# Tool Call Filtering Methods
317+
# Reference: Youtu-Agent implementation (contrib/youtu-agent-lightning)
318+
# Purpose: Filter out "unexpected tool call turns" where the model continues
319+
# generating text after a tool call instead of stopping properly.
320+
# =========================================================================
321+
322+
def _setup_tool_call_filter(self, train_information: Dict[str, Any], tokenizer: Any) -> None:
323+
"""Initialize tool parser and valid ending token patterns for filtering.
324+
325+
Uses apply_chat_template to auto-detect the correct tool call ending tokens
326+
rather than hardcoding token IDs. Also builds variants with eos/pad tokens
327+
to allow various ending conditions and prevent over-filtering.
328+
329+
Args:
330+
train_information: Training config containing 'format' for toolcall format
331+
tokenizer: The tokenizer used for encoding/decoding
332+
"""
333+
if not TOOL_PARSER_AVAILABLE:
334+
print("Warning: ToolParser not available, tool call filtering disabled.")
335+
self.tool_parser = None
336+
return
337+
338+
toolcall_format = train_information.get("format", "hermes")
339+
self.tool_parser = ToolParser.get_tool_parser(toolcall_format, tokenizer)
340+
341+
# Use chat template to detect the actual tool call ending token sequence
342+
# Example uses calculator tool to match calc-x example for consistency
343+
tools_examples = [{
344+
"type": "function",
345+
"name": "calculate",
346+
"description": "Evaluate a mathematical expression",
347+
"parameters": {
348+
"type": "object",
349+
"properties": {
350+
"expression": {"type": "string", "description": "Math expression, e.g., '2 + 3 * 4'"},
351+
},
352+
"required": ["expression"],
353+
},
354+
}]
355+
toolcall_message_examples = [
356+
{"role": "user", "content": "What is 15 + 27?"},
357+
{"role": "assistant", "content": "", "tool_calls": [{
358+
"id": "call_001",
359+
"type": "function",
360+
"function": {"name": "calculate", "arguments": '{"expression":"15 + 27"}'},
361+
}]},
362+
]
363+
toolcall_example_chat_template = tokenizer.apply_chat_template(
364+
toolcall_message_examples, tools=tools_examples,
365+
add_generation_prompt=False, tokenize=False,
366+
)
367+
# Extract the last 2 tokens from the chat template output (e.g., </tool_call><|im_end|>)
368+
toolcall_example_token_last2 = tokenizer.encode(toolcall_example_chat_template.strip())[-2:]
369+
370+
eos_token_id = tokenizer.eos_token_id
371+
pad_token_id = tokenizer.pad_token_id
372+
373+
# Build candidate list: the detected ending + variants with eos/pad
374+
# This allows various tool-call ending conditions to prevent over-filtering
375+
toolcall_candidate_token_last2_list = [toolcall_example_token_last2]
376+
if toolcall_example_token_last2[-1] != eos_token_id:
377+
toolcall_candidate_token_last2_list.append([toolcall_example_token_last2[0], eos_token_id])
378+
if toolcall_example_token_last2[-1] != pad_token_id:
379+
toolcall_candidate_token_last2_list.append([toolcall_example_token_last2[0], pad_token_id])
380+
381+
self.toolcall_candidate_token_last2_list = toolcall_candidate_token_last2_list
382+
logger.info(
383+
f"Tool call filter initialized: {eos_token_id=}, {pad_token_id=}, "
384+
f"candidates={self.toolcall_candidate_token_last2_list}"
385+
)
386+
387+
388+
def _is_valid_tool_call_response(self, response_ids: List[int]) -> Tuple[bool, bool]:
389+
"""Check if a response with tool calls ends with valid ending tokens.
390+
391+
Uses strict last-2-token check (same as youtu branch): the response must end
392+
with one of the candidate token pairs (e.g., </tool_call><|im_end|> or
393+
</tool_call><|endoftext|>).
394+
395+
Args:
396+
response_ids: List of token IDs from the model's response
397+
398+
Returns:
399+
Tuple of (has_tool_calls, has_valid_ending):
400+
- has_tool_calls: True if the response contains tool calls
401+
- has_valid_ending: True if no tool calls, or tool calls with proper ending
402+
"""
403+
if self.tool_parser is None:
404+
return False, True
405+
406+
_, tool_calls = asyncio.run(self.tool_parser.extract_tool_calls(response_ids))
407+
408+
if not tool_calls:
409+
return False, True
410+
411+
if len(response_ids) < 2:
412+
return True, False
413+
414+
# Strict last-2 check against all valid ending candidates
415+
for candidate in self.toolcall_candidate_token_last2_list:
416+
if response_ids[-2] == candidate[0] and response_ids[-1] == candidate[1]:
417+
return True, True
418+
419+
return True, False
420+
294421
# Multimodal utilities for M-RoPE position embeddings
295422

296423
def _is_mrope_model(self) -> bool:
@@ -821,6 +948,12 @@ def get_train_data_batch(
821948
finished_id_to_sample_info: Dict[str, Dict[str, Any]] = {}
822949
finished_id_to_final_reward: Dict[str, float] = {}
823950
sample_with_reward_count = 0
951+
952+
# Tool call filtering metrics
953+
n_total_turns_before_filter = 0
954+
n_unexpected_tool_calls = 0
955+
n_skipped_rollouts_by_filter = 0
956+
824957
for rollout_id, rollout in self._completed_rollouts_v0.items():
825958
original_sample = self._task_id_to_original_sample[rollout_id]
826959
sample_with_reward_count += int(rollout.final_reward is not None)
@@ -842,6 +975,31 @@ def get_train_data_batch(
842975
}
843976
for t in rollout.triplets
844977
]
978+
979+
# Filter void turns (empty prompt or response)
980+
trace_list = [
981+
trace for trace in trace_list
982+
if len(trace["prompt_ids"]) and len(trace["response_ids"])
983+
]
984+
985+
# Filter turns with unexpected tool calls (model continues after tool call)
986+
if self.tool_parser is not None:
987+
n_total_turns_before_filter += len(trace_list)
988+
trace_list_filtered = []
989+
for trace in trace_list:
990+
has_tool_calls, has_valid_ending = self._is_valid_tool_call_response(trace["response_ids"])
991+
if has_tool_calls and not has_valid_ending:
992+
n_unexpected_tool_calls += 1
993+
else:
994+
trace_list_filtered.append(trace)
995+
996+
if self.trace_aggregator.get("filter_unexpected_tool_calls", False):
997+
if len(trace_list_filtered) <= 1:
998+
n_skipped_rollouts_by_filter += 1
999+
finished_id_to_final_reward[rollout_id] = final_reward
1000+
continue
1001+
trace_list = trace_list_filtered
1002+
8451003
info = {
8461004
"reward": final_reward,
8471005
"trace_list": trace_list,
@@ -1123,6 +1281,14 @@ def get_train_data_batch(
11231281
and self.trace_aggregator.get("debug", False)
11241282
else {}
11251283
),
1284+
"training/n_unexpected_tool_calls": n_unexpected_tool_calls,
1285+
"training/n_total_turns_before_filter": n_total_turns_before_filter,
1286+
"training/unexpected_tool_call_ratio": (
1287+
n_unexpected_tool_calls / n_total_turns_before_filter if n_total_turns_before_filter > 0 else 0.0
1288+
),
1289+
"training/n_skipped_rollouts_by_filter": n_skipped_rollouts_by_filter,
1290+
"training/filter_enabled": float(self.trace_aggregator.get("filter_unexpected_tool_calls", False)),
1291+
"training/reward_std": np.std(list(finished_id_to_final_reward.values())),
11261292
}
11271293

11281294
# Add non-tensor data for advantage calculation and logging

agentlightning/verl/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def _train_step(self, batch_dict: dict) -> dict:
421421
self._dump_generations(
422422
inputs=inputs,
423423
outputs=outputs,
424+
gts=[""] * len(inputs),
424425
scores=scores,
425426
reward_extra_infos_dict=reward_extra_infos_dict,
426427
dump_path=rollout_data_dir,

examples/calc_x/train_calc_agent.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@
4040
import agentlightning as agl
4141
from agentlightning.env_var import LightningEnvVar, resolve_bool_env_var, resolve_str_env_var
4242

43+
# Ensure venv bin is in PATH (needed for uvx/mcp-server-calculator in Ray workers)
44+
_script_dir = os.path.dirname(os.path.abspath(__file__))
45+
_venv_bin = os.path.join(_script_dir, "..", "..", ".venv", "bin")
46+
if os.path.isdir(_venv_bin):
47+
os.environ["PATH"] = os.path.abspath(_venv_bin) + ":" + os.environ.get("PATH", "")
48+
4349

4450
def verl_default_config() -> Dict[str, Any]:
4551
config = {
@@ -123,6 +129,11 @@ def train(
123129
trajectory_level: bool = False,
124130
weave: bool,
125131
mongo_uri: Optional[str],
132+
filter_unexpected_tool_calls: bool = False,
133+
experiment_name: Optional[str] = None,
134+
n_gpus: int = 1,
135+
checkpoint_dir: str = "/home/jovyan/msra/experiments/checkpoints",
136+
resume: bool = False,
126137
):
127138
"""The training entrypoint function for Calc-X agent with VERL algorithm.
128139
@@ -141,6 +152,7 @@ def train(
141152
trajectory_level: Whether to enable trajectory level in trace aggregator.
142153
weave: Whether to enable Weave tracing.
143154
mongo_uri: MongoDB URI to use for the store.
155+
experiment_name: Custom experiment name for W&B logging.
144156
"""
145157
# Load datasets (respect CLI file paths)
146158
train_dataset = cast(agl.Dataset[MathProblem], HuggingFaceDataset.from_parquet(train_file).to_list()) # type: ignore
@@ -156,6 +168,26 @@ def train(
156168
if model:
157169
config["actor_rollout_ref"]["model"]["path"] = model
158170

171+
# Override experiment name if provided (for W&B logging)
172+
if experiment_name:
173+
config["trainer"]["experiment_name"] = experiment_name
174+
print(f"Using custom experiment name: {experiment_name}")
175+
176+
# Override n_gpus_per_node for multi-GPU training
177+
if n_gpus > 1:
178+
config["trainer"]["n_gpus_per_node"] = n_gpus
179+
print(f"Multi-GPU training enabled: n_gpus_per_node={n_gpus}")
180+
181+
# Set checkpoint directory and conversation dump directory
182+
config["trainer"]["default_local_dir"] = checkpoint_dir
183+
config["trainer"]["resume_mode"] = "auto" if resume else "disable"
184+
conversations_dir = checkpoint_dir.replace("checkpoints", "conversations")
185+
config["trainer"]["rollout_data_dir"] = conversations_dir
186+
os.makedirs(conversations_dir, exist_ok=True)
187+
print(f"Checkpoint directory: {checkpoint_dir}")
188+
print(f"Conversations directory: {conversations_dir}")
189+
print(f"Resume mode: {config['trainer']['resume_mode']}")
190+
159191
# Enable LoRA configuration if requested
160192
if lora:
161193
config["actor_rollout_ref"]["model"]["lora_rank"] = lora_rank
@@ -175,6 +207,19 @@ def train(
175207
}
176208
print("Trajectory level enabled in trace aggregator.")
177209

210+
# =========================================================================
211+
# Tool Call Filtering (Youtu-Agent style)
212+
# Filters out turns where the model generates unexpected content after
213+
# a tool call (hallucinated tool responses). Helps prevent entropy explosion.
214+
# =========================================================================
215+
if filter_unexpected_tool_calls:
216+
if "agentlightning" not in config:
217+
config["agentlightning"] = {"trace_aggregator": {}}
218+
if "trace_aggregator" not in config["agentlightning"]:
219+
config["agentlightning"]["trace_aggregator"] = {}
220+
config["agentlightning"]["trace_aggregator"]["filter_unexpected_tool_calls"] = True
221+
print("Tool call filtering enabled (Youtu-Agent style).")
222+
178223
# CI toggle keeps everything else the same but you can tweak the lightweight bits here if desired
179224
if ci or ci_fast:
180225
# Config the experiment name and project name so that they are available to CI
@@ -290,6 +335,35 @@ def main():
290335
default=None,
291336
help="MongoDB URI to use for the store.",
292337
)
338+
parser.add_argument(
339+
"--filter-unexpected-tool-calls",
340+
action="store_true",
341+
help="Enable Youtu-Agent style tool call filtering. "
342+
"Filters out turns where the model generates unexpected content after a tool call.",
343+
)
344+
parser.add_argument(
345+
"--experiment-name",
346+
type=str,
347+
default=None,
348+
help="Custom experiment name for W&B logging (default: calc_x or auto-generated for CI)",
349+
)
350+
parser.add_argument(
351+
"--n-gpus",
352+
type=int,
353+
default=1,
354+
help="Number of GPUs per node for distributed training (default: 1)",
355+
)
356+
parser.add_argument(
357+
"--checkpoint-dir",
358+
type=str,
359+
default="/home/jovyan/msra/experiments/checkpoints",
360+
help="Directory to save checkpoints (default: /home/jovyan/msra/experiments/checkpoints)",
361+
)
362+
parser.add_argument(
363+
"--resume",
364+
action="store_true",
365+
help="Resume training from the latest checkpoint in checkpoint-dir",
366+
)
293367

294368
args = parser.parse_args()
295369

@@ -321,6 +395,11 @@ def main():
321395
trajectory_level=args.trajectory_level,
322396
weave=args.weave,
323397
mongo_uri=args.mongo_uri,
398+
filter_unexpected_tool_calls=args.filter_unexpected_tool_calls,
399+
experiment_name=args.experiment_name,
400+
n_gpus=args.n_gpus,
401+
checkpoint_dir=args.checkpoint_dir,
402+
resume=args.resume,
324403
)
325404

326405

0 commit comments

Comments
 (0)