diff --git a/.gitignore b/.gitignore index c89af90d7..693fcadb1 100644 --- a/.gitignore +++ b/.gitignore @@ -218,3 +218,7 @@ agentlightning/dashboard/**/*.svg # Docker data docker/data/ + +# AGL simulation +agl_envs/ +wandb/ diff --git a/contrib/agentlightning/contrib/agent/empo2_agent.py b/contrib/agentlightning/contrib/agent/empo2_agent.py new file mode 100644 index 000000000..3cfcde477 --- /dev/null +++ b/contrib/agentlightning/contrib/agent/empo2_agent.py @@ -0,0 +1,259 @@ +import copy +import logging +from typing import Any, Dict + +import numpy as np +import requests +from add_instruction import add_chat_all_tips, add_chat_instruction, add_chat_tips +from agl_envs import make_env_manager + +from agentlightning import LLM, NamedResources, Rollout, configure_logger, emit_reward, operation +from agentlightning.utils.otel import make_link_attributes +from contrib.agentlightning.contrib.agent.env_agent import EnvAgent +from contrib.recipes.envs.prompt_builder import HistoryPromptBuilder + +configure_logger() +logger = configure_logger(name=__name__, level=logging.ERROR) + + +def do_compress(text): + url = "http://127.0.0.1:8000/key_cal/" + headers = {"Content-Type": "application/json"} # 明确指定 JSON 格式 + data = {"text": text} + response = requests.post(url, json=data, headers=headers) # 使用 json 参数 + return response.json() + + +url_mem = "http://127.0.0.1:8001/mem/" + + +def retrieve_memory(idx, key): + response = requests.post(url_mem, json={"key": key, "idx": idx}) + count, data = response.json() + return count, data + + +def reset_memory(mem_list_num): + requests.post(url_mem, json={"key": [], "idx": mem_list_num, "content": "Reset"}) # 用于初始化多个 memory slot + + +def add_memory(idx, key, content, score): + requests.post(url_mem, json={"key": key, "idx": idx, "content": content, "score": score}) + + +def gather_chats(prompt): + chat_list = [] + for item in prompt: + role = item.type + content = item.content + if "System" in role: + continue + elif "User" in role: + role = "user" + else: + role = "assistant" + chat_list.append(f"{role}: {content}") + text = " ".join(chat_list) + return text + + +class EMPO2Agent(EnvAgent): + def __init__(self, config, trained_agents: str | None = None) -> None: + super().__init__(config=config, trained_agents=trained_agents) + + def _get_tip_prompt(self, prompt, tips): + prompt_type = self.config.captioner.prompt_type + + if prompt_type == "chat": + return add_chat_tips(prompt, tips) + else: + raise ValueError(f"Unsupported prompt_type '{prompt_type}' for _get_tip_obs (expected 'chat')") + + def _get_all_tip_prompt(self, prompt, tip_list): + prompt_type = self.config.captioner.prompt_type + if prompt_type == "chat": + return add_chat_all_tips(prompt, tip_list) + else: + raise ValueError(f"Unsupported prompt_type '{prompt_type}' for _get_tip_obs (expected 'chat')") + + def _get_tip_generation_prompt(self, prompt): + return add_chat_instruction(prompt, "tip") + + async def rollout_async( + self, + task: Dict[str, Any], + resources: NamedResources, + rollout: Rollout, + ) -> float | None: + rollout_id = rollout.rollout_id + logger.info(f"[Rollout {rollout_id}] Task: {task}") + + reward_scale = float(self.config["reawrd_scale"]) + + # Setup LLM + agent + llm: LLM = resources.get("main_llm") + print("Training with model:", llm.model, "on endpoint:", llm.endpoint) + self.agent = self._build_agent(llm, 1.0 if rollout.mode == "train" else 0.4) + + if rollout.mode == "train": + train_mode = task["train_mode"] + global_steps = task["global_steps"] + else: + train_mode = "on-policy" + + if rollout.mode == "train" and (train_mode == "off-policy" or train_mode == "on-policy-with-tips"): + use_tips = True + else: + use_tips = False + + variation_idx = task["variation_idx"] + + try: + # Setup environment + prompt_builder = HistoryPromptBuilder( + max_history=self.config.captioner.max_history, prompt_type=self.config.captioner.prompt_type + ) + + self.env = make_env_manager(self.config.env_name, task, self.config) + env_obs, infos, available_actions_hint = self.env.reset() + + prompt_builder.init(self.env) + prompt_builder.update_observation(env_obs) + # prompt_builder.update_admissible_actions(available_actions_hint) + + prompt = prompt_builder.get_prompt() + + episode_reward, done = 0.0, False + + pure_prompt_for_mem = [] + history_actions_for_mem = [] + tip_list = [] + + step_count = 0 + while not done: + if use_tips: + text = gather_chats(prompt) + key = ( + np.array(do_compress(text)["key"]) + .reshape( + -1, + ) + .tolist() + ) + count, mem_list = retrieve_memory(variation_idx, key) + else: + count, mem_list = 0, [] + + ret_tips, intrinsic_reward = "", 0.0 + + if use_tips: + if count > 0: + ret_tips = "Here are some memories you collected in your previous exploration:\n" + for mem in mem_list: + ret_tips += mem + "\n" + + tip_list.append(ret_tips) + intrinsic_reward = 1 / (count + 1) + else: + tip_list.append("") + intrinsic_reward = 1 + + try: + if count > 0: + tip_prompt = self._get_all_tip_prompt(prompt, tip_list) + instructed_prompt = self._get_instructed_prompt(tip_prompt, sep="") + else: + instructed_prompt = self._get_instructed_prompt(prompt) + + # Main agent step + with operation(step_count=step_count): + result = await self.agent._model_client.create(instructed_prompt) + output = result.content + logger.info(f"[LLM output]: {output}") + + except Exception as e: + logger.error(f"[Rollout {rollout_id}] Error during training rollout: {e}", exc_info=True) + break + + # Environment step + pure_prompt_for_mem.append([copy.deepcopy(prompt), None]) + env_obs, executed_action, is_valid, step_reward, terminated, truncated, info, available_actions_hint = ( + self.env.step(output, use_reasoning=self.config.captioner.type == "cot") + ) + history_actions_for_mem.append(executed_action) + + prompt_builder.update_step_count() + prompt_builder.update_action(executed_action) + prompt_builder.update_observation(env_obs) + # prompt_builder.update_admissible_actions(available_actions_hint) + + prompt = prompt_builder.get_prompt() + + if rollout.mode == "train": + step_reward = reward_scale * step_reward + + emit_reward( + { + "extrinsic_reward": step_reward, + "intrinsic_reward": intrinsic_reward, + }, + primary_key="extrinsic_reward", + attributes=make_link_attributes({"step_count": str(step_count)}), + ) + + episode_reward += float(step_reward) + done = np.logical_or(terminated, truncated) + + step_count += 1 + + if rollout.mode == "train" and self.config.captioner.prompt_type == "chat" and self.config.save_rollout: + filename = f"empo2_rollouts/variant_{variation_idx}/step_{global_steps}/{rollout_id}_{round(episode_reward, 1)}_use_tip_{use_tips}.json" + if use_tips: + _rollout = self._get_all_tip_obs(obs, tip_list) + else: + _rollout = obs + self._save_chat_rollout(_rollout, filename) + + if rollout.mode == "train": + prompt_builder.prompt_type = "chat" + prompt_builder.max_history = -1 + prompt = prompt_builder.get_prompt() + prompt.pop() + + tip_generation_prompt = self._get_tip_generation_prompt(prompt) + + self.agent._model_client.max_tokens = 128 + result = await self.agent._model_client.create(tip_generation_prompt) + tips = result.content + logger.info(f"Tips: {tips}") + + #! Fill the ret and tip + for i in range(len(pure_prompt_for_mem)): + max_score = 100 * reward_scale + pure_prompt_for_mem[i][1] = ( + tips + + f"; At that timestep, the specific action your took was {history_actions_for_mem[i]}; Eventually you got the score {round(episode_reward, 1)}/{int(max_score)}." + ) + + #! Generate the tips and save the mem + for i in range(len(pure_prompt_for_mem)): + text = gather_chats(pure_prompt_for_mem[i][0]) + key = ( + np.array(do_compress(text)["key"]) + .reshape( + -1, + ) + .tolist() + ) + content = pure_prompt_for_mem[i][1] + score = episode_reward + add_memory(variation_idx, key, content, round(score, 1)) + + if self.config.use_success_rate: + return self.env.get_success_score() * reward_scale + else: + return episode_reward + + finally: + if self.env is not None: + self.env.close() diff --git a/contrib/agentlightning/contrib/algorithm/env_verl/core_empo2.py b/contrib/agentlightning/contrib/algorithm/env_verl/core_empo2.py new file mode 100644 index 000000000..ce11cef68 --- /dev/null +++ b/contrib/agentlightning/contrib/algorithm/env_verl/core_empo2.py @@ -0,0 +1,67 @@ +from typing import Any, List + +import torch + + +def is_sublist(sub, full): + n, m = len(sub), len(full) + return any(full[i : i + n] == sub for i in range(m - n + 1)) + + +# Function to remove segments of a list between a start pattern and an end pattern +def remove_pattern_ranges(seq: List[Any], start_pat: List[Any], end_pat: List[Any]) -> List[Any]: + """Remove every [start_pat ... end_pat] slice (inclusive) from seq.""" + + out: List[Any] = [] + i = 0 + n = len(seq) + ls, le = len(start_pat), len(end_pat) + + while i < n: + # Check if the start pattern matches at the current position + if i + ls <= n and seq[i : i + ls] == start_pat: + # Look for the first occurrence of the end pattern after the start pattern + j = i + ls + found_end = -1 + while j + le <= n: + if seq[j : j + le] == end_pat: + found_end = j + break # Stop when the end pattern is found + j += 1 + + # If the end pattern is found, skip the whole segment from start to end + if found_end != -1: + i = found_end + le # Move the index past the end pattern + continue # Skip the current iteration and go to the next + else: + # If the end pattern is not found, keep the current element and move one step forward + out.append(seq[i]) + i += 1 + else: + # If the start pattern is not found, just append the current element + out.append(seq[i]) + i += 1 + + # Return the filtered list with the start-end pattern segments removed + return out + + +def low_prob_token_masking(batch): + response_mask = batch.batch["response_mask"] # [N, T] + old_log_prob = batch.batch["old_log_probs"] # [N, T] + # advantages = batch.batch["advantages"] # [N, T] + + masked_old_log_prob = old_log_prob.masked_fill(response_mask == 0, 1e9) + min_values, _ = torch.min(masked_old_log_prob, dim=1) # [N] + + mask = min_values < -5 # [N] + + combined_mask = mask.unsqueeze(1) & (response_mask == 1) + + # advantages masking + response_mask = response_mask.masked_fill(combined_mask, 0) + batch.batch["response_mask"] = response_mask + + print(f"Number of tokens masked: {combined_mask.sum().item()}") + + return batch diff --git a/contrib/agentlightning/contrib/algorithm/env_verl/daemon.py b/contrib/agentlightning/contrib/algorithm/env_verl/daemon.py index 10c947c16..609dd4cbd 100644 --- a/contrib/agentlightning/contrib/algorithm/env_verl/daemon.py +++ b/contrib/agentlightning/contrib/algorithm/env_verl/daemon.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import copy import json import random import socket @@ -18,8 +19,8 @@ from tensordict import TensorDict from verl import DataProto +import contrib.agentlightning.contrib.algorithm.env_verl.core_empo2 as core_empo2 from agentlightning import LLM, AgentLightningServer, NamedResources, RolloutLegacy -from agentlightning.adapter.triplet import TraceToTripletBase from agentlightning.llm_proxy import LLMProxy, ModelConfig from agentlightning.reward import find_final_reward from agentlightning.store.base import LightningStore @@ -145,7 +146,7 @@ def __init__( mode: Literal["v0", "v1"] = "v1", llm_proxy: LLMProxy | None = None, store: LightningStore | None = None, - adapter: TraceToTripletBase | None = None, + adapter: TracerTraceToTripletGroup | None = None, ): self.mode = mode self.llm_timeout_seconds = llm_timeout_seconds @@ -658,9 +659,11 @@ def get_train_data_batch( max_prompt_length: int, max_response_length: int, device: torch.device, + max_train_length: int = -1, use_final_reward_as_step_reward: bool = True, use_intrinsic_reward: bool = False, is_gigpo: bool = False, + empo2_train_mode: bool = False, ): """ Processes completed rollouts to generate a training data batch. @@ -741,13 +744,21 @@ def get_train_data_batch( for rollout_id, sample_info in finished_id_to_sample_info.items(): for turn_index, trace in enumerate(sample_info["trace_list"]): + prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"] + + if max_train_length > -1 and len(prompt_ids) + len(response_ids) > max_train_length: + continue final_reward_list.append(sample_info["final_reward"]) step_reward_list.append(trace["step_reward"]) step_intrinsic_reward_list.append(trace["step_intrinsic_reward"]) message_list.append(trace["message"]) - prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"] + if empo2_train_mode == "off-policy": + START_PATTERN = self.tokenizer.encode("") + END_PATTERN = self.tokenizer.encode("\n\n") + if core_empo2.is_sublist(START_PATTERN, prompt_ids): + prompt_ids = core_empo2.remove_pattern_ranges(prompt_ids, START_PATTERN, END_PATTERN) # Mark samples with prompts exceeding max_prompt_length to be dropped later if len(prompt_ids) > max_prompt_length: @@ -787,6 +798,7 @@ def get_train_data_batch( batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1) attention_mask = torch.cat([input_attention_mask, response_attention_mask], dim=-1) position_ids = torch.clamp(torch.cumsum(attention_mask, dim=-1) - 1, min=0) + is_drop_mask = torch.BoolTensor(is_drop_list).to(device) if use_final_reward_as_step_reward: scores = torch.tensor(final_reward_list, dtype=torch.float32).to(device) diff --git a/contrib/agentlightning/contrib/algorithm/env_verl/trainer.py b/contrib/agentlightning/contrib/algorithm/env_verl/trainer.py index 3dd2458f4..147bb6879 100644 --- a/contrib/agentlightning/contrib/algorithm/env_verl/trainer.py +++ b/contrib/agentlightning/contrib/algorithm/env_verl/trainer.py @@ -34,6 +34,7 @@ from verl.utils.metric import reduce_metrics from verl.utils.tracking import Tracking +import contrib.agentlightning.contrib.algorithm.env_verl.core_empo2 as core_empo2 from agentlightning.adapter import TraceAdapter, TraceToTripletBase from agentlightning.llm_proxy import LLMProxy from agentlightning.store.base import LightningStore @@ -250,6 +251,21 @@ def _train_step(self, batch_dict: dict) -> dict: # generate a batch with _timer("gen", timing_raw): self.async_rollout_manager.wake_up() + + num_problems = self.config.data.train_batch_size + gen_batch.non_tensor_batch["global_steps"] = [self.global_steps for _ in range(num_problems)] + + if hasattr(self.config, "tips") and self.config.tips.use_tips: + touzi = random.random() + if touzi < 0.17: + self.empo2_train_mode = "off-policy" # Update with Tips and give them to the pure_chats + elif touzi < 0.25: + self.empo2_train_mode = "on-policy-with-tips" + else: + self.empo2_train_mode = "on-policy" # Normal Update, No Tips + + gen_batch.non_tensor_batch["train_mode"] = [self.empo2_train_mode for _ in range(num_problems)] + self.agent_mode_daemon.set_up_data_and_server( gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses ) @@ -257,9 +273,11 @@ def _train_step(self, batch_dict: dict) -> dict: batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch( max_prompt_length=self.config.data.max_prompt_length, max_response_length=self.config.data.max_response_length, + max_train_length=getattr(self.config.data, "max_train_length", -1), device=gen_batch.batch["fake_ids"].device, use_final_reward_as_step_reward=self.config.algorithm.use_final_reward_as_step_reward, use_intrinsic_reward=self.config.algorithm.use_intrinsic_reward, + empo2_train_mode=getattr(self, "empo2_train_mode", None), ) metrics.update(agent_metrics) self.agent_mode_daemon.clear_data_and_server() @@ -362,6 +380,9 @@ def _train_step(self, batch_dict: dict) -> dict: config=self.config.algorithm, ) + if hasattr(self.config, "tips") and self.config.tips.use_tips: + batch = core_empo2.low_prob_token_masking(batch) + # Calculate the metrics before processing. Refer to the comments of function `compute_data_metrics` for details. metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_before_processing")) @@ -500,6 +521,14 @@ def fit(self): # train step metrics = self._train_step(batch_dict) + if hasattr(self.config, "tips") and self.config.tips.use_tips: + mode_map = { + "off-policy": 0, + "on-policy-with-tips": 1, + "on-policy": 2, + } + metrics["empo2/train_mode"] = mode_map.get(self.empo2_train_mode) + # validate if ( self.val_reward_fn is not None diff --git a/contrib/recipes/envs/README.md b/contrib/recipes/envs/README.md index 105190d5e..f32f2f732 100644 --- a/contrib/recipes/envs/README.md +++ b/contrib/recipes/envs/README.md @@ -80,15 +80,28 @@ We follow the single-mode prompt for ALFWorld from [verl-agent](https://github.c   -## Run RL Training (GRPO) +## Run RL Training + +### GRPO ```bash # Run alfworld -python3 train_env_agent.py --algorithm grpo --env alfworld +python3 train_env_agent.py --algorithm grpo_qwen_1.5b_instruct --env alfworld # Run scienceworld single task task_num 0 -python3 train_env_agent.py --algorithm grpo --env scienceworld --task_num 0 +python3 train_env_agent.py --algorithm grpo_qwen_1.5b_instruct --env scienceworld --task_num 0 # Run scienceworld multi-task -python3 train_env_agent.py --algorithm grpo --env scienceworld --task_num -1 +python3 train_env_agent.py --algorithm grpo_qwen_1.5b_instruct --env scienceworld --task_num -1 ``` + +### EMPO² Integration + +We integrate **EMPO²** (*Memory-Augmented LLM Agent via Online Self-Distillation*, ICLR 2026) [[paper]](https://arxiv.org/abs/2602.23008) into AGL. EMPO² leverages a memory-augmented mechanism combined with online self-distillation to enhance LLM agent performance. In our experiments, EMPO² consistently outperforms GRPO, demonstrating stronger learning efficiency. + +```bash +# Run scienceworld single task task_num 25 +python3 train_env_agent.py --algorithm empo2_qwen_7b_instruct --env scienceworld2 --task_num 25 +``` + +![agl_empo2_25](./assets/agl_empo2_25.png) diff --git a/contrib/recipes/envs/add_instruction.py b/contrib/recipes/envs/add_instruction.py index 627500153..ac1033ec2 100644 --- a/contrib/recipes/envs/add_instruction.py +++ b/contrib/recipes/envs/add_instruction.py @@ -11,13 +11,23 @@ """.strip() NAIVE_INSTRUCTION = """ +You could try to explore different actions, especially when you are not sure what the best action for your current observation. Please response with only one line with one sentence, following the possible action format shown above. No extra words are allowed. """.strip() +TIP_INSTRUCTION = """ +Thanks for your playing. +Now you have ended a trajectory and collect some meaningless or valuable information from the interactions with the environment. +Please summary the trajectory, and also summary what information you get from this trajectory, and how far this trajectory is from fully completing the task. +Please response with only one sentence with only one line, do not include any extra words. +You sentence should be less than 100 words. +""".strip() + # Mapping for instruction text types INSTRUCTION_MAP = { "cot": COT_INSTRUCTION, "naive": NAIVE_INSTRUCTION, + "tip": TIP_INSTRUCTION, } @@ -26,7 +36,7 @@ def _get_instruction(type: str, env_name: str = None): Retrieve an instruction string from INSTRUCTION_MAP based on the given type. Args: - type (str): Instruction type key (e.g., "cot", "naive", "critic", "tip"). + type (str): Instruction type key (e.g., "cot", "naive", "tip"). env_name (str, optional): Currently unused. Reserved for future environment-specific instruction handling. @@ -60,11 +70,18 @@ def add_chat_instruction(prompt, type: str, sep: str = "\n\n", env_name: str = N Returns: list: A new prompt list with the instruction appended to the last message. """ - new_prompt = copy.deepcopy(prompt) - instruction = _get_instruction(type, env_name) - new_prompt[-1].content += sep + instruction + if type == "tip": + new_prompt = copy.deepcopy(prompt) + tip_instruction = _get_instruction(type, env_name) + new_prompt.append(UserMessage(source="user", content=tip_instruction)) - return new_prompt + return new_prompt + else: + new_prompt = copy.deepcopy(prompt) + instruction = _get_instruction(type, env_name) + new_prompt[-1].content += sep + instruction + + return new_prompt def add_single_instruction(prompt, type: str, sep: str = "\n\n", env_name: str = None): @@ -99,3 +116,24 @@ def add_single_instruction(prompt, type: str, sep: str = "\n\n", env_name: str = return new_prompt else: raise TypeError("Prompt must be a string or a list of strings") + + +def add_chat_tips(prompt, tips): + new_prompt = copy.deepcopy(prompt) + new_prompt[-1].content += f"\n\n {tips}\n\n\n" + return new_prompt + + +def add_chat_all_tips(prompt, tip_list): + new_prompt = copy.deepcopy(prompt) + tips_iter = iter(tip_list) + + for item in new_prompt: + if "User" in item.type: + tip = next(tips_iter, None) + if tip is None: + break + if not tip == "": + item.content += f"\n\n {tip}\n\n\n" + + return new_prompt diff --git a/contrib/recipes/envs/assets/agl_empo2_25.png b/contrib/recipes/envs/assets/agl_empo2_25.png new file mode 100644 index 000000000..6233ac16e Binary files /dev/null and b/contrib/recipes/envs/assets/agl_empo2_25.png differ diff --git a/contrib/recipes/envs/assets/prompt_type.png b/contrib/recipes/envs/assets/prompt_type.png index da6349ddc..98c0ed0f6 100644 Binary files a/contrib/recipes/envs/assets/prompt_type.png and b/contrib/recipes/envs/assets/prompt_type.png differ diff --git a/contrib/recipes/envs/clean.sh b/contrib/recipes/envs/clean.sh new file mode 100755 index 000000000..5ea8b7b1c --- /dev/null +++ b/contrib/recipes/envs/clean.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -e + +echo "Stopping AgentLightning and simulation_agent..." +pkill -f AgentLightning || true +pkill -f simulation_agent || true + +echo "Stopping Ray cluster..." +ray stop + +echo "Killing VLLM::EngineCore processes..." +ps aux | grep VLLM::EngineCore | grep -v grep | awk '{print $2}' | xargs --no-run-if-empty kill -9 + +echo "✅ Cleanup complete." diff --git a/contrib/recipes/envs/config_env/scienceworld2.yaml b/contrib/recipes/envs/config_env/scienceworld2.yaml new file mode 100644 index 000000000..bd4be3bb5 --- /dev/null +++ b/contrib/recipes/envs/config_env/scienceworld2.yaml @@ -0,0 +1,16 @@ +env_name: scienceworld # scienceworld, babyai, alfworld +seed: 0 +format_penalty: 0.0 +binary_reward: False +save_rollout: False +log_env_obs: False # True for GiGPO +reawrd_scale: 1.0 +use_success_rate: False + +# only for scienceworld +use_action_correction: True + +captioner: + type: naive # naive or cot + prompt_type: chat # chat or single + max_history: -1 diff --git a/contrib/recipes/envs/config_verl/alfworld/grpo_qwen_1.5b_instruct.yaml b/contrib/recipes/envs/config_verl/alfworld/grpo_qwen_1.5b_instruct.yaml new file mode 100644 index 000000000..3fed97e60 --- /dev/null +++ b/contrib/recipes/envs/config_verl/alfworld/grpo_qwen_1.5b_instruct.yaml @@ -0,0 +1,86 @@ +# ========================== +# Variable definitions +# ========================== +variables: + NUM_GPUS: 2 + MINI_BATCH_SIZE: 32 + PER_GPU_BATCH_SIZE: 16 + TENSOR_MODEL_PARALLEL_SIZE: 2 + NUM_ROLLOUTS: 8 + BASE_MODEL: Qwen/Qwen2.5-1.5B-Instruct + PROJECT_NAME: AGL-Simulation-ALFWorld + TRIAL: ${oc.env:TRIAL,0} + EXPERIMENT_NAME: grpo-alfworld-${variables.TRIAL} + DATA_DIR: agl_envs/task_data/alfworld + +# ========================== +# Main Config +# ========================== +agentlightning: + port: 9999 + +algorithm: + adv_estimator: grpo + use_kl_in_reward: false + use_final_reward_as_step_reward: true + use_intrinsic_reward: true + +data: + train_files: ${variables.DATA_DIR}/train.parquet + val_files: ${variables.DATA_DIR}/test.parquet + train_batch_size: 32 + val_batch_size: 140 + max_prompt_length: 2048 + max_response_length: 512 + truncation: error + return_raw_chat: true + +actor_rollout_ref: + rollout: + tensor_model_parallel_size: ${variables.TENSOR_MODEL_PARALLEL_SIZE} + n: ${variables.NUM_ROLLOUTS} + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + multi_turn: + format: hermes + name: vllm + gpu_memory_utilization: 0.6 + enable_chunked_prefill: false + enforce_eager: false + free_cache_engine: true + val_kwargs: + temperature: 0.4 + do_sample: true + actor: + ppo_mini_batch_size: ${variables.MINI_BATCH_SIZE} + ppo_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + optim: + lr: 1.0e-6 + use_kl_loss: true + kl_loss_coef: 0.01 + kl_loss_type: low_var_kl + entropy_coeff: 0.001 + fsdp_config: + param_offload: false + optimizer_offload: false + ref: + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + fsdp_config: + param_offload: true + model: + path: ${variables.BASE_MODEL} + use_remove_padding: true + enable_gradient_checkpointing: true + +trainer: + n_gpus_per_node: ${variables.NUM_GPUS} + val_before_train: false + critic_warmup: 0 + logger: + - console + - wandb + project_name: ${variables.PROJECT_NAME} + experiment_name: ${variables.EXPERIMENT_NAME} + nnodes: 1 + save_freq: 100 + test_freq: 5 + total_epochs: 200 diff --git a/contrib/recipes/envs/config_verl/scienceworld/empo2_qwen_7b_instruct.yaml b/contrib/recipes/envs/config_verl/scienceworld/empo2_qwen_7b_instruct.yaml new file mode 100644 index 000000000..64d79f48e --- /dev/null +++ b/contrib/recipes/envs/config_verl/scienceworld/empo2_qwen_7b_instruct.yaml @@ -0,0 +1,101 @@ +# ========================== +# Variable definitions +# ========================== +variables: + NUM_GPUS: 8 + MINI_BATCH_SIZE: 16 + PER_GPU_BATCH_SIZE: 1 + TENSOR_MODEL_PARALLEL_SIZE: 2 + NUM_ROLLOUTS: 8 + BASE_MODEL: Qwen/Qwen2.5-7B-Instruct + PROJECT_NAME: EMPO2-ScienceWorld2 + TASK_NUM: ${oc.env:TASK_NUM,25} + TRIAL: ${oc.env:TRIAL,0} + EXPERIMENT_NAME: (all-off-policy-final-reward)empo2-${variables.TASK_NUM}-sciworld-${variables.TRIAL} + DATA_DIR: agl_envs/task_data/scienceworld/single_data/${variables.TASK_NUM} + OUTPUT_DIR: /mnt/jeonghyekim/empo2_checkpoint/0211/${variables.EXPERIMENT_NAME} + +# ========================== +# Main Config +# ========================== +agentlightning: + port: 9999 + +algorithm: + adv_estimator: grpo + use_kl_in_reward: false + use_final_reward_as_step_reward: true + use_intrinsic_reward: true + +data: + train_files: ${variables.DATA_DIR}/train.parquet + val_files: ${variables.DATA_DIR}/test.parquet + train_batch_size: 16 + val_batch_size: 80 + max_prompt_length: 16384 + max_response_length: 32 + max_train_length: 8192 + truncation: error + return_raw_chat: true + +actor_rollout_ref: + rollout: + tensor_model_parallel_size: ${variables.TENSOR_MODEL_PARALLEL_SIZE} + n: ${variables.NUM_ROLLOUTS} + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + multi_turn: + format: hermes + name: vllm + gpu_memory_utilization: 0.5 + enable_chunked_prefill: false + enforce_eager: false + free_cache_engine: true + val_kwargs: + temperature: 0.4 + do_sample: true + actor: + ppo_mini_batch_size: ${variables.MINI_BATCH_SIZE} + ppo_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + optim: + lr: 1.0e-6 + use_kl_loss: false + kl_loss_coef: 0.00 + entropy_coeff: 0.0 + clip_ratio_high: 0.30 + clip_ratio_low: 0.20 + clip_ratio_c: 10.0 + entropy_checkpointing: true + entropy_from_logits_with_chunking: true + fsdp_config: + param_offload: true + optimizer_offload: true + forward_prefetch: true + ref: + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + entropy_checkpointing: true + entropy_from_logits_with_chunking: true + fsdp_config: + param_offload: true + forward_prefetch: true + model: + path: ${variables.BASE_MODEL} + use_remove_padding: true + enable_gradient_checkpointing: true + +trainer: + default_local_dir: ${variables.OUTPUT_DIR}/checkpoints + n_gpus_per_node: ${variables.NUM_GPUS} + val_before_train: false + critic_warmup: 0 + logger: + - console + - wandb + project_name: ${variables.PROJECT_NAME} + experiment_name: ${variables.EXPERIMENT_NAME} + nnodes: 1 + save_freq: 50 + test_freq: 20 + total_epochs: 500 + +tips: + use_tips: true diff --git a/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_1.5b_instruct.yaml b/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_1.5b_instruct.yaml new file mode 100644 index 000000000..b90418b8b --- /dev/null +++ b/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_1.5b_instruct.yaml @@ -0,0 +1,87 @@ +# ========================== +# Variable definitions +# ========================== +variables: + NUM_GPUS: 2 + MINI_BATCH_SIZE: 32 + PER_GPU_BATCH_SIZE: 16 + TENSOR_MODEL_PARALLEL_SIZE: 2 + NUM_ROLLOUTS: 8 + BASE_MODEL: Qwen/Qwen2.5-1.5B-Instruct + PROJECT_NAME: AGL-Simulation-ScienceWorld + TASK_NUM: ${oc.env:TASK_NUM,-1} + TRIAL: ${oc.env:TRIAL,0} + EXPERIMENT_NAME: grpo-sciworld-${variables.TRIAL} + DATA_DIR: agl_envs/task_data/scienceworld/multi_data + +# ========================== +# Main Config +# ========================== +agentlightning: + port: 9999 + +algorithm: + adv_estimator: grpo + use_kl_in_reward: false + use_final_reward_as_step_reward: true + use_intrinsic_reward: true + +data: + train_files: ${variables.DATA_DIR}/train.parquet + val_files: ${variables.DATA_DIR}/test.parquet + train_batch_size: 32 + val_batch_size: 144 + max_prompt_length: 6000 + max_response_length: 1024 + truncation: error + return_raw_chat: true + +actor_rollout_ref: + rollout: + tensor_model_parallel_size: ${variables.TENSOR_MODEL_PARALLEL_SIZE} + n: ${variables.NUM_ROLLOUTS} + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + multi_turn: + format: hermes + name: vllm + gpu_memory_utilization: 0.6 + enable_chunked_prefill: false + enforce_eager: false + free_cache_engine: true + val_kwargs: + temperature: 0.4 + do_sample: true + actor: + ppo_mini_batch_size: ${variables.MINI_BATCH_SIZE} + ppo_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + optim: + lr: 1.0e-6 + use_kl_loss: true + kl_loss_coef: 0.01 + kl_loss_type: low_var_kl + entropy_coeff: 0.001 + fsdp_config: + param_offload: false + optimizer_offload: false + ref: + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + fsdp_config: + param_offload: true + model: + path: ${variables.BASE_MODEL} + use_remove_padding: true + enable_gradient_checkpointing: true + +trainer: + n_gpus_per_node: ${variables.NUM_GPUS} + val_before_train: false + critic_warmup: 0 + logger: + - console + - wandb + project_name: ${variables.PROJECT_NAME} + experiment_name: ${variables.EXPERIMENT_NAME} + nnodes: 1 + save_freq: 100 + test_freq: 5 + total_epochs: 500 diff --git a/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_7b_instruct.yaml b/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_7b_instruct.yaml new file mode 100644 index 000000000..7de1e0b5b --- /dev/null +++ b/contrib/recipes/envs/config_verl/scienceworld/grpo_qwen_7b_instruct.yaml @@ -0,0 +1,92 @@ +# ========================== +# Variable definitions +# ========================== +variables: + NUM_GPUS: 8 + MINI_BATCH_SIZE: 16 + PER_GPU_BATCH_SIZE: 1 + TENSOR_MODEL_PARALLEL_SIZE: 2 + NUM_ROLLOUTS: 8 + BASE_MODEL: Qwen/Qwen2.5-7B-Instruct + PROJECT_NAME: EMPO2-ScienceWorld2 + TASK_NUM: ${oc.env:TASK_NUM,25} + TRIAL: ${oc.env:TRIAL,0} + EXPERIMENT_NAME: (final-reward)grpo-${variables.TASK_NUM}-sciworld-${variables.TRIAL} + DATA_DIR: agl_envs/task_data/scienceworld/single_data/${variables.TASK_NUM} + OUTPUT_DIR: /mnt/jeonghyekim/empo2_grpo_checkpoint/0211/${variables.EXPERIMENT_NAME} + +# ========================== +# Main Config +# ========================== +agentlightning: + port: 9999 + +algorithm: + adv_estimator: grpo + use_kl_in_reward: false + use_final_reward_as_step_reward: true + use_intrinsic_reward: true + +data: + train_files: ${variables.DATA_DIR}/train.parquet + val_files: ${variables.DATA_DIR}/test.parquet + train_batch_size: 16 + val_batch_size: 80 + max_prompt_length: 16384 + max_response_length: 32 + max_train_length: 8192 + truncation: error + return_raw_chat: true + +actor_rollout_ref: + rollout: + tensor_model_parallel_size: ${variables.TENSOR_MODEL_PARALLEL_SIZE} + n: ${variables.NUM_ROLLOUTS} + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + multi_turn: + format: hermes + name: vllm + gpu_memory_utilization: 0.5 + enable_chunked_prefill: false + enforce_eager: false + free_cache_engine: true + val_kwargs: + temperature: 0.4 + do_sample: true + actor: + ppo_mini_batch_size: ${variables.MINI_BATCH_SIZE} + ppo_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + optim: + lr: 1.0e-6 + use_kl_loss: false + kl_loss_coef: 0.00 + entropy_coeff: 0.0 + clip_ratio_high: 0.30 + clip_ratio_low: 0.20 + clip_ratio_c: 10.0 + fsdp_config: + param_offload: true + optimizer_offload: true + ref: + log_prob_micro_batch_size_per_gpu: ${variables.PER_GPU_BATCH_SIZE} + fsdp_config: + param_offload: true + model: + path: ${variables.BASE_MODEL} + use_remove_padding: true + enable_gradient_checkpointing: true + +trainer: + default_local_dir: ${variables.OUTPUT_DIR}/checkpoints + n_gpus_per_node: ${variables.NUM_GPUS} + val_before_train: false + critic_warmup: 0 + logger: + - console + - wandb + project_name: ${variables.PROJECT_NAME} + experiment_name: ${variables.EXPERIMENT_NAME} + nnodes: 1 + save_freq: 50 + test_freq: 20 + total_epochs: 500 diff --git a/contrib/recipes/envs/empo2_server/server_bert.py b/contrib/recipes/envs/empo2_server/server_bert.py new file mode 100644 index 000000000..f60d2b2cb --- /dev/null +++ b/contrib/recipes/envs/empo2_server/server_bert.py @@ -0,0 +1,34 @@ +import os +import time + +import torch +import uvicorn +from fastapi import FastAPI, Request +from pydantic import BaseModel + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +torch.cuda.set_per_process_memory_fraction(0.1, 0) + +num_works = 1 + +app = FastAPI() + +from sentence_transformers import SentenceTransformer + +model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") + + +@app.post("/key_cal/") +async def compress(request: Request): + try: + data = await request.json() + text = data.get("text", "") + except: + text = (await request.body()).decode("utf-8") + + key = model.encode(text) + return {"key": key.tolist()} + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000, workers=num_works) diff --git a/contrib/recipes/envs/empo2_server/server_mem.py b/contrib/recipes/envs/empo2_server/server_mem.py new file mode 100644 index 000000000..7b7618144 --- /dev/null +++ b/contrib/recipes/envs/empo2_server/server_mem.py @@ -0,0 +1,74 @@ +import random +import time +from collections import deque + +import numpy as np +import uvicorn +from fastapi import FastAPI, Request +from pydantic import BaseModel + +num_works = 1 +app = FastAPI() + +mem_list = None +content_set = None + + +class MemRequest(BaseModel): + key: list + idx: int = None + content: str = None + score: float = None + + +@app.post("/mem/") +async def mem_handler(mem_req: MemRequest): + global cnt, mem_list, content_set + + key = mem_req.key + idx = mem_req.idx + content = mem_req.content + score = mem_req.score + + if content == "Reset": + mem_list_num = idx + content_set = {id: set() for id in range(mem_list_num)} + mem_list = {id: [] for id in range(mem_list_num)} + cnt = {id: 0 for id in range(mem_list_num)} + print(f"Clean all the mem. The num of mem_list is {mem_list_num}") + return None + + if content is not None: + if content not in content_set[idx]: + content_set[idx].add(content) + mem_list[idx].append( + { + "cnt": cnt[idx], + "key": key, + "content": content, + "score": score, + } + ) + cnt[idx] += 1 + if len(mem_list[idx]) > 1000: + oldest_hash = mem_list[idx][0]["content"] + content_set[idx].discard(oldest_hash) + mem_list[idx] = mem_list[idx][-1000:] + print("Add,", "id", idx, "cnt", cnt[idx], "content", content, "score", score) + else: + data = [] + for mem in mem_list[idx]: + mem_key = mem["key"] + sim = np.dot(key, mem_key) / (np.linalg.norm(key) * np.linalg.norm(mem_key)) + if sim > 0.5: + data.append(mem) + # data = random.sample(data, min(len(data), 10)) if len(data) > 0 else [] + data = sorted(data, key=lambda x: -x["score"])[:10] if len(data) > 0 else [] + data = [x["content"] for x in data] + count = len(data) + print("Load", count, data) + return count, data + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8001, workers=num_works) diff --git a/contrib/recipes/envs/prompt_builder.py b/contrib/recipes/envs/prompt_builder.py index 905285d55..48645bdac 100644 --- a/contrib/recipes/envs/prompt_builder.py +++ b/contrib/recipes/envs/prompt_builder.py @@ -80,7 +80,7 @@ def init(self, env): self._events.clear() if self.prompt_type == "chat": - inst_prompt = env.get_instruction_prompt(info) + inst_prompt = env.get_instruction_prompt() self.update_instruction_prompt(inst_prompt) elif self.prompt_type == "single": template_wo_his, template = env.get_single_prompt_template() diff --git a/contrib/recipes/envs/train_env_agent.py b/contrib/recipes/envs/train_env_agent.py index 81aefe841..ae5147238 100644 --- a/contrib/recipes/envs/train_env_agent.py +++ b/contrib/recipes/envs/train_env_agent.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. - import argparse import os +import re import subprocess +import time from omegaconf import OmegaConf @@ -51,9 +52,8 @@ def get_config(path): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--env", type=str, default="scienceworld") - parser.add_argument("--algorithm", type=str, default="grpo") - parser.add_argument("--debug", action="store_true") + parser.add_argument("--env", type=str, default="scienceworld2") + parser.add_argument("--algorithm", type=str, default="empo2_qwen_7b_instruct") parser.add_argument("--n_workers", type=int, default=64, help="Number of workers for training") parser.add_argument("--trial", type=int, default=0, help="Number of trials") parser.add_argument("--task_num", type=int, default=25, help="ScienceWorld Task number to inject as env var") @@ -68,17 +68,16 @@ def get_config(path): # set environment variable before loading configs os.environ["TRIAL"] = str(args.trial) - if args.env == "scienceworld": + if "scienceworld" in args.env: os.environ["TASK_NUM"] = str(args.task_num) # Load configs agent_config_path = f"config_env/{args.env}.yaml" - if args.debug: - trainer_config_path = f"config_verl/{args.env}/debug/{args.algorithm}.yaml" - else: - trainer_config_path = f"config_verl/{args.env}/{args.algorithm}.yaml" agent_config = get_config(agent_config_path) + env_prefix = re.sub(r"\d+$", "", args.env) + trainer_config_path = f"config_verl/{env_prefix}/{args.algorithm}.yaml" + if "gigpo" in args.algorithm: agent_config.log_env_obs = True rl_training_config = get_config(trainer_config_path) @@ -87,9 +86,26 @@ def get_config(path): train_dataset, val_dataset = train_val_dataset(rl_training_config) # Initialize agent - from contrib.agentlightning.contrib.agent.env_agent import EnvAgent + if "empo2" in args.algorithm: + from contrib.agentlightning.contrib.agent.empo2_agent import EMPO2Agent, reset_memory + + kill_process_on_port(8000) + kill_process_on_port(8001) + + os.makedirs("logs", exist_ok=True) + + subprocess.Popen(f"nohup python empo2_server/server_bert.py > logs/bert_{args.task_num}.log 2>&1 &", shell=True) + subprocess.Popen(f"nohup python empo2_server/server_mem.py > logs/mem_{args.task_num}.log 2>&1 &", shell=True) + + NUM_MEMORY = 5 + time.sleep(1) + reset_memory(NUM_MEMORY) + + agent = EMPO2Agent(agent_config) + else: + from contrib.agentlightning.contrib.agent.env_agent import EnvAgent - agent = EnvAgent(agent_config) + agent = EnvAgent(agent_config) # Initialize trainer and start training trainer = Trainer(