Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion docs/features/reward_system.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ We put reward calculation into the agent side instead of trainer side and use a
2. Reward calculation can be designed to be asynchronous for efficiency.

### Definition
Similar to tools, we can decide whether to use environments in the reward definition. The return should either be a value, or a dictionary containing `reward` as one of keys. We can use decorator `@tool` or inherit from the `BaseReward` class.
Similar to tools, we can decide whether to use environments in the reward definition. The return should either be a value, or a dictionary containing `reward` as one of keys. We can use decorator `@tool` or inherit from the `BaseReward` class. Any additional keys in the returned dict (e.g. `em`, `f1`, `fmt`) are passed through and documented in training and validation.



```python
@reward(name="qa_f1_reward")
Expand Down Expand Up @@ -105,3 +107,9 @@ def summary_reward(final_response, length_penalty, max_length):
else:
return 1.0
```

## Return Values

Each a `float` value or a dictionary containing `reward` as key should be returned. If the return value is `float`, it is directly used as rewards. If a dictionary is returned, the `reward` is used as rewards. While other keys are still documented.

Extra keys (besides `reward`) are logged as `reward_extra/{key}/mean`, `reward_extra/{key}/max`, `reward_extra/{key}/min` in the metrics produced by `compute_data_metrics` (`verl/verl/trainer/ppo/metric_utils.py`).
72 changes: 72 additions & 0 deletions src/agentfly/rewards/qa_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,78 @@ def qa_f1_reward_tool(final_response: str, answer: str, trajectory: List[str]) -
return rewards_dict


def _extract_answer_tag(text: str) -> str:
"""Extract content between <answer> and </answer>, or return original if not present."""
match = re.search(r"<answer>\s*(.*?)\s*</answer>", text, re.DOTALL)
return match.group(1).strip() if match else text


def _format_ok(final_response: str, trajectory: List) -> tuple:
"""True if final_response has <answer>...</answer>, trajectory has tool calling, and all assistant turns except the last have <think>/</think>."""
has_answer_tags = "<answer>" in final_response and "</answer>" in final_response
if not has_answer_tags or not trajectory:
return False, False, False
has_tool_calling = any(
isinstance(msg, dict) and msg.get("role") == "tool" for msg in trajectory
)
# Collect assistant turns; only previous (non-last) ones must have think
assistant_turns = []
for msg in trajectory:
if isinstance(msg, dict):
if msg.get("role") == "assistant":
content = msg.get("content") or msg.get("text") or ""
assistant_turns.append(content)
elif msg.get("role") == "tool":
pass # already counted has_tool_calling
else:
assistant_turns.append(str(msg))
if not assistant_turns:
previous_have_think = True
else:
previous = assistant_turns[:-1] # all but last
previous_have_think = all(
"<think>" in c and "</think>" in c for c in previous if c
)
fmt = has_answer_tags and has_tool_calling and previous_have_think
return fmt, previous_have_think, has_tool_calling


@reward(name="qa_em_format_reward")
def qa_em_format_reward(final_response: str, golden_answers: List[str], trajectory: List[str]) -> float:
"""
Calculate the reward for the agent's response based on the EM score.

- 1.0 if the format is correct, and the em is true
- 0.1 if the format is correct, but the em is wrong
- 0.0 if the format is incorrect
"""
predicted = _extract_answer_tag(final_response)
if not golden_answers:
max_em, max_f1 = 0.0, 0.0
else:
max_em = max(em_score(predicted, g) for g in golden_answers)
max_f1 = max(f1_score(predicted, g)[0] for g in golden_answers)
fmt, previous_have_think, has_tool_calling = _format_ok(final_response, trajectory)

reward = 0.0
if fmt and max_em:
reward = 1.0
elif fmt and not max_em:
reward = 0.1
elif max_em and not fmt:
reward = 0.0

return {
"reward": reward,
"em": max_em,
"f1": max_f1,
"fmt": 1.0 if fmt else 0.0,
"fmt_think": 1.0 if previous_have_think else 0.0,
"fmt_tool": 1.0 if has_tool_calling else 0.0,
}



@reward(name="ok_vqa_reward")
def ok_vqa_reward(
final_response: str, answers: List[str], trajectory: List[str]
Expand Down
2 changes: 1 addition & 1 deletion src/agentfly/tools/src/search/async_dense_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache

from .... import AGENT_CACHE_DIR
import datasets
import numpy as np
import torch
Expand Down