diff --git a/docs/features/reward_system.md b/docs/features/reward_system.md index 15f0654..49e2fc1 100644 --- a/docs/features/reward_system.md +++ b/docs/features/reward_system.md @@ -5,7 +5,9 @@ We put reward calculation into the agent side instead of trainer side and use a 2. Reward calculation can be designed to be asynchronous for efficiency. ### Definition -Similar to tools, we can decide whether to use environments in the reward definition. The return should either be a value, or a dictionary containing `reward` as one of keys. We can use decorator `@tool` or inherit from the `BaseReward` class. +Similar to tools, we can decide whether to use environments in the reward definition. The return should either be a value, or a dictionary containing `reward` as one of keys. We can use decorator `@tool` or inherit from the `BaseReward` class. Any additional keys in the returned dict (e.g. `em`, `f1`, `fmt`) are passed through and documented in training and validation. + + ```python @reward(name="qa_f1_reward") @@ -105,3 +107,9 @@ def summary_reward(final_response, length_penalty, max_length): else: return 1.0 ``` + +## Return Values + +Each a `float` value or a dictionary containing `reward` as key should be returned. If the return value is `float`, it is directly used as rewards. If a dictionary is returned, the `reward` is used as rewards. While other keys are still documented. + +Extra keys (besides `reward`) are logged as `reward_extra/{key}/mean`, `reward_extra/{key}/max`, `reward_extra/{key}/min` in the metrics produced by `compute_data_metrics` (`verl/verl/trainer/ppo/metric_utils.py`). \ No newline at end of file diff --git a/src/agentfly/rewards/qa_reward.py b/src/agentfly/rewards/qa_reward.py index 19dbab8..5a8618d 100644 --- a/src/agentfly/rewards/qa_reward.py +++ b/src/agentfly/rewards/qa_reward.py @@ -149,6 +149,78 @@ def qa_f1_reward_tool(final_response: str, answer: str, trajectory: List[str]) - return rewards_dict +def _extract_answer_tag(text: str) -> str: + """Extract content between and , or return original if not present.""" + match = re.search(r"\s*(.*?)\s*", text, re.DOTALL) + return match.group(1).strip() if match else text + + +def _format_ok(final_response: str, trajectory: List) -> tuple: + """True if final_response has ..., trajectory has tool calling, and all assistant turns except the last have /.""" + has_answer_tags = "" in final_response and "" in final_response + if not has_answer_tags or not trajectory: + return False, False, False + has_tool_calling = any( + isinstance(msg, dict) and msg.get("role") == "tool" for msg in trajectory + ) + # Collect assistant turns; only previous (non-last) ones must have think + assistant_turns = [] + for msg in trajectory: + if isinstance(msg, dict): + if msg.get("role") == "assistant": + content = msg.get("content") or msg.get("text") or "" + assistant_turns.append(content) + elif msg.get("role") == "tool": + pass # already counted has_tool_calling + else: + assistant_turns.append(str(msg)) + if not assistant_turns: + previous_have_think = True + else: + previous = assistant_turns[:-1] # all but last + previous_have_think = all( + "" in c and "" in c for c in previous if c + ) + fmt = has_answer_tags and has_tool_calling and previous_have_think + return fmt, previous_have_think, has_tool_calling + + +@reward(name="qa_em_format_reward") +def qa_em_format_reward(final_response: str, golden_answers: List[str], trajectory: List[str]) -> float: + """ + Calculate the reward for the agent's response based on the EM score. + + - 1.0 if the format is correct, and the em is true + - 0.1 if the format is correct, but the em is wrong + - 0.0 if the format is incorrect + """ + predicted = _extract_answer_tag(final_response) + if not golden_answers: + max_em, max_f1 = 0.0, 0.0 + else: + max_em = max(em_score(predicted, g) for g in golden_answers) + max_f1 = max(f1_score(predicted, g)[0] for g in golden_answers) + fmt, previous_have_think, has_tool_calling = _format_ok(final_response, trajectory) + + reward = 0.0 + if fmt and max_em: + reward = 1.0 + elif fmt and not max_em: + reward = 0.1 + elif max_em and not fmt: + reward = 0.0 + + return { + "reward": reward, + "em": max_em, + "f1": max_f1, + "fmt": 1.0 if fmt else 0.0, + "fmt_think": 1.0 if previous_have_think else 0.0, + "fmt_tool": 1.0 if has_tool_calling else 0.0, + } + + + @reward(name="ok_vqa_reward") def ok_vqa_reward( final_response: str, answers: List[str], trajectory: List[str] diff --git a/src/agentfly/tools/src/search/async_dense_retriever.py b/src/agentfly/tools/src/search/async_dense_retriever.py index 8dd5680..d30bf4e 100644 --- a/src/agentfly/tools/src/search/async_dense_retriever.py +++ b/src/agentfly/tools/src/search/async_dense_retriever.py @@ -7,7 +7,7 @@ from collections import deque from concurrent.futures import ThreadPoolExecutor from functools import lru_cache - +from .... import AGENT_CACHE_DIR import datasets import numpy as np import torch diff --git a/verl b/verl index a6b57f2..7da6e92 160000 --- a/verl +++ b/verl @@ -1 +1 @@ -Subproject commit a6b57f2659a44d0ccba543dcabbc6bd5425b3689 +Subproject commit 7da6e9249de7b455ad48f5b4f22148686ef1c29b