diff --git a/.gitignore b/.gitignore index a83a714b7f..66e4b066be 100644 --- a/.gitignore +++ b/.gitignore @@ -161,13 +161,16 @@ pufferlib/ocean/impulse_wars/debug-*/ pufferlib/ocean/impulse_wars/release-*/ pufferlib/ocean/impulse_wars/benchmark/ - # Ignore data files data/ -pufferlib/resources/drive/binaries/ +pufferlib/resources/drive/binaries/* +pufferlib/resources/drive/binaries/training/ +pufferlib/resources/drive/binaries/validation/ # But keep map_000.bin for the training test !pufferlib/resources/drive/binaries/map_000.bin +!pufferlib/resources/drive/binaries/training/map_000.bin +pufferlib/resources/drive/sanity/sanity_binaries/ # Compiled drive binary in root /drive @@ -183,6 +186,10 @@ pufferlib/resources/drive/output_agent.gif pufferlib/resources/drive/output.gif # Local artifacts and outputs artifacts/ # Local drive renders +pufferlib/resources/drive/output*.gif +emsdk/ +docs/book/* +!docs/book/assets/ pufferlib/resources/drive/output*.mp4 # Local TODO tracking diff --git a/evaluate_human_logs.py b/evaluate_human_logs.py index 83472b6c71..a91e8bd2ce 100644 --- a/evaluate_human_logs.py +++ b/evaluate_human_logs.py @@ -6,8 +6,8 @@ import pufferlib import pufferlib.vector from pufferlib.ocean import env_creator -from pufferlib.ocean.torch import Drive, Recurrent - +from pufferlib.ocean.torch import Drive, Recurrent, Transformer +from pufferlib.ocean.benchmark.evaluator import HumanReplayEvaluator import matplotlib.pyplot as plt import numpy as np @@ -101,311 +101,204 @@ def plot_adaptive_metrics(first_metrics, last_metrics, delta_metrics, output_pat def main(): + print("Beginning human evaluations using HumanReplayEvaluator") parser = argparse.ArgumentParser() parser.add_argument("--policy-path", type=str, required=True) + parser.add_argument("--policy-architecture", type=str, default="Recurrent") parser.add_argument("--num-maps", type=int, default=10) - parser.add_argument("--num-rollouts", type=int, default=1000) - parser.add_argument("--batch-size", type=int, default=32, help="Max parallel rollouts per batch") - parser.add_argument("--num-workers", type=int, default=16) + parser.add_argument("--num-rollouts", type=int, default=100) parser.add_argument("--num-agents", type=int, default=64) - parser.add_argument( - "--condition-type", - type=str, - default="none", - choices=["none", "reward", "entropy", "discount", "all"], - help="Conditioning type (none, reward, entropy, discount, all)", - ) parser.add_argument("--output", type=str, default="eval_human_logs.json") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") - parser.add_argument( - "--max-controlled-agents", type=int, default=-1 - ) ## needs to be 1 if you want human logs, -1 if you want Self Play - parser.add_argument("--adaptive-driving-agent", type=int, default=0, help="Enable adaptive driving agent") - parser.add_argument("--k-scenarios", type=int, default=1, help="Number of scenarios (default 1 for non-adaptive)") + parser.add_argument("--max-controlled-agents", type=int, default=1) + parser.add_argument("--adaptive-driving-agent", type=int, default=0) + parser.add_argument("--k-scenarios", type=int, default=1) parser.add_argument("--dynamics-model", type=str, default="classic") - args = parser.parse_args() - - num_batches = (args.num_rollouts + args.batch_size - 1) // args.batch_size + parser.add_argument("--human-replay", action="store_true") + args_parsed = parser.parse_args() print(f"Evaluation Configuration:") - print(f" Policy: {args.policy_path}") - print(f" Conditioning: {args.condition_type}") - print(f" Num maps: {args.num_maps}") - print(f" Total rollouts: {args.num_rollouts}") - print(f" Batch size: {args.batch_size}") - print(f" Num batches: {num_batches}") - print(f" Num agents per env: {args.num_agents}") - print(f" Adaptive agent: {bool(args.adaptive_driving_agent)}") - print(f" K scenarios: {args.k_scenarios}") - print(f" Output: {args.output}\n") - print(f" Dynamics Model: {args.dynamics_model}") - - # Load policy - print("Loading policy...") - env_name = "puffer_adaptive_drive" if args.adaptive_driving_agent else "puffer_drive" + print(f" Policy: {args_parsed.policy_path}") + print(f" Policy Architecture: {args_parsed.policy_architecture}") + print(f" Num maps: {args_parsed.num_maps}") + print(f" Total rollouts: {args_parsed.num_rollouts}") + print(f" Num agents per env: {args_parsed.num_agents}") + print(f" Adaptive agent: {bool(args_parsed.adaptive_driving_agent)}") + print(f" K scenarios: {args_parsed.k_scenarios}") + print(f" Dynamics Model: {args_parsed.dynamics_model}") + print(f" Output: {args_parsed.output}\n") + + # Build args dict in the format expected by HumanReplayEvaluator + env_name = "puffer_adaptive_drive" if args_parsed.adaptive_driving_agent else "puffer_drive" make_env = env_creator(env_name) - temp_env = make_env( - num_agents=64, - num_maps=args.num_maps, - scenario_length=91, - co_player_cond_type=args.condition_type, - adaptive_driving_agent=args.adaptive_driving_agent, - k_scenarios=args.k_scenarios, - dynamics_model=args.dynamics_model, - ) - - base_policy = Drive(temp_env, input_size=64, hidden_size=256) - policy = Recurrent(temp_env, base_policy, input_size=256, hidden_size=256).to(args.device) - state_dict = torch.load(args.policy_path, map_location=args.device) + + scenario_length = 91 + context_length = args_parsed.k_scenarios * scenario_length + + args = { + "train": { + "device": args_parsed.device, + "use_rnn": args_parsed.policy_architecture == "Recurrent", + "policy_architecture": args_parsed.policy_architecture, + "context_window": context_length, + }, + "env": { + "num_agents": args_parsed.num_agents, + "num_maps": args_parsed.num_maps, + "scenario_length": scenario_length, + "adaptive_driving_agent": args_parsed.adaptive_driving_agent, + "k_scenarios": args_parsed.k_scenarios, + "dynamics_model": args_parsed.dynamics_model, + "max_controlled_agents": args_parsed.max_controlled_agents, + "report_interval": 1, + "control_mode": "control_vehicles", + "episode_length": scenario_length, + "report_all_scenarios": args_parsed.adaptive_driving_agent, + "dynamics_model": "classic", + "reward_vehicle_collision": -0.5, + "reward_offroad_collision": -0.5, + "reward_goal": 1.0, + "reward_goal_post_respawn": 0.25, + }, + "vec": { + "backend": "PufferEnv", + "num_envs": 1, + }, + "eval": { + "human_replay_control_mode": "control_vehicles", + }, + } + + if args_parsed.human_replay: + args["env"]["human_replay_mode"] = True + + # Load policy once + print("Loading policy...") + temp_env = make_env(**args["env"]) + + if args_parsed.policy_architecture == "Recurrent": + base_policy = Drive(temp_env, input_size=64, hidden_size=256) + policy = Recurrent(temp_env, base_policy, input_size=256, hidden_size=256).to(args_parsed.device) + elif args_parsed.policy_architecture == "Transformer": + base_policy = Drive(temp_env, input_size=128, hidden_size=256) + policy = Transformer( + temp_env, + base_policy, + input_size=256, + hidden_size=256, + num_layers=2, + num_heads=4, + context_length=context_length, + ).to(args_parsed.device) + + state_dict = torch.load(args_parsed.policy_path, map_location=args_parsed.device) state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} policy.load_state_dict(state_dict) policy.eval() temp_env.close() print("Policy loaded successfully\n") - # Run evaluation in batches - all_returns = [] - all_scenario_metrics = [] # Track metrics per scenario for adaptive agents - all_metrics = [] - - env_kwargs = { - "num_agents": args.num_agents, - "num_maps": args.num_maps, - "max_controlled_agents": args.max_controlled_agents, - "report_interval": 1, - "scenario_length": 91, - "adaptive_driving_agent": args.adaptive_driving_agent, - "k_scenarios": args.k_scenarios, - "dynamics_model": args.dynamics_model, - "co_player_cond_type": args.condition_type, - "co_player_cond_entropy_ub": 0.05, - "co_player_cond_discount_lb": 0.40, - } - - print("Running evaluation...") - for batch_idx in range(num_batches): - batch_rollouts = min(args.batch_size, args.num_rollouts - batch_idx * args.batch_size) + # Create evaluator + from pufferlib.ocean.benchmark.evaluator import HumanReplayEvaluator - print(f"Batch {batch_idx + 1}/{num_batches} ({batch_rollouts} rollouts)") + evaluator = HumanReplayEvaluator(args) - # Find largest valid num_workers (divisor of batch_rollouts) - max_workers = min(args.num_workers, batch_rollouts) - while batch_rollouts % max_workers != 0: - max_workers -= 1 + # Run multiple rollouts + print(f"Running {args_parsed.num_rollouts} rollouts...") + all_results = [] + for rollout_idx in tqdm(range(args_parsed.num_rollouts), desc="Rollouts"): + # Create fresh env for each rollout vecenv = pufferlib.vector.make( make_env, - env_kwargs=env_kwargs, - backend=pufferlib.vector.Multiprocessing, - num_envs=batch_rollouts, - num_workers=max_workers, + env_kwargs=args["env"], + backend=pufferlib.vector.Serial, + num_envs=1, ) - obs, _ = vecenv.reset() - total_agents = obs.shape[0] - - state = { - "lstm_h": torch.zeros(total_agents, policy.hidden_size, device=args.device), - "lstm_c": torch.zeros(total_agents, policy.hidden_size, device=args.device), - } - - batch_infos = [] - total_reward = np.zeros(total_agents) - scenario_infos = [] # Track infos per scenario - - with torch.no_grad(): - # Run through all scenarios (1 for non-adaptive, k for adaptive) - for scenario in range(args.k_scenarios): - scenario_info_list = [] - desc = f" Scenario {scenario + 1}/{args.k_scenarios}" if args.k_scenarios > 1 else " Steps" - for t in tqdm(range(91), desc=desc, ncols=80, leave=False): - obs_t = torch.as_tensor(obs, device=args.device) - logits, _ = policy.forward_eval(obs_t, state) - action, _, _ = pufferlib.pytorch.sample_logits(logits) - - obs, reward, done, trunc, info = vecenv.step(action.cpu().numpy()) - total_reward += reward - - if info: - valid_infos = [inf for inf in info if "score" in inf] - batch_infos.extend(valid_infos) - scenario_info_list.extend(valid_infos) - - # Store per-scenario infos - if args.adaptive_driving_agent: - scenario_infos.append(scenario_info_list) + # Run single rollout + results = evaluator.rollout(args, vecenv, policy) + all_results.append(results) vecenv.close() - # Aggregate batch metrics - num_infos = len(batch_infos) or 1 - batch_metrics = { - "score": sum(info.get("score", 0) for info in batch_infos) / num_infos, - "collision_rate": sum(info.get("collision_rate", 0) for info in batch_infos) / num_infos, - "offroad_rate": sum(info.get("offroad_rate", 0) for info in batch_infos) / num_infos, - "completion_rate": sum(info.get("completion_rate", 0) for info in batch_infos) / num_infos, - "dnf_rate": sum(info.get("dnf_rate", 0) for info in batch_infos) / num_infos, - "avg_collisions_per_agent": sum(info.get("avg_collisions_per_agent", 0) for info in batch_infos) - / num_infos, - "avg_offroad_per_agent": sum(info.get("avg_offroad_per_agent", 0) for info in batch_infos) / num_infos, + # Aggregate results + print("\nAggregating results...") + aggregated = {} + + # Get all metric keys from first result + all_keys = list(all_results[0].keys()) + metric_keys = [k for k in all_keys if not k.startswith("ada_delta")] + delta_keys = [k for k in all_keys if k.startswith("ada_delta")] + + # Average regular metrics + for key in metric_keys: + values = [r.get(key, 0) for r in all_results] + aggregated[key] = float(np.mean(values)) + + # Average delta metrics if present + if delta_keys: + for key in delta_keys: + values = [r.get(key, 0) for r in all_results] + aggregated[key] = float(np.mean(values)) + + # Derive last scenario metrics from first + delta + # Extract first scenario metrics + first_scenario_keys = [k for k in metric_keys if k not in ["n"]] + + # Map metric names to their delta counterparts + metric_to_delta = { + "score": "ada_delta_score", + "collision_rate": "ada_delta_collision_rate", + "offroad_rate": "ada_delta_offroad_rate", + "completion_rate": "ada_delta_completion_rate", + "episode_return": "ada_delta_episode_return", + "dnf_rate": "ada_delta_dnf_rate", + "lane_alignment_rate": "ada_delta_lane_alignment_rate", } - rollout_rewards = total_reward.reshape(batch_rollouts, args.num_agents) - rollout_returns = rollout_rewards.mean(axis=1) - - all_returns.extend(rollout_returns.tolist()) - all_metrics.append(batch_metrics) - - # Store scenario-specific metrics for adaptive agents - if args.adaptive_driving_agent: - batch_scenario_metrics = [] - for scenario_info_list in scenario_infos: - num_scenario_infos = len(scenario_info_list) or 1 - scenario_metrics = { - "score": sum(info.get("score", 0) for info in scenario_info_list) / num_scenario_infos, - "collision_rate": sum(info.get("collision_rate", 0) for info in scenario_info_list) - / num_scenario_infos, - "offroad_rate": sum(info.get("offroad_rate", 0) for info in scenario_info_list) - / num_scenario_infos, - "completion_rate": sum(info.get("completion_rate", 0) for info in scenario_info_list) - / num_scenario_infos, - "dnf_rate": sum(info.get("dnf_rate", 0) for info in scenario_info_list) / num_scenario_infos, - "num_goals_reached": sum(info.get("num_goals_reached", 0) for info in scenario_info_list) - / num_scenario_infos, - "lane_alignment_rate": sum(info.get("lane_alignment_rate", 0) for info in scenario_info_list) - / num_scenario_infos, - "avg_displacement_error": sum(info.get("avg_displacement_error", 0) for info in scenario_info_list) - / num_scenario_infos, - "episode_return": sum(info.get("episode_return", 0) for info in scenario_info_list) - / num_scenario_infos, - } - # Compute perf (score without collision before goal) - scenario_metrics["perf"] = scenario_metrics["score"] - batch_scenario_metrics.append(scenario_metrics) - - all_scenario_metrics.append(batch_scenario_metrics) - - # Aggregate across all batches - all_returns = np.array(all_returns) - - # Divide by k_scenarios since we accumulated across all scenarios - all_returns = all_returns / args.k_scenarios - - metrics = {k: np.mean([m[k] for m in all_metrics]) for k in all_metrics[0].keys()} - - results = { - "avg_return": float(np.mean(all_returns)), - "std_return": float(np.std(all_returns)), - "se_return": float(np.std(all_returns) / np.sqrt(len(all_returns))), # Standard error - **{k: float(v) for k, v in metrics.items()}, - } - - # Compute adaptive delta metrics from scenario metrics - first_scenario_metrics = None - last_scenario_metrics = None - - if args.adaptive_driving_agent and len(all_scenario_metrics) > 0: - # Aggregate scenario metrics across all batches - # all_scenario_metrics is a list of [batch][scenario] metrics - # We need to compute average for each scenario across all batches - - aggregated_scenario_metrics = [] - for scenario_idx in range(args.k_scenarios): - scenario_metrics_list = [batch[scenario_idx] for batch in all_scenario_metrics] - - # Average each metric across batches for this scenario - avg_scenario_metrics = {} - for key in scenario_metrics_list[0].keys(): - avg_scenario_metrics[key] = np.mean([m[key] for m in scenario_metrics_list]) - - aggregated_scenario_metrics.append(avg_scenario_metrics) + # Store first scenario metrics + for metric_name in metric_to_delta.keys(): + if metric_name in aggregated: + aggregated[f"first_scenario_{metric_name}"] = aggregated[metric_name] - # Get first and last scenario metrics - first_scenario_metrics = aggregated_scenario_metrics[0] - last_scenario_metrics = aggregated_scenario_metrics[-1] + # Compute last scenario metrics: last = first + delta + for metric_name, delta_key in metric_to_delta.items(): + if metric_name in aggregated and delta_key in aggregated: + aggregated[f"last_scenario_{metric_name}"] = aggregated[metric_name] + aggregated[delta_key] - # Helper function to compute delta percentage - def compute_delta_percent(first_val, last_val): - if abs(first_val) < 0.0001: - return 0.0 - return (last_val - first_val) / first_val * 100.0 - - # Compute all delta metrics - results["ada_delta_completion_rate"] = compute_delta_percent( - first_scenario_metrics["completion_rate"], last_scenario_metrics["completion_rate"] - ) - results["ada_delta_score"] = compute_delta_percent( - first_scenario_metrics["score"], last_scenario_metrics["score"] - ) - results["ada_delta_perf"] = compute_delta_percent(first_scenario_metrics["perf"], last_scenario_metrics["perf"]) - results["ada_delta_collision_rate"] = compute_delta_percent( - first_scenario_metrics["collision_rate"], last_scenario_metrics["collision_rate"] - ) - results["ada_delta_offroad_rate"] = compute_delta_percent( - first_scenario_metrics["offroad_rate"], last_scenario_metrics["offroad_rate"] - ) - results["ada_delta_num_goals_reached"] = compute_delta_percent( - first_scenario_metrics["num_goals_reached"], last_scenario_metrics["num_goals_reached"] - ) - results["ada_delta_dnf_rate"] = compute_delta_percent( - first_scenario_metrics["dnf_rate"], last_scenario_metrics["dnf_rate"] - ) - results["ada_delta_lane_alignment_rate"] = compute_delta_percent( - first_scenario_metrics["lane_alignment_rate"], last_scenario_metrics["lane_alignment_rate"] - ) - results["ada_delta_avg_displacement_error"] = compute_delta_percent( - first_scenario_metrics["avg_displacement_error"], last_scenario_metrics["avg_displacement_error"] - ) - results["ada_delta_episode_return"] = compute_delta_percent( - first_scenario_metrics["episode_return"], last_scenario_metrics["episode_return"] - ) + # Save results + with open(args_parsed.output, "w") as f: + json.dump(aggregated, f, indent=2) - # Store first and last scenario values for reporting - results["first_scenario_score"] = float(first_scenario_metrics["score"]) - results["first_scenario_collision_rate"] = float(first_scenario_metrics["collision_rate"]) - results["first_scenario_offroad_rate"] = float(first_scenario_metrics["offroad_rate"]) - results["first_scenario_episode_return"] = float(first_scenario_metrics["episode_return"]) - - results["last_scenario_score"] = float(last_scenario_metrics["score"]) - results["last_scenario_collision_rate"] = float(last_scenario_metrics["collision_rate"]) - results["last_scenario_offroad_rate"] = float(last_scenario_metrics["offroad_rate"]) - results["last_scenario_episode_return"] = float(last_scenario_metrics["episode_return"]) - - with open(args.output, "w") as f: - json.dump(results, f, indent=2) - - print(f"\nResults:") - print(f" Return: {results['avg_return']:.2f} ± {results['se_return']:.2f} (SE)") - print(f" Score: {results['score']:.3f}") - print(f" Completion: {results['completion_rate']:.3f}") - print(f" Collision: {results['collision_rate']:.3f}") - print(f" Collision per agent: {results['avg_collisions_per_agent']:.3f}") - print(f" Offroad: {results['offroad_rate']:.3f}") - - if args.adaptive_driving_agent: + # Print results + if args_parsed.adaptive_driving_agent and delta_keys: print(f"\n0-Shot Performance (First Scenario):") - print(f" Score: {results['first_scenario_score']:.3f}") - print(f" Collision: {results['first_scenario_collision_rate']:.3f}") - print(f" Offroad: {results['first_scenario_offroad_rate']:.3f}") - print(f" Return: {results['first_scenario_episode_return']:.2f}") + print(f" Score: {aggregated.get('first_scenario_score', float('nan')):.3f}") + print(f" Collision: {aggregated.get('first_scenario_collision_rate', float('nan')):.3f}") + print(f" Offroad: {aggregated.get('first_scenario_offroad_rate', float('nan')):.3f}") + print(f" Return: {aggregated.get('first_scenario_episode_return', float('nan')):.2f}") print(f"\nAdapted Performance (Last Scenario):") - print(f" Score: {results['last_scenario_score']:.3f}") - print(f" Collision: {results['last_scenario_collision_rate']:.3f}") - print(f" Offroad: {results['last_scenario_offroad_rate']:.3f}") - print(f" Return: {results['last_scenario_episode_return']:.2f}") + print(f" Score: {aggregated.get('last_scenario_score', float('nan')):.3f}") + print(f" Collision: {aggregated.get('last_scenario_collision_rate', float('nan')):.3f}") + print(f" Offroad: {aggregated.get('last_scenario_offroad_rate', float('nan')):.3f}") + print(f" Return: {aggregated.get('last_scenario_episode_return', float('nan')):.2f}") + + print(f"\nAdaptive Metrics (Delta):") + print(f" Score: {aggregated.get('ada_delta_score', float('nan')):.4f}") + print(f" Collision rate: {aggregated.get('ada_delta_collision_rate', float('nan')):.4f}") + print(f" Offroad rate: {aggregated.get('ada_delta_offroad_rate', float('nan')):.4f}") + print(f" Episode return: {aggregated.get('ada_delta_episode_return', float('nan')):.4f}") - print(f"\nAdaptive Metrics (Delta %):") - print(f" Score: {results['ada_delta_score']:.2f}%") - print(f" Collision rate: {results['ada_delta_collision_rate']:.2f}%") - print(f" Offroad rate: {results['ada_delta_offroad_rate']:.2f}%") - print(f" Episode return: {results['ada_delta_episode_return']:.2f}%") + print(f"\nSaved to {args_parsed.output}") + import sys - # Generate visualization - plot_adaptive_metrics(first_scenario_metrics, last_scenario_metrics, results, args.output) + sys.exit(0) - print(f"\nSaved to {args.output}") +if __name__ == "__main__": + main() if __name__ == "__main__": main() diff --git a/external/pyxodr b/external/pyxodr new file mode 160000 index 0000000000..cd4b837a65 --- /dev/null +++ b/external/pyxodr @@ -0,0 +1 @@ +Subproject commit cd4b837a651d4f10c3c4e77b04a029cac367c64b diff --git a/pufferlib/config/ocean/adaptive.ini b/pufferlib/config/ocean/adaptive.ini index 3e4eaf60c4..598faf50ff 100644 --- a/pufferlib/config/ocean/adaptive.ini +++ b/pufferlib/config/ocean/adaptive.ini @@ -2,58 +2,73 @@ package = ocean env_name = puffer_adaptive_drive policy_name = Drive -rnn_name = Recurrent +transformer_name = Transformer + ; Changed from rnn_name [vec] num_workers = 16 num_envs = 16 -batch_size = 1 +batch_size = 2 ; backend = Serial [policy] -input_size = 64 +input_size = 128 +; Increased from 64 for richer representations hidden_size = 256 -[rnn] +[transformer] input_size = 256 hidden_size = 256 +num_layers = 2 +; Number of transformer layers +num_heads = 4 +; Number of attention heads (must divide hidden_size) +; context_length = 182 +; k_scenarios (2) * scenario_length (91) = maximum attention span +dropout = 0.0 +; Dropout (keep at 0 for RL stability initially) [env] -num_agents = 1024 -num_ego_agents = 512 +num_agents = 1512 +num_ego_agents = 756 ; Options: discrete, continuous action_type = discrete ; Options: classic, jerk dynamics_model = classic -; Number of consecutive scenarios per episode (adaptive-specific) -k_scenarios = 2 reward_vehicle_collision = -0.5 -reward_offroad_collision = -0.2 -reward_ade = 0.0 +reward_offroad_collision = -0.5 dt = 0.1 reward_goal = 1.0 reward_goal_post_respawn = 0.25 +; in case of reward conditioning, we scale the goal_weight by this number for post respawn ; Meters around goal to be considered "reached" goal_radius = 2.0 -; What to do when goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop" +; Max target speed in m/s for the agent to maintain towards the goal +goal_speed = 100.0 +; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop" goal_behavior = 0 +; Determines the target distance to the new goal in the case of goal_behavior = generate_new_goals. +; Large numbers will select a goal point further away from the agent's current position. +goal_target_distance = 30.0 ; Options: 0 - Ignore, 1 - Stop, 2 - Remove collision_behavior = 0 ; Options: 0 - Ignore, 1 - Stop, 2 - Remove offroad_behavior = 0 -; Number of steps before reset +; Number of steps in each scenario (constrained by base data) scenario_length = 91 -; Resample frequency = k_scenarios * scenario_length (adaptive-specific) -resample_frequency = 182 +k_scenarios = 2 +termination_mode = 1 +; 0 - terminate at episode_length, 1 - terminate after all agents have been reset +map_dir = "resources/drive/binaries/training" num_maps = 1000 -; Which step of the trajectory to initialize the agents at upon reset +; Determines which step of the trajectory to initialize the agents at upon reset init_steps = 0 -; Options: "control_vehicles", "control_agents", "control_tracks_to_predict" +; Options: "control_vehicles", "control_agents", "control_wosac", "control_sdc_only" control_mode = "control_vehicles" ; Options: "created_all_valid", "create_only_controlled" init_mode = "create_all_valid" ; train with co players -co_player_enabled = 1 +co_player_enabled = False [env.conditioning] @@ -71,10 +86,9 @@ discount_weight_lb = 0.80 discount_weight_ub = 0.98 [env.co_player_policy] -enabled = True policy_name = Drive rnn_name = Recurrent -policy_path = "experiments/puffer_drive_ewdjljwd.pt" +policy_path = "pufferlib/resources/drive/policies/varied_discount.pt" input_size = 64 hidden_size = 256 @@ -96,68 +110,71 @@ entropy_weight_ub = 0.001 discount_weight_lb = 0.80 discount_weight_ub = 0.98 - [train] +seed=42 total_timesteps = 2_000_000_000 -# learning_rate = 0.02 -# gamma = 0.985 anneal_lr = True -; Needs to be: num_agents * num_workers * BPTT horizon +; Needs to be: num_agents * num_workers * context_window batch_size = auto -; minibatch_size = 745472 -; minibatch_multiplier = 512 -; max_minibatch_size = 745472 -minibatch_size = 372736 -minibatch_multiplier = 256 -max_minibatch_size = 372736 -; BPTT horizon (overridden by pufferl.py for adaptive agents to k_scenarios * scenario_length) +minibatch_size = 36400 +; 400 * 91 +max_minibatch_size = 36400 +minibatch_multiplier = 400 +policy_architecture = Transformer +; Matches scenario_length for buffer organization bptt_horizon = 32 +; Keep for backward compatibility adam_beta1 = 0.9 adam_beta2 = 0.999 adam_eps = 1e-8 clip_coef = 0.2 -ent_coef = 0.001 +ent_coef = 0.005 gae_lambda = 0.95 gamma = 0.98 -learning_rate = 0.001 -max_grad_norm = 1 -prio_alpha = 0.8499999999999999 -prio_beta0 = 0.8499999999999999 +learning_rate = 0.0003 +; Reduced from 0.003 (transformers often need lower LR) +max_grad_norm = 1.0 +prio_alpha = 0.85 +prio_beta0 = 0.85 update_epochs = 1 -vf_clip_coef = 0.1999999999999999 -vf_coef = 2 +vf_clip_coef = 0.2 +vf_coef = 2.0 vtrace_c_clip = 1 vtrace_rho_clip = 1 -checkpoint_interval = 1000 +checkpoint_interval = 10 # Rendering options render = True -render_interval = 1000 +render_interval = 100 ; If True, show exactly what the agent sees in agent observation obs_only = True ; Show grid lines -show_grid = False +show_grid = True ; Draws lines from ego agent observed ORUs and road elements to show detection range show_lasers = False ; Display human xy logs in the background -show_human_logs = True -; Options: str to path (e.g., "resources/drive/binaries/map_001.bin"), None +show_human_logs = False +; If True, zoom in on a part of the map. Otherwise, show full map +zoom_in = True +; Options: List[str to path], str to path (e.g., "resources/drive/training/binaries/map_001.bin"), None render_map = none [eval] -eval_interval = 1000 +eval_interval = 10 +; Path to dataset used for evaluation +map_dir = "resources/drive/binaries/training" +; Evaluation will run on the first num_maps maps in the map_dir directory +num_maps = 20 backend = PufferEnv -# WOSAC (Waymo Open Sim Agents Challenge) evaluation settings +; WOSAC (Waymo Open Sim Agents Challenge) evaluation settings ; If True, enables evaluation on realism metrics each time we save a checkpoint -wosac_realism_eval = True +wosac_realism_eval = False ; Number of policy rollouts per scene wosac_num_rollouts = 32 ; When to start the simulation wosac_init_steps = 10 -; Total number of WOSAC agents to evaluate -wosac_num_agents = 256 -; Control the tracks to predict -wosac_control_mode = "control_tracks_to_predict" -; Initialize from the tracks to predict +; Control everything valid at init in the scene +wosac_control_mode = "control_wosac" +; Create everything in valid at init the scene wosac_init_mode = "create_all_valid" ; Stop when reaching the goal wosac_goal_behavior = 2 @@ -168,23 +185,27 @@ wosac_sanity_check = False wosac_aggregate_results = True ; If True, enable human replay evaluation (pair policy-controlled agent with human replays) human_replay_eval = True -; Control only the self-driving car -human_replay_control_mode = "control_sdc_only" -; This equals the number of scenarios, since we control one agent in each -human_replay_num_agents = 64 - -[sweep.env.reward_vehicle_collision] -distribution = uniform -min = -0.5 -max = 0.0 -mean = -0.05 +; Control mode for human replay (control_vehicles with max_controlled_agents=1 controls one agent) +human_replay_control_mode = "control_vehicles" +; Number of agents in human replay evaluation environment +human_replay_num_agents = 32 +; Number of rollouts for human replay evaluation +human_replay_num_rollouts = 100 +; Number of maps to use for human replay evaluation +human_replay_num_maps = 100 + +[sweep.train.learning_rate] +distribution = log_normal +min = 0.001 +mean = 0.003 +max = 0.005 scale = auto -[sweep.env.reward_offroad_collision] -distribution = uniform -min = -0.5 -max = 0.0 -mean = -0.05 +[sweep.train.ent_coef] +distribution = log_normal +min = 0.001 +mean = 0.005 +max = 0.03 scale = auto [sweep.env.goal_radius] @@ -194,16 +215,18 @@ max = 20.0 mean = 10.0 scale = auto -[sweep.env.reward_ade] -distribution = uniform -min = -0.1 -max = 0.0 -mean = -0.02 +[sweep.train.gae_lambda] +distribution = log_normal +min = 0.95 +mean = 0.98 +max = 0.999 scale = auto -[sweep.env.reward_goal_post_respawn] -distribution = uniform -min = 0.0 -max = 1.0 -mean = 0.5 -scale = auto +[controlled_exp.train.goal_speed] +values = [10, 20, 30, 3] + +[controlled_exp.train.ent_coef] +values = [0.001, 0.005, 0.01] + +[controlled_exp.train.seed] +values = [42, 55, 1] diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index b7bc0d021f..7d5efd43be 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -2,7 +2,7 @@ package = ocean env_name = puffer_drive policy_name = Drive -rnn_name = Recurrent +rnn_name = Transformer [vec] num_workers = 16 @@ -14,38 +14,56 @@ batch_size = 2 input_size = 64 hidden_size = 256 -[rnn] +; [rnn] +; input_size = 256 +; hidden_size = 256 + +[transformer] input_size = 256 hidden_size = 256 +num_layers = 2 +; Number of transformer layers +num_heads = 4 +; Number of attention heads (must divide hidden_size) +context_window = 32 +; k_scenarios (2) * scenario_length (91) = maximum attention span +dropout = 0.0 +; Dropout (keep at 0 for RL stability initially) [env] -num_agents = 1024 +num_agents = 512 num_ego_agents = 512 ; Options: discrete, continuous action_type = discrete ; Options: classic, jerk dynamics_model = classic reward_vehicle_collision = -0.5 -reward_offroad_collision = -0.2 -reward_ade = 0.0 +reward_offroad_collision = -0.5 dt = 0.1 reward_goal = 1.0 reward_goal_post_respawn = 0.25 # in case of reward conditioning, we scale the goal_weight by this number for post respawn ; Meters around goal to be considered "reached" goal_radius = 2.0 -; What to do when goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop" -goal_behavior = 0 +; Max target speed in m/s for the agent to maintain towards the goal +goal_speed = 100.0 +; What to do when the goal is reached. Options: 0:"respawn", 1:"generate_new_goals", 2:"stop" +goal_behavior = 1 +; Determines the target distance to the new goal in the case of goal_behavior = generate_new_goals. +; Large numbers will select a goal point further away from the agent's current position. +goal_target_distance = 30.0 ; Options: 0 - Ignore, 1 - Stop, 2 - Remove collision_behavior = 0 ; Options: 0 - Ignore, 1 - Stop, 2 - Remove offroad_behavior = 0 -; Number of steps before reset +; Number of steps before scenario_length = 91 -resample_frequency = 182 +resample_frequency = 910 +termination_mode = 1 # 0 - terminate at episode_length, 1 - terminate after all agents have been reset +map_dir = "resources/drive/binaries/training" num_maps = 1000 -; Which step of the trajectory to initialize the agents at upon reset +; Determines which step of the trajectory to initialize the agents at upon reset init_steps = 0 -; Options: "control_vehicles", "control_agents", "control_tracks_to_predict", "control_sdc_only" +; Options: "control_vehicles", "control_agents", "control_wosac", "control_sdc_only" control_mode = "control_vehicles" ; Options: "created_all_valid", "create_only_controlled" init_mode = "create_all_valid" @@ -67,10 +85,9 @@ discount_weight_lb = 0.80 discount_weight_ub = 0.98 [env.co_player_policy] -enabled = False policy_name = Drive rnn_name = Recurrent -policy_path = "resources/drive/policies/varied_discount.pt" +policy_path = "pufferlib/resources/drive/policies/varied_discount.pt" input_size = 64 hidden_size = 256 @@ -93,6 +110,7 @@ discount_weight_lb = 0.98 discount_weight_ub = 0.80 [train] +seed=42 total_timesteps = 2_000_000_000 # learning_rate = 0.02 # gamma = 0.985 @@ -101,6 +119,8 @@ anneal_lr = True batch_size = auto minibatch_size = 32768 max_minibatch_size = 32768 +; minibatch_size = 256 +; max_minibatch_size = 256 bptt_horizon = 32 adam_beta1 = 0.9 adam_beta2 = 0.999 @@ -118,36 +138,42 @@ vf_clip_coef = 0.1999999999999999 vf_coef = 2 vtrace_c_clip = 1 vtrace_rho_clip = 1 -checkpoint_interval = 1000 +checkpoint_interval = 100 +use_transformer = True +context_window = 32 # Rendering options render = True -render_interval = 1000 +render_interval = 100 ; If True, show exactly what the agent sees in agent observation obs_only = True ; Show grid lines -show_grid = False +show_grid = True ; Draws lines from ego agent observed ORUs and road elements to show detection range show_lasers = False ; Display human xy logs in the background -show_human_logs = True -; Options: str to path (e.g., "resources/drive/binaries/map_001.bin"), None +show_human_logs = False +; If True, zoom in on a part of the map. Otherwise, show full map +zoom_in = True +; Options: List[str to path], str to path (e.g., "resources/drive/training/binaries/map_001.bin"), None render_map = none [eval] eval_interval = 1000 +; Path to dataset used for evaluation +map_dir = "resources/drive/binaries/training" +; Evaluation will run on the first num_maps maps in the map_dir directory +num_maps = 20 backend = PufferEnv -# WOSAC (Waymo Open Sim Agents Challenge) evaluation settings +; WOSAC (Waymo Open Sim Agents Challenge) evaluation settings ; If True, enables evaluation on realism metrics each time we save a checkpoint wosac_realism_eval = False ; Number of policy rollouts per scene wosac_num_rollouts = 32 ; When to start the simulation wosac_init_steps = 10 -; Total number of WOSAC agents to evaluate -wosac_num_agents = 256 -; Control the tracks to predict -wosac_control_mode = "control_tracks_to_predict" -; Initialize from the tracks to predict +; Control everything valid at init in the scene +wosac_control_mode = "control_wosac" +; Create everything in valid at init the scene wosac_init_mode = "create_all_valid" ; Stop when reaching the goal wosac_goal_behavior = 2 @@ -160,8 +186,6 @@ wosac_aggregate_results = True human_replay_eval = False ; Control only the self-driving car human_replay_control_mode = "control_sdc_only" -; This equals the number of scenarios, since we control one agent in each -human_replay_num_agents = 64 [sweep.train.learning_rate] distribution = log_normal @@ -174,10 +198,9 @@ scale = auto distribution = log_normal min = 0.001 mean = 0.005 -max = 0.01 +max = 0.03 scale = auto - [sweep.env.goal_radius] distribution = uniform min = 2.0 @@ -185,16 +208,18 @@ max = 20.0 mean = 10.0 scale = auto -[sweep.env.reward_ade] -distribution = uniform -min = -0.1 -max = 0.0 -mean = -0.02 +[sweep.train.gae_lambda] +distribution = log_normal +min = 0.95 +mean = 0.98 +max = 0.999 scale = auto -[sweep.env.reward_goal_post_respawn] -distribution = uniform -min = 0.0 -max = 1.0 -mean = 0.5 -scale = auto +[controlled_exp.train.goal_speed] +values = [10, 20, 30, 3] + +[controlled_exp.train.ent_coef] +values = [0.001, 0.005, 0.01] + +[controlled_exp.train.seed] +values = [42, 55, 1] diff --git a/pufferlib/models.py b/pufferlib/models.py index 0893a9db47..d0b438f9a3 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -7,6 +7,9 @@ import pufferlib.pytorch import pufferlib.spaces +import torch.nn.functional as F +import math + class Default(nn.Module): """Default PyTorch policy. Flattens obs and applies a linear layer. @@ -196,6 +199,205 @@ def forward(self, observations, state): return logits, values +class TransformerWrapper(nn.Module): # TransformerWrapper + def __init__( + self, + env, + policy, + input_size=128, + hidden_size=128, + num_layers=4, + num_heads=8, + context_length=512, + dropout=0.0, + ): + """Wraps your policy with a Transformer for temporal modeling. + + Args: + env: Environment instance + policy: Your Drive policy (must have encode_observations and decode_actions) + input_size: Size of encoded observations (from policy.encode_observations) + hidden_size: Transformer hidden dimension + num_layers: Number of transformer layers + num_heads: Number of attention heads + context_length: Maximum sequence length to attend over + dropout: Dropout probability + """ + super().__init__() + self.obs_shape = env.single_observation_space.shape + self.policy = policy + self.input_size = input_size + self.hidden_size = hidden_size + self.context_length = context_length + self.num_layers = num_layers + self.num_heads = num_heads + self.is_continuous = self.policy.is_continuous + + # Project encoded observations to transformer dimension if needed + if input_size != hidden_size: + self.input_projection = nn.Linear(input_size, hidden_size) + else: + self.input_projection = nn.Identity() + + # Learnable positional embeddings + self.positional_embedding = nn.Parameter(torch.zeros(1, context_length, hidden_size)) + nn.init.normal_(self.positional_embedding, std=0.02) + + # Transformer encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_size, + nhead=num_heads, + dim_feedforward=hidden_size * 2, + dropout=dropout, + activation="gelu", + batch_first=True, + norm_first=True, # Pre-LN architecture (more stable) + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # create cache for memory context + for T in [1, 2, 4, 8, 16, 32, 64, 91, 182, 273, 364, 455]: + mask = self.create_causal_mask(T, "cpu") + self.register_buffer(f"_causal_mask_{T}", mask, persistent=False) + + # Layer norm for output + self.output_norm = nn.LayerNorm(hidden_size) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + """Initialize weights similar to GPT-2""" + for name, param in self.named_parameters(): + if "layer_norm" in name or "layernorm" in name or "output_norm" in name: + continue + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name and param.ndim >= 2: + nn.init.orthogonal_(param, 1.0) + + def create_causal_mask(self, seq_len, device): + """Create causal attention mask""" + mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1) + return mask + + def get_causal_mask(self, T, device): + """Get cached causal mask or create new one""" + buffer_name = f"_causal_mask_{T}" + if hasattr(self, buffer_name): + return getattr(self, buffer_name).to(device) + return self.create_causal_mask(T, device) + + def create_episode_mask(self, terminals, seq_len): + """Episode mask which ensures that you arent attending over episode boundaries""" + B = terminals.shape[0] + device = terminals.device + + episode_ids = torch.nn.functional.pad(terminals[:, :-1], (1, 0)).cumsum(dim=1) + + mask_allow = episode_ids.unsqueeze(2) == episode_ids.unsqueeze(1) + + return torch.where(mask_allow, torch.zeros(1, device=device), torch.full((1,), float("-inf"), device=device)) + + def forward_eval(self, observations, state): + B = observations.shape[0] + device = observations.device + + hidden = self.policy.encode_observations(observations, state=state) + hidden = self.input_projection(hidden) + + if "transformer_context" not in state or state["transformer_context"] is None: + context = torch.zeros(B, self.context_length, self.hidden_size, device=device) + pos = torch.zeros(1, dtype=torch.long, device=device) + else: + context = state["transformer_context"] + pos = state.get("transformer_position", torch.zeros(1, dtype=torch.long, device=device)) + + if ( + context.shape[-1] != self.hidden_size + or context.shape[0] != B + or context.shape[1] != self.context_length + ): + context = torch.zeros(B, self.context_length, self.hidden_size, device=device) + pos = torch.zeros(1, dtype=torch.long, device=device) + + write_idx = (pos % self.context_length).long() + context[:, write_idx, :] = hidden.unsqueeze(1) + pos = pos + 1 + + pos_embed = self.positional_embedding[:, : self.context_length] + context_with_pos = context + pos_embed + + causal_mask = self.get_causal_mask(self.context_length, device) + + output = self.transformer(context_with_pos, mask=causal_mask, is_causal=True) + output = self.output_norm(output) + + read_idx = ((pos - 1) % self.context_length).long() + hidden_out = output[:, read_idx, :].squeeze(1) + + state["transformer_context"] = context + state["transformer_position"] = pos + state["hidden"] = hidden_out + + logits, values = self.policy.decode_actions(hidden_out) + return logits, values + + def forward(self, observations, state): + x = observations + device = x.device + + if x.ndim == len(self.obs_shape) + 1: + B, T = x.shape[0], 1 + elif x.ndim == len(self.obs_shape) + 2: + B, T = x.shape[:2] + else: + raise ValueError(f"Invalid input tensor shape: {x.shape}") + + x_flat = x.view(B * T, *self.obs_shape) + hidden = self.policy.encode_observations(x_flat, state) + + hidden = hidden.view(B, T, self.input_size) + hidden = self.input_projection(hidden) + + # Remove dynamic truncation - use clamp instead of if + T_actual = min(T, self.context_length) # Python int, fine + if T_actual < T: + hidden = hidden[:, -T_actual:] + T = T_actual + + hidden = hidden + self.positional_embedding[:, :T] + + use_episode_mask = "terminals" in state and state["terminals"] is not None + + if not use_episode_mask: + causal_mask = self.get_causal_mask(T, device) + hidden = self.transformer(hidden, mask=causal_mask, is_causal=True) + else: + terminals = state["terminals"] + if terminals.shape[1] > T: + terminals = terminals[:, -T:] + causal_mask = self.get_causal_mask(T, device) + episode_mask = self.create_episode_mask(terminals, T) + attn_mask = causal_mask.unsqueeze(0) + episode_mask + attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0) + hidden = self.transformer(hidden, mask=attn_mask, is_causal=False) + + hidden = self.output_norm(hidden) + flat_hidden = hidden.contiguous().view(B * T, self.hidden_size) + + logits, values = self.policy.decode_actions(flat_hidden) + values = values.view(B, T) + + # Use Python int for context_len - no sync + context_len = min(T, self.context_length) + state["hidden"] = hidden + state["transformer_context"] = hidden[:, -context_len:].detach() + state["transformer_position"] = torch.full((B,), context_len - 1, dtype=torch.long, device=device) + + return logits, values + + class Convolutional(nn.Module): def __init__( self, diff --git a/pufferlib/ocean/benchmark/evaluator.py b/pufferlib/ocean/benchmark/evaluator.py index 383384e623..dbf84c1906 100644 --- a/pufferlib/ocean/benchmark/evaluator.py +++ b/pufferlib/ocean/benchmark/evaluator.py @@ -625,7 +625,10 @@ class HumanReplayEvaluator: def __init__(self, config: Dict): self.config = config - self.sim_steps = 91 - self.config["env"]["init_steps"] + k_scenarios = self.config["env"].get("k_scenarios", 1) + scenario_length = self.config["env"].get("scenario_length", 91) + init_steps = self.config["env"].get("init_steps", 0) + self.sim_steps = scenario_length - init_steps def rollout(self, args, puffer_env, policy): """Roll out policy in env with human replays. Store statistics. @@ -635,14 +638,12 @@ def rollout(self, args, puffer_env, policy): the policy is with (static) human partners. Args: - args: Config dict with train settings (device, use_rnn, etc.) + args: Config dict with train settings (device, use_rnn, policy_architecture, etc.) puffer_env: PufferLib environment wrapper policy: Trained policy to evaluate Returns: - dict: Aggregated metrics including: - - avg_collisions_per_agent: Average collisions per agent - - avg_offroad_per_agent: Average offroad events per agent + dict: Aggregated metrics including delta metrics for adaptive agents """ import numpy as np import torch @@ -652,26 +653,72 @@ def rollout(self, args, puffer_env, policy): device = args["train"]["device"] obs, info = puffer_env.reset() - state = {} - if args["train"]["use_rnn"]: + + policy_architecture = args["train"].get("policy_architecture", "Recurrent") + k_scenarios = args["env"].get("k_scenarios", 1) + + if policy_architecture == "Recurrent": state = dict( lstm_h=torch.zeros(num_agents, policy.hidden_size, device=device), lstm_c=torch.zeros(num_agents, policy.hidden_size, device=device), ) + elif policy_architecture == "Transformer": + context_length = args["train"].get("context_window", 182) + state = dict( + transformer_context=torch.zeros(num_agents, context_length, policy.hidden_size, device=device), + transformer_position=torch.zeros(1, dtype=torch.long, device=device), + ) + else: + state = {} - for time_idx in range(self.sim_steps): - # Step policy - with torch.no_grad(): - ob_tensor = torch.as_tensor(obs).to(device) - logits, value = policy.forward_eval(ob_tensor, state) - action, logprob, _ = pufferlib.pytorch.sample_logits(logits) - action_np = action.cpu().numpy().reshape(puffer_env.action_space.shape) + collected_infos = [] + delta_metrics = None - if isinstance(logits, torch.distributions.Normal): - action_np = np.clip(action_np, puffer_env.action_space.low, puffer_env.action_space.high) + # Loop through scenarios + for scenario in range(k_scenarios): + for time_idx in range(self.sim_steps): + # Step policy + with torch.no_grad(): + ob_tensor = torch.as_tensor(obs).to(device) + logits, value = policy.forward_eval(ob_tensor, state) + action, logprob, _ = pufferlib.pytorch.sample_logits(logits) + action_np = action.cpu().numpy().reshape(puffer_env.action_space.shape) - obs, rewards, dones, truncs, info_list = puffer_env.step(action_np) + if isinstance(logits, torch.distributions.Normal): + action_np = np.clip(action_np, puffer_env.action_space.low, puffer_env.action_space.high) - if len(info_list) > 0: # Happens at the end of episode - results = info_list[0] - return results + obs, rewards, dones, truncs, info_list = puffer_env.step(action_np) + + # Reset transformer context on mid-scenario terminations (not at scenario boundaries) + if policy_architecture == "Transformer": + is_last_step = time_idx == self.sim_steps - 1 + if not is_last_step: + done_mask = dones | truncs + if done_mask.any(): + done_indices = np.where(done_mask)[0] + state["transformer_context"][done_indices] = 0.0 + + # Collect infos + if len(info_list) > 0: + for info_dict in info_list: + if isinstance(info_dict, dict): + if "ada_delta_score" in info_dict: + delta_metrics = info_dict + elif "score" in info_dict: + collected_infos.append(info_dict) + + # Return the last info dict which contains delta metrics for adaptive agents + if collected_infos: + metric_keys = collected_infos[0].keys() + aggregated = {} + for key in metric_keys: + values = [info.get(key, 0) for info in collected_infos] + aggregated[key] = np.mean(values) + + # Merge delta metrics if they exist + if delta_metrics: + aggregated.update(delta_metrics) + + return aggregated + + return {} diff --git a/pufferlib/ocean/drive/README.md b/pufferlib/ocean/drive/README.md deleted file mode 100644 index a37907b206..0000000000 --- a/pufferlib/ocean/drive/README.md +++ /dev/null @@ -1,108 +0,0 @@ -# PufferDrive - -This readme contains several important assumptions and definions about the `PufferDrive` environment. - -## Agent initialization and control - -### `init_mode` - -Determines which agents are **created** in the environment. - -| Option | Description | -| ------------------------ | ---------------------------------------------------------------------------- | -| `create_all_valid` | Create all entities valid at initialization (`traj_valid[init_steps] == 1`). | -| `create_only_controlled` | Create only those agents that are controlled by the policy. | - -### `control_mode` - -Determines which created agents are **controlled** by the policy. - -| Option | Description | -| ----------------------------------------- | ------------------------------------------------------------------------------------------------- | -| `control_vehicles` (default) | Control only valid **vehicles** (not experts, beyond `MIN_DISTANCE_TO_GOAL`, under `MAX_AGENTS`). | -| `control_agents` | Control all valid **agent types** (vehicles, cyclists, pedestrians). | -| `control_tracks_to_predict` *(WOMD only)* | Control agents listed in the `tracks_to_predict` metadata. | - - -## Termination conditions (`done`) - -Episodes are never truncated before reaching `episode_len`. The `goal_behavior` argument controls agent behavior after reaching a goal early: - -* **`goal_behavior=0` (default):** Agents respawn at their initial position after reaching their goal (last valid log position). -* **`goal_behavior=1`:** Agents receive new goals indefinitely after reaching each goal. -* **`goal_behavior=2`:** Agents stop after reaching their goal. - -## Logged performance metrics - -We record multiple performance metrics during training, aggregated over all *active agents* (alive and controlled). Key metrics include: - -- `score`: Goals reached cleanly (goal was achieved without collision or going off-road) -- `collision_rate`: Binary flag (0 or 1) if agent hit another vehicle. -- `offroad_rate`: Binary flag (0 or 1) if agent left road bounds. -- `completion_rate`: Whether the agent reached its goal in this episode (even if it collided or went off-road). - - -### Metric aggregation - -The `num_agents` parameter in `drive.ini` defines the total number of agents used to collect experience. -At runtime, **Puffer** uses `num_maps` to create enough environments to populate the buffer with `num_agents`, distributing them evenly across `num_envs`. - -Because agents are respawned immediately after reaching their goal, they remain active throughout the episode. - -At the end of each episode (i.e., when `timestep == TRAJECTORY_LENGTH`), metrics are logged once via: - -```C -if (env->timestep == TRAJECTORY_LENGTH) { - add_log(env); - c_reset(env); - return; -} -``` - -Metrics are normalized and aggregated in `vec_log` (`pufferlib/ocean/env_binding.h`). They are averaged over all active agents across all environments. For example, the aggregated collision rate is computed as: - -$$ -r^{agg}_{\text{collision}} = \frac{\mathbb{I}[\text{collided in episode}]}{N} -$$ - -where $N$ is the number of controlled agents. -This value represents the fraction of agents that collided at least once during the episode. So, cases **A** and **B** below would yield identical off-road and collision rates: - -![alt text](../../resources/drive/examples_a_b.png) - -Since these metrics do not capture *multiple* events per agent, we additionally log the **average number of collision and off-road events per episode**. This is computed as: - -$$ -c^{avg}_{\text{collision}} = \frac{\text{total number of collision events across all agents and environments}}{N} -$$ - -where $N$ is the total number of controlled agents. -For example, an `avg_collisions_per_agent` value of 4 indicates that, on average, each agent collides four times per episode. - -### Effect of respawning on metrics - -By default, agents are reset to their initial position when they reach their goal before the episode ends. Upon respawn, `respawn_timestep` is updated from `-1` to the current step index. - -This raises the question: **how does repeated respawning affect aggregated metrics?** - -To begin, note that the environment is a bit different before and after respawn. After an agent respawns, all other agents are "removed" from the environment. As a result, collisions with other agents cannot occur post-respawn. - -This effectively transforms the scenario into a single-agent environment, simplifying the task since the agent no longer needs to coordinate with others. - -![alt text](../../resources/drive/pre_and_post_respawn.png) - -#### `score` - -Consider an episode of 91 steps where an agent is initialized relatively close to the goal position and reaches its goal three times: - -1. **First attempt:** reaches the goal without collisions -2. **Second attempt:** reaches the goal without collisions -3. **Third attempt:** reaches the goal but goes off-road along the way - -![alt text](../../resources/drive/realistic_collision_event_post_respawn.png) - -The highlighted trajectory shows the first attempt. In this case, the recorded score is `0.0` — a single off-road event invalidates the score for the entire episode. This behavior is desired: the score metric is unforgiving. - -#### `offroad_rate` and `collision_rate` - -Same logic holds as above. diff --git a/pufferlib/ocean/drive/adaptive.py b/pufferlib/ocean/drive/adaptive.py index 184ce41577..7af77c62db 100644 --- a/pufferlib/ocean/drive/adaptive.py +++ b/pufferlib/ocean/drive/adaptive.py @@ -12,7 +12,12 @@ def __init__(self, **kwargs): kwargs["ini_file"] = "pufferlib/config/ocean/adaptive.ini" kwargs["adaptive_driving_agent"] = True + # Human replay mode: disable co-players, use human trajectories for other agents + human_replay_mode = kwargs.pop("human_replay_mode", False) + if human_replay_mode: + kwargs["co_player_enabled"] = False + kwargs["resample_frequency"] = self.k_scenarios * self.scenario_length self.episode_length = kwargs["resample_frequency"] - # print(f"resample frequency is ", kwargs["resample_frequency"], flush=True) + super().__init__(**kwargs) diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index 60c03814b7..419bef0b47 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -1,27 +1,29 @@ #define Env Drive #define MY_SHARED #define MY_PUT + +#include #include "binding.h" -static int my_put(Env* env, PyObject* args, PyObject* kwargs) { - PyObject* obs = PyDict_GetItemString(kwargs, "observations"); +static int my_put(Env *env, PyObject *args, PyObject *kwargs) { + PyObject *obs = PyDict_GetItemString(kwargs, "observations"); if (!PyObject_TypeCheck(obs, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Observations must be a NumPy array"); return 1; } - PyArrayObject* observations = (PyArrayObject*)obs; + PyArrayObject *observations = (PyArrayObject *)obs; if (!PyArray_ISCONTIGUOUS(observations)) { PyErr_SetString(PyExc_ValueError, "Observations must be contiguous"); return 1; } env->observations = PyArray_DATA(observations); - PyObject* act = PyDict_GetItemString(kwargs, "actions"); + PyObject *act = PyDict_GetItemString(kwargs, "actions"); if (!PyObject_TypeCheck(act, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Actions must be a NumPy array"); return 1; } - PyArrayObject* actions = (PyArrayObject*)act; + PyArrayObject *actions = (PyArrayObject *)act; if (!PyArray_ISCONTIGUOUS(actions)) { PyErr_SetString(PyExc_ValueError, "Actions must be contiguous"); return 1; @@ -32,12 +34,12 @@ static int my_put(Env* env, PyObject* args, PyObject* kwargs) { return 1; } - PyObject* rew = PyDict_GetItemString(kwargs, "rewards"); + PyObject *rew = PyDict_GetItemString(kwargs, "rewards"); if (!PyObject_TypeCheck(rew, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Rewards must be a NumPy array"); return 1; } - PyArrayObject* rewards = (PyArrayObject*)rew; + PyArrayObject *rewards = (PyArrayObject *)rew; if (!PyArray_ISCONTIGUOUS(rewards)) { PyErr_SetString(PyExc_ValueError, "Rewards must be contiguous"); return 1; @@ -48,12 +50,12 @@ static int my_put(Env* env, PyObject* args, PyObject* kwargs) { } env->rewards = PyArray_DATA(rewards); - PyObject* term = PyDict_GetItemString(kwargs, "terminals"); + PyObject *term = PyDict_GetItemString(kwargs, "terminals"); if (!PyObject_TypeCheck(term, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Terminals must be a NumPy array"); return 1; } - PyArrayObject* terminals = (PyArrayObject*)term; + PyArrayObject *terminals = (PyArrayObject *)term; if (!PyArray_ISCONTIGUOUS(terminals)) { PyErr_SetString(PyExc_ValueError, "Terminals must be contiguous"); return 1; @@ -66,23 +68,21 @@ static int my_put(Env* env, PyObject* args, PyObject* kwargs) { return 0; } -static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) { int population_play = unpack(kwargs, "population_play"); - if (population_play){ - return my_shared_population_play(self, args, kwargs); - } - else{ - return my_shared_self_play( self, args, kwargs); + if (population_play) { + return my_shared_population_play(self, args, kwargs); + } else { + return my_shared_self_play(self, args, kwargs); } - } -static int my_init(Env* env, PyObject* args, PyObject* kwargs) { +static int my_init(Env *env, PyObject *args, PyObject *kwargs) { env->human_agent_idx = unpack(kwargs, "human_agent_idx"); env->ini_file = unpack_str(kwargs, "ini_file"); env_init_config conf = {0}; - if(ini_parse(env->ini_file, handler, &conf) < 0) { + if (ini_parse(env->ini_file, handler, &conf) < 0) { printf("Error while loading %s", env->ini_file); } if (kwargs && PyDict_GetItemString(kwargs, "scenario_length")) { @@ -95,15 +95,16 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { env->action_type = conf.action_type; env->dynamics_model = conf.dynamics_model; if (PyDict_GetItemString(kwargs, "dynamics_model")) { - char* dynamics_str = unpack_str(kwargs, "dynamics_model"); + char *dynamics_str = unpack_str(kwargs, "dynamics_model"); env->dynamics_model = (strcmp(dynamics_str, "jerk") == 0) ? JERK : CLASSIC; } env->reward_vehicle_collision = conf.reward_vehicle_collision; env->reward_offroad_collision = conf.reward_offroad_collision; env->reward_goal = conf.reward_goal; env->reward_goal_post_respawn = conf.reward_goal_post_respawn; - env->reward_ade = conf.reward_ade; env->scenario_length = conf.scenario_length; + + env->termination_mode = conf.termination_mode; env->collision_behavior = conf.collision_behavior; env->offroad_behavior = conf.offroad_behavior; env->max_controlled_agents = unpack(kwargs, "max_controlled_agents"); @@ -127,10 +128,10 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { if (env->population_play) { env->num_co_players = unpack(kwargs, "num_co_players"); - double* co_player_ids_d = unpack_float_array(kwargs, "co_player_ids", &env->num_co_players); + double *co_player_ids_d = unpack_float_array(kwargs, "co_player_ids", &env->num_co_players); if (co_player_ids_d != NULL && env->num_co_players > 0) { - env->co_player_ids = (int*)malloc(env->num_co_players * sizeof(int)); + env->co_player_ids = (int *)malloc(env->num_co_players * sizeof(int)); if (env->co_player_ids == NULL) { fprintf(stderr, "Error: Failed to allocate memory for co_player_ids\n"); free(co_player_ids_d); @@ -152,9 +153,9 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { // Handle ego agents - always as an array env->num_ego_agents = unpack(kwargs, "num_ego_agents"); if (env->num_ego_agents > 0) { - double* ego_agent_ids_d = unpack_float_array(kwargs, "ego_agent_ids", &env->num_ego_agents); + double *ego_agent_ids_d = unpack_float_array(kwargs, "ego_agent_ids", &env->num_ego_agents); if (ego_agent_ids_d != NULL) { - env->ego_agent_ids = (int*)malloc(env->num_ego_agents * sizeof(int)); + env->ego_agent_ids = (int *)malloc(env->num_ego_agents * sizeof(int)); for (int i = 0; i < env->num_ego_agents; i++) { env->ego_agent_ids[i] = (int)ego_agent_ids_d[i]; } @@ -173,16 +174,18 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { env->ego_agent_ids = NULL; } - env->init_mode = (int)unpack(kwargs, "init_mode"); env->control_mode = (int)unpack(kwargs, "control_mode"); env->goal_behavior = (int)unpack(kwargs, "goal_behavior"); + env->goal_target_distance = (float)unpack(kwargs, "goal_target_distance"); env->goal_radius = (float)unpack(kwargs, "goal_radius"); + env->goal_speed = (float)unpack(kwargs, "goal_speed"); + char *map_dir = unpack_str(kwargs, "map_dir"); int map_id = unpack(kwargs, "map_id"); int max_agents = unpack(kwargs, "max_agents"); int init_steps = unpack(kwargs, "init_steps"); - char map_file[100]; - sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id); + char map_file[512]; + snprintf(map_file, sizeof(map_file), "%s/map_%03d.bin", map_dir, map_id); env->num_agents = max_agents; env->map_name = strdup(map_file); env->init_steps = init_steps; @@ -191,18 +194,21 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { return 0; } -static int my_log(PyObject* dict, Log* log) { +static int my_log(PyObject *dict, Log *log) { assign_to_dict(dict, "n", log->n); + assign_to_dict(dict, "score", log->score); assign_to_dict(dict, "offroad_rate", log->offroad_rate); - assign_to_dict(dict, "episode_length", log->episode_length); assign_to_dict(dict, "collision_rate", log->collision_rate); + assign_to_dict(dict, "episode_length", log->episode_length); assign_to_dict(dict, "episode_return", log->episode_return); assign_to_dict(dict, "dnf_rate", log->dnf_rate); - assign_to_dict(dict, "avg_displacement_error", log->avg_displacement_error); assign_to_dict(dict, "completion_rate", log->completion_rate); assign_to_dict(dict, "lane_alignment_rate", log->lane_alignment_rate); - assign_to_dict(dict, "score", log->score); - assign_to_dict(dict, "avg_offroad_per_agent", log->avg_offroad_per_agent); - assign_to_dict(dict, "avg_collisions_per_agent", log->avg_collisions_per_agent); + assign_to_dict(dict, "offroad_per_agent", log->offroad_per_agent); + assign_to_dict(dict, "collisions_per_agent", log->collisions_per_agent); + assign_to_dict(dict, "goals_sampled_this_episode", log->goals_sampled_this_episode); + assign_to_dict(dict, "goals_reached_this_episode", log->goals_reached_this_episode); + assign_to_dict(dict, "speed_at_goal", log->speed_at_goal); + // assign_to_dict(dict, "avg_displacement_error", log->avg_displacement_error); return 0; } diff --git a/pufferlib/ocean/drive/binding.h b/pufferlib/ocean/drive/binding.h index b5ec9ed65d..58d26d4a26 100644 --- a/pufferlib/ocean/drive/binding.h +++ b/pufferlib/ocean/drive/binding.h @@ -1,57 +1,70 @@ #include "drive.h" #include "../env_binding.h" -static PyObject* my_shared_self_play(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *my_shared_self_play(PyObject *self, PyObject *args, PyObject *kwargs) { + char *map_dir = unpack_str(kwargs, "map_dir"); int num_agents = unpack(kwargs, "num_agents"); int num_maps = unpack(kwargs, "num_maps"); int init_mode = unpack(kwargs, "init_mode"); int control_mode = unpack(kwargs, "control_mode"); int init_steps = unpack(kwargs, "init_steps"); + int goal_behavior = unpack(kwargs, "goal_behavior"); + float goal_target_distance = unpack(kwargs, "goal_target_distance"); + int use_all_maps = unpack(kwargs, "use_all_maps"); int max_controlled_agents = unpack(kwargs, "max_controlled_agents"); + printf("Generating environments for %d agents using %s maps from %s, num maps %d \n", num_agents, + use_all_maps ? "all" : "random", map_dir, num_maps); + fflush(stdout); + // Use current time and pid for randomness clock_gettime(CLOCK_REALTIME, &ts); - srand(ts.tv_nsec); + srand((unsigned int)(ts.tv_sec ^ ts.tv_nsec ^ getpid())); int total_agent_count = 0; int env_count = 0; - int max_envs = num_agents; + int max_envs = use_all_maps ? num_maps : num_agents; + int map_idx = 0; int maps_checked = 0; - PyObject* agent_offsets = PyList_New(max_envs+1); - PyObject* map_ids = PyList_New(max_envs); + PyObject *agent_offsets = PyList_New(max_envs + 1); + PyObject *map_ids = PyList_New(max_envs); // getting env count - while(total_agent_count < num_agents && env_count < max_envs){ - char map_file[100]; - int map_id = rand() % num_maps; - Drive* env = calloc(1, sizeof(Drive)); + while (use_all_maps ? map_idx < max_envs : total_agent_count < num_agents && env_count < max_envs) { + char map_file[512]; + int map_id = use_all_maps ? map_idx++ : rand() % num_maps; + Drive *env = calloc(1, sizeof(Drive)); env->init_mode = init_mode; + env->max_controlled_agents = max_controlled_agents; env->control_mode = control_mode; env->init_steps = init_steps; - env->max_controlled_agents = max_controlled_agents; - sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id); + env->goal_behavior = goal_behavior; + env->goal_target_distance = goal_target_distance; + snprintf(map_file, sizeof(map_file), "%s/map_%03d.bin", map_dir, map_id); env->entities = load_map_binary(map_file, env); set_active_agents(env); // Skip map if it doesn't contain any controllable agents - if(env->active_agent_count == 0) { - maps_checked++; - - // Safeguard: if we've checked all available maps and found no active agents, raise an error - if(maps_checked >= num_maps) { - for(int j=0;jnum_entities;j++) { - free_entity(&env->entities[j]); + if (env->active_agent_count == 0) { + if (!use_all_maps) { + maps_checked++; + + // Safeguard: if we've checked all available maps and found no active agents, raise an error + if (maps_checked >= num_maps) { + for (int j = 0; j < env->num_entities; j++) { + free_entity(&env->entities[j]); + } + free(env->entities); + free(env->active_agent_indices); + free(env->static_agent_indices); + free(env->expert_static_agent_indices); + free(env); + Py_DECREF(agent_offsets); + Py_DECREF(map_ids); + char error_msg[256]; + sprintf(error_msg, "No controllable agents found in any of the %d available maps", num_maps); + PyErr_SetString(PyExc_ValueError, error_msg); + return NULL; } - free(env->entities); - free(env->active_agent_indices); - free(env->static_agent_indices); - free(env->expert_static_agent_indices); - free(env); - Py_DECREF(agent_offsets); - Py_DECREF(map_ids); - char error_msg[256]; - sprintf(error_msg, "No controllable agents found in any of the %d available maps", num_maps); - PyErr_SetString(PyExc_ValueError, error_msg); - return NULL; } - for(int j=0;jnum_entities;j++) { + for (int j = 0; j < env->num_entities; j++) { free_entity(&env->entities[j]); } free(env->entities); @@ -60,17 +73,17 @@ static PyObject* my_shared_self_play(PyObject* self, PyObject* args, PyObject* k free(env->expert_static_agent_indices); free(env); continue; - } + } // Store map_id - PyObject* map_id_obj = PyLong_FromLong(map_id); + PyObject *map_id_obj = PyLong_FromLong(map_id); PyList_SetItem(map_ids, env_count, map_id_obj); // Store agent offset - PyObject* offset = PyLong_FromLong(total_agent_count); + PyObject *offset = PyLong_FromLong(total_agent_count); PyList_SetItem(agent_offsets, env_count, offset); total_agent_count += env->active_agent_count; env_count++; - for(int j=0;jnum_entities;j++) { + for (int j = 0; j < env->num_entities; j++) { free_entity(&env->entities[j]); } free(env->entities); @@ -79,26 +92,26 @@ static PyObject* my_shared_self_play(PyObject* self, PyObject* args, PyObject* k free(env->expert_static_agent_indices); free(env); } - //printf("Generated %d environments to cover %d agents (requested %d agents)\n", env_count, total_agent_count, num_agents); - if(total_agent_count >= num_agents){ + // printf("Generated %d environments to cover %d agents (requested %d agents)\n", env_count, total_agent_count, + // num_agents); + if (!use_all_maps && total_agent_count >= num_agents) { total_agent_count = num_agents; } - PyObject* final_total_agent_count = PyLong_FromLong(total_agent_count); + PyObject *final_total_agent_count = PyLong_FromLong(total_agent_count); PyList_SetItem(agent_offsets, env_count, final_total_agent_count); - PyObject* final_env_count = PyLong_FromLong(env_count); + PyObject *final_env_count = PyLong_FromLong(env_count); // resize lists - PyObject* resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1); - PyObject* resized_map_ids = PyList_GetSlice(map_ids, 0, env_count); - PyObject* tuple = PyTuple_New(3); + PyObject *resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1); + PyObject *resized_map_ids = PyList_GetSlice(map_ids, 0, env_count); + PyObject *tuple = PyTuple_New(3); PyTuple_SetItem(tuple, 0, resized_agent_offsets); PyTuple_SetItem(tuple, 1, resized_map_ids); PyTuple_SetItem(tuple, 2, final_env_count); return tuple; } - -static double* unpack_float_array(PyObject* kwargs, char* key, Py_ssize_t* out_size) { - PyObject* val = PyDict_GetItemString(kwargs, key); +static double *unpack_float_array(PyObject *kwargs, char *key, Py_ssize_t *out_size) { + PyObject *val = PyDict_GetItemString(kwargs, key); if (val == NULL) { char error_msg[100]; snprintf(error_msg, sizeof(error_msg), "Missing required keyword argument '%s'", key); @@ -123,15 +136,14 @@ static double* unpack_float_array(PyObject* kwargs, char* key, Py_ssize_t* out_s return NULL; } - double* array = (double*)malloc(size * sizeof(double)); + double *array = (double *)malloc(size * sizeof(double)); if (array == NULL) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate memory for float array"); return NULL; } - for (Py_ssize_t i = 0; i < size; i++) { - PyObject* item = PySequence_GetItem(val, i); + PyObject *item = PySequence_GetItem(val, i); if (item == NULL) { free(array); return NULL; @@ -168,19 +180,19 @@ static double* unpack_float_array(PyObject* kwargs, char* key, Py_ssize_t* out_s return array; } - -static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *my_shared_population_play(PyObject *self, PyObject *args, PyObject *kwargs) { + char *map_dir = unpack_str(kwargs, "map_dir"); int num_agents = unpack(kwargs, "num_agents"); int num_maps = unpack(kwargs, "num_maps"); int num_ego_agents = unpack(kwargs, "num_ego_agents"); - int init_mode = unpack(kwargs, "init_mode"); + int init_mode = unpack(kwargs, "init_mode"); int population_play = unpack(kwargs, "population_play"); int control_mode = unpack(kwargs, "control_mode"); int init_steps = unpack(kwargs, "init_steps"); int max_controlled_agents = unpack(kwargs, "max_controlled_agents"); int max_scenes_per_process = 0; - PyObject* max_envs_obj = PyDict_GetItemString(kwargs, "max_scenes_per_process"); + PyObject *max_envs_obj = PyDict_GetItemString(kwargs, "max_scenes_per_process"); if (max_envs_obj && PyLong_Check(max_envs_obj)) { long v = PyLong_AsLong(max_envs_obj); if (v > 0 && v <= INT_MAX) { @@ -188,17 +200,16 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj } } - // Use current time for randomness + // Use current time + PID for better randomness struct timespec ts; clock_gettime(CLOCK_REALTIME, &ts); - srand(ts.tv_nsec); + srand((unsigned int)(ts.tv_sec ^ ts.tv_nsec ^ getpid())); int num_coplayers = num_agents - num_ego_agents; - printf("Creating worlds for %d total agents (%d egos, %d co-players)\n", - num_agents, num_ego_agents, num_coplayers); + printf("Creating worlds for %d total agents (%d egos, %d co-players)\n", num_agents, num_ego_agents, num_coplayers); // Create shuffled agent role array (0 = coplayer, 1 = ego) - int* agent_roles = malloc(num_agents * sizeof(int)); + int *agent_roles = malloc(num_agents * sizeof(int)); for (int i = 0; i < num_ego_agents; i++) { agent_roles[i] = 1; // ego } @@ -218,23 +229,45 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj int env_count = 0; int total_egos_assigned = 0; int total_coplayers_assigned = 0; + int agent_role_index = 0; // Track position in agent_roles array int max_envs = num_agents; if (max_scenes_per_process > 0 && max_scenes_per_process < max_envs) { max_envs = max_scenes_per_process; } - PyObject* agent_offsets = PyList_New(max_envs + 1); - PyObject* map_ids = PyList_New(max_envs); - PyObject* ego_agent_ids = PyList_New(max_envs); - PyObject* coplayer_ids = PyList_New(max_envs); + PyObject *agent_offsets = PyList_New(max_envs + 1); + PyObject *map_ids = PyList_New(max_envs); + PyObject *ego_agent_ids = PyList_New(max_envs); + PyObject *coplayer_ids = PyList_New(max_envs); + + int consecutive_skips = 0; // Safety counter for infinite loop detection + int max_consecutive_skips = num_maps * 3; // Allow trying each map multiple times // Create worlds by randomly sampling maps while (total_agent_count < num_agents && env_count < max_envs) { + // Safety check: if we've skipped too many times in a row, something is wrong + if (consecutive_skips > max_consecutive_skips) { + fprintf(stderr, + "[shared_population_play] ERROR: Too many consecutive skips (%d). " + "All maps may have 0 active agents. agent_role_index=%d, total_agent_count=%d\n", + consecutive_skips, agent_role_index, total_agent_count); + + Py_DECREF(agent_offsets); + Py_DECREF(map_ids); + Py_DECREF(ego_agent_ids); + Py_DECREF(coplayer_ids); + free(agent_roles); + PyErr_Format(PyExc_RuntimeError, + "shared_population_play: unable to find maps with active agents after %d attempts", + consecutive_skips); + return NULL; + } + char map_file[100]; int map_id = rand() % num_maps; - Drive* env = calloc(1, sizeof(Drive)); - sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id); + Drive *env = calloc(1, sizeof(Drive)); + snprintf(map_file, sizeof(map_file), "%s/map_%03d.bin", map_dir, map_id); env->entities = load_map_binary(map_file, env); int remaining_capacity = num_agents - total_agent_count; @@ -250,15 +283,27 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj set_active_agents(env); + // CRITICAL FIX: Skip maps with 0 active agents + if (env->active_agent_count == 0) { + printf("Skipping map %d (0 active agents)\n", map_id); + for (int j = 0; j < env->num_entities; j++) { + free_entity(&env->entities[j]); + } + free(env->entities); + free(env->active_agent_indices); + free(env->static_agent_indices); + free(env->expert_static_agent_indices); + free(env); + consecutive_skips++; + continue; + } + int next_total = total_agent_count + env->active_agent_count; if (next_total > num_agents) { int remaining = num_agents - total_agent_count; - fprintf(stderr, - "[shared_population_play] ERROR oversubscribed agents: requested=%d remaining=%d map=%d\n", - env->active_agent_count, - remaining, - map_id); - for(int j=0; jnum_entities; j++) { + fprintf(stderr, "[shared_population_play] ERROR oversubscribed agents: requested=%d remaining=%d map=%d\n", + env->active_agent_count, remaining, map_id); + for (int j = 0; j < env->num_entities; j++) { free_entity(&env->entities[j]); } free(env->entities); @@ -272,25 +317,22 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj Py_DECREF(coplayer_ids); free(agent_roles); PyErr_Format(PyExc_RuntimeError, - "shared_population_play oversubscribed: total=%d target=%d map=%d active=%d", - next_total, - num_agents, - map_id, - env->active_agent_count); + "shared_population_play oversubscribed: total=%d target=%d map=%d active=%d", next_total, + num_agents, map_id, env->active_agent_count); return NULL; } // Store map_id - PyObject* map_id_obj = PyLong_FromLong(map_id); + PyObject *map_id_obj = PyLong_FromLong(map_id); PyList_SetItem(map_ids, env_count, map_id_obj); // Store agent offset - PyObject* offset = PyLong_FromLong(total_agent_count); + PyObject *offset = PyLong_FromLong(total_agent_count); PyList_SetItem(agent_offsets, env_count, offset); // Create ego and coplayer lists for this world - PyObject* ego_list = PyList_New(0); - PyObject* coplayer_list = PyList_New(0); + PyObject *ego_list = PyList_New(0); + PyObject *coplayer_list = PyList_New(0); int world_egos = 0; int world_coplayers = 0; @@ -298,9 +340,9 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj // Assign agents from the shuffled roles for (int a = 0; a < env->active_agent_count; a++) { - PyObject* agent_id = PyLong_FromLong(total_agent_count); + PyObject *agent_id = PyLong_FromLong(total_agent_count); - if (agent_roles[total_agent_count] == 1) { + if (agent_roles[agent_role_index] == 1) { // This agent is an ego PyList_Append(ego_list, agent_id); world_egos++; @@ -314,22 +356,25 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj Py_DECREF(agent_id); total_agent_count++; + agent_role_index++; } // Enforce constraint: must have at least 1 ego per world (if egos remain) if (world_egos == 0 && remaining_egos > 0) { - fprintf(stderr, - "[shared_population_play] WARNING: World %d has no ego agents but %d egos remain. Skipping world.\n", - env_count, remaining_egos); + fprintf( + stderr, + "[shared_population_play] WARNING: World %d has no ego agents but %d egos remain. Skipping world.\n", + env_count, remaining_egos); // Rollback the agent assignments for this world total_agent_count -= env->active_agent_count; total_coplayers_assigned -= world_coplayers; + agent_role_index -= env->active_agent_count; Py_DECREF(ego_list); Py_DECREF(coplayer_list); - for(int j=0; jnum_entities; j++) { + for (int j = 0; j < env->num_entities; j++) { free_entity(&env->entities[j]); } free(env->entities); @@ -337,18 +382,23 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj free(env->static_agent_indices); free(env->expert_static_agent_indices); free(env); + + consecutive_skips++; continue; // Try another map } + // Successfully created a world, reset skip counter + consecutive_skips = 0; + PyList_SetItem(ego_agent_ids, env_count, ego_list); PyList_SetItem(coplayer_ids, env_count, coplayer_list); - printf("World %d (map %d): %d agents (%d egos, %d co-players)\n", - env_count, map_id, env->active_agent_count, world_egos, world_coplayers); + printf("World %d (map %d): %d agents (%d egos, %d co-players)\n", env_count, map_id, env->active_agent_count, + world_egos, world_coplayers); env_count++; - for(int j=0; jnum_entities; j++) { + for (int j = 0; j < env->num_entities; j++) { free_entity(&env->entities[j]); } free(env->entities); @@ -362,15 +412,15 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj total_agent_count = num_agents; } - PyObject* final_total_agent_count = PyLong_FromLong(total_agent_count); + PyObject *final_total_agent_count = PyLong_FromLong(total_agent_count); PyList_SetItem(agent_offsets, env_count, final_total_agent_count); - PyObject* final_env_count = PyLong_FromLong(env_count); + PyObject *final_env_count = PyLong_FromLong(env_count); // Resize lists - PyObject* resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1); - PyObject* resized_map_ids = PyList_GetSlice(map_ids, 0, env_count); - PyObject* resized_ego_ids = PyList_GetSlice(ego_agent_ids, 0, env_count); - PyObject* resized_coplayer_ids = PyList_GetSlice(coplayer_ids, 0, env_count); + PyObject *resized_agent_offsets = PyList_GetSlice(agent_offsets, 0, env_count + 1); + PyObject *resized_map_ids = PyList_GetSlice(map_ids, 0, env_count); + PyObject *resized_ego_ids = PyList_GetSlice(ego_agent_ids, 0, env_count); + PyObject *resized_coplayer_ids = PyList_GetSlice(coplayer_ids, 0, env_count); Py_DECREF(agent_offsets); Py_DECREF(map_ids); @@ -381,15 +431,15 @@ static PyObject* my_shared_population_play(PyObject* self, PyObject* args, PyObj free(agent_roles); // Create a tuple - PyObject* tuple = PyTuple_New(5); + PyObject *tuple = PyTuple_New(5); PyTuple_SetItem(tuple, 0, resized_agent_offsets); PyTuple_SetItem(tuple, 1, resized_map_ids); PyTuple_SetItem(tuple, 2, final_env_count); PyTuple_SetItem(tuple, 3, resized_ego_ids); PyTuple_SetItem(tuple, 4, resized_coplayer_ids); - printf("Total: %d agents across %d worlds (egos: %d, co-players: %d)\n", - total_agent_count, env_count, total_egos_assigned, total_coplayers_assigned); + printf("Total: %d agents across %d worlds (egos: %d, co-players: %d)\n", total_agent_count, env_count, + total_egos_assigned, total_coplayers_assigned); return tuple; } diff --git a/pufferlib/ocean/drive/drive.c b/pufferlib/ocean/drive/drive.c index 5e193c0371..9f6337051c 100644 --- a/pufferlib/ocean/drive/drive.c +++ b/pufferlib/ocean/drive/drive.c @@ -1,7 +1,5 @@ -#include "drive.h" #include "drivenet.h" #include -#include "../env_config.h" // Use this test if the network changes to ensure that the forward pass // matches the torch implementation to the 3rd or ideally 4th decimal place @@ -10,22 +8,22 @@ void test_drivenet() { int num_actions = 2; int num_agents = 4; - float* observations = calloc(num_agents*num_obs, sizeof(float)); - for (int i = 0; i < num_obs*num_agents; i++) { + float *observations = calloc(num_agents * num_obs, sizeof(float)); + for (int i = 0; i < num_obs * num_agents; i++) { observations[i] = i % 7; } - int* actions = calloc(num_agents*num_actions, sizeof(int)); + int *actions = calloc(num_agents * num_actions, sizeof(int)); - //Weights* weights = load_weights("resources/drive/puffer_drive_weights.bin"); - Weights* weights = load_weights("puffer_drive_weights.bin"); - DriveNet* net = init_drivenet(weights, num_agents, CLASSIC, false, false, false); + // Weights* weights = load_weights("resources/drive/puffer_drive_weights.bin"); + Weights *weights = load_weights("puffer_drive_weights.bin"); + DriveNet *net = init_drivenet(weights, num_agents, CLASSIC, false, false, false); forward(net, observations, actions); - for (int i = 0; i < num_agents*num_actions; i++) { + for (int i = 0; i < num_agents * num_actions; i++) { printf("idx: %d, action: %d, logits:", i, actions[i]); for (int j = 0; j < num_actions; j++) { - printf(" %.6f", net->actor->output[i*num_actions + j]); + printf(" %.6f", net->actor->output[i * num_actions + j]); } printf("\n"); } @@ -34,14 +32,10 @@ void test_drivenet() { } void demo() { - // Read configuration from INI file - env_init_config conf = {0}; - const char* ini_file = "pufferlib/config/ocean/drive.ini"; - if(ini_parse(ini_file, handler, &conf) < 0) { - fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file); - exit(1); - } + // Note: The settings below are hardcoded for demo purposes. Since the policy was + // trained with these exact settings, that changing them may lead to + // weird behavior. Drive env = { .human_agent_idx = 0, .dynamics_model = conf.dynamics_model, @@ -50,7 +44,7 @@ void demo() { .reward_ade = conf.reward_ade, .goal_radius = conf.goal_radius, .dt = conf.dt, - .map_name = "resources/drive/binaries/map_000.bin", + .map_name = "resources/drive/binaries/training/map_000.bin", .init_steps = conf.init_steps, .collision_behavior = conf.collision_behavior, .offroad_behavior = conf.offroad_behavior, @@ -58,50 +52,71 @@ void demo() { allocate(&env); c_reset(&env); c_render(&env); - Weights* weights = load_weights("resources/drive/puffer_drive_weights.bin"); - DriveNet* net = init_drivenet(weights, env.active_agent_count, env.dynamics_model, false, false, false); - //Client* client = make_client(&env); + Weights *weights = load_weights("resources/drive/puffer_drive_weights_carla_town12.bin"); + DriveNet *net = init_drivenet(weights, env.active_agent_count, env.dynamics_model, false, false, false); + int accel_delta = 2; int steer_delta = 4; while (!WindowShouldClose()) { - // Handle camera controls - int (*actions)[2] = (int(*)[2])env.actions; - forward(net, env.observations, env.actions); + int *actions = (int *)env.actions; // Single integer per agent + + forward(net, env.observations, actions); + if (IsKeyDown(KEY_LEFT_SHIFT)) { - actions[env.human_agent_idx][0] = 3; - actions[env.human_agent_idx][1] = 6; - if(IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)){ - actions[env.human_agent_idx][0] += accel_delta; - // Cap acceleration to maximum of 6 - if(actions[env.human_agent_idx][0] > 6) { - actions[env.human_agent_idx][0] = 6; + if (env.dynamics_model == CLASSIC) { + // Classic dynamics: acceleration and steering + int accel_idx = 3; // neutral (0 m/s²) + int steer_idx = 6; // neutral (0.0 steering) + + if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) { + accel_idx += accel_delta; + if (accel_idx > 6) + accel_idx = 6; } - } - if(IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)){ - actions[env.human_agent_idx][0] -= accel_delta; - // Cap acceleration to minimum of 0 - if(actions[env.human_agent_idx][0] < 0) { - actions[env.human_agent_idx][0] = 0; + if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) { + accel_idx -= accel_delta; + if (accel_idx < 0) + accel_idx = 0; } - } - if(IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)){ - actions[env.human_agent_idx][1] += steer_delta; - // Cap steering to minimum of 0 - if(actions[env.human_agent_idx][1] < 0) { - actions[env.human_agent_idx][1] = 0; + if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) { + steer_idx += steer_delta; // Increase steering index for left turn + if (steer_idx > 12) + steer_idx = 12; } - } - if(IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)){ - actions[env.human_agent_idx][1] -= steer_delta; - // Cap steering to maximum of 12 - if(actions[env.human_agent_idx][1] > 12) { - actions[env.human_agent_idx][1] = 12; + if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) { + steer_idx -= steer_delta; // Decrease steering index for right turn + if (steer_idx < 0) + steer_idx = 0; } - } - if(IsKeyPressed(KEY_TAB)){ - env.human_agent_idx = (env.human_agent_idx + 1) % env.active_agent_count; + + // Encode into single integer: action = accel_idx * 13 + steer_idx + actions[env.human_agent_idx] = accel_idx * 13 + steer_idx; + + } else if (env.dynamics_model == JERK) { + // Jerk dynamics: longitudinal and lateral jerk + // JERK_LONG[4] = {-15.0f, -4.0f, 0.0f, 4.0f} + // JERK_LAT[3] = {-4.0f, 0.0f, 4.0f} + int jerk_long_idx = 2; // neutral (0.0) + int jerk_lat_idx = 1; // neutral (0.0) + + if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) { + jerk_long_idx = 3; // acceleration (4.0) + } + if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) { + jerk_long_idx = 0; // hard braking (-15.0) + } + if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) { + jerk_lat_idx = 2; // left turn (4.0) + } + if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) { + jerk_lat_idx = 0; // right turn (-4.0) + } + + // Encode into single integer: action = jerk_long_idx * 3 + jerk_lat_idx + actions[env.human_agent_idx] = jerk_long_idx * 3 + jerk_lat_idx; } } + c_step(&env); c_render(&env); } @@ -113,25 +128,15 @@ void demo() { } void performance_test() { - // Read configuration from INI file - env_init_config conf = {0}; - const char* ini_file = "pufferlib/config/ocean/drive.ini"; - if(ini_parse(ini_file, handler, &conf) < 0) { - fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file); - exit(1); - } long test_time = 10; Drive env = { .human_agent_idx = 0, - .dynamics_model = conf.dynamics_model, - .reward_vehicle_collision = conf.reward_vehicle_collision, - .reward_offroad_collision = conf.reward_offroad_collision, - .reward_ade = conf.reward_ade, - .goal_radius = conf.goal_radius, - .dt = conf.dt, - .map_name = "resources/drive/binaries/map_000.bin", - .init_steps = conf.init_steps, + .dynamics_model = CLASSIC, // Classic dynamics + .action_type = 0, // Discrete + .map_name = "resources/drive/binaries/training/map_000.bin", + .dt = 0.1f, + .init_steps = 0, }; clock_t start_time, end_time; double cpu_time_used; @@ -139,33 +144,33 @@ void performance_test() { allocate(&env); c_reset(&env); end_time = clock(); - cpu_time_used = ((double) (end_time - start_time)) / CLOCKS_PER_SEC; + cpu_time_used = ((double)(end_time - start_time)) / CLOCKS_PER_SEC; printf("Init time: %f\n", cpu_time_used); long start = time(NULL); int i = 0; - int (*actions)[2] = (int(*)[2])env.actions; + int (*actions)[2] = (int (*)[2])env.actions; while (time(NULL) - start < test_time) { // Set random actions for all agents - for(int j = 0; j < env.active_agent_count; j++) { + for (int j = 0; j < env.active_agent_count; j++) { int accel = rand() % 7; int steer = rand() % 13; - actions[j][0] = accel; // -1, 0, or 1 - actions[j][1] = steer; // Random steering + actions[j][0] = accel; // -1, 0, or 1 + actions[j][1] = steer; // Random steering } c_step(&env); i++; } long end = time(NULL); - printf("SPS: %ld\n", (i*env.active_agent_count) / (end - start)); + printf("SPS: %ld\n", (i * env.active_agent_count) / (end - start)); free_allocated(&env); } int main() { - //performance_test(); + // performance_test(); demo(); - //test_drivenet(); + // test_drivenet(); return 0; } diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index cb112b9a7e..009b2d6206 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -37,7 +37,7 @@ // Control modes #define CONTROL_VEHICLES 0 #define CONTROL_AGENTS 1 -#define CONTROL_TRACKS_TO_PREDICT 2 +#define CONTROL_WOSAC 2 #define CONTROL_SDC_ONLY 3 // Minimum distance to goal position @@ -64,7 +64,9 @@ // Grid cell size #define GRID_CELL_SIZE 5.0f -#define MAX_ENTITIES_PER_CELL 30 // Depends on resolution of data Formula: 3 * (2 + GRID_CELL_SIZE*sqrt(2)/resolution) => For each entity type in gridmap, diagonal poly-lines -> sqrt(2), include diagonal ends -> 2 +#define MAX_ENTITIES_PER_CELL \ + 30 // Depends on resolution of data Formula: 3 * (2 + GRID_CELL_SIZE*sqrt(2)/resolution) => For each entity type in + // gridmap, diagonal poly-lines -> sqrt(2), include diagonal ends -> 2 // Max road segment observation entities #define MAX_ROAD_SEGMENT_OBSERVATIONS 200 @@ -86,34 +88,55 @@ #define STOP_AGENT 1 #define REMOVE_AGENT 2 -//GOAL BEHAVIOUR +// GOAL BEHAVIOUR #define GOAL_RESPAWN 0 #define GOAL_GENERATE_NEW 1 #define GOAL_STOP 2 +#define PARTNER_FEATURES 7 + +#define ROAD_FEATURES 7 +#define ROAD_FEATURES_ONEHOT 13 +#define PARTNER_FEATURES 7 + +// Ego features depend on dynamics model +#define EGO_FEATURES_CLASSIC 7 +#define EGO_FEATURES_JERK 10 + // Jerk action space (for JERK dynamics model) static const float JERK_LONG[4] = {-15.0f, -4.0f, 0.0f, 4.0f}; static const float JERK_LAT[3] = {-4.0f, 0.0f, 4.0f}; // Classic action space (for CLASSIC dynamics model) static const float ACCELERATION_VALUES[7] = {-4.0000f, -2.6670f, -1.3330f, -0.0000f, 1.3330f, 2.6670f, 4.0000f}; -static const float STEERING_VALUES[13] = {-1.000f, -0.833f, -0.667f, -0.500f, -0.333f, -0.167f, 0.000f, 0.167f, 0.333f, 0.500f, 0.667f, 0.833f, 1.000f}; +static const float STEERING_VALUES[13] = {-1.000f, -0.833f, -0.667f, -0.500f, -0.333f, -0.167f, 0.000f, + 0.167f, 0.333f, 0.500f, 0.667f, 0.833f, 1.000f}; static const float offsets[4][2] = { - {-1, 1}, // top-left - {1, 1}, // top-right - {1, -1}, // bottom-right - {-1, -1} // bottom-left - }; + {-1, 1}, // top-left + {1, 1}, // top-right + {1, -1}, // bottom-right + {-1, -1} // bottom-left +}; static const int collision_offsets[25][2] = { - {-2, -2}, {-1, -2}, {0, -2}, {1, -2}, {2, -2}, // Top row - {-2, -1}, {-1, -1}, {0, -1}, {1, -1}, {2, -1}, // Second row - {-2, 0}, {-1, 0}, {0, 0}, {1, 0}, {2, 0}, // Middle row (including center) - {-2, 1}, {-1, 1}, {0, 1}, {1, 1}, {2, 1}, // Fourth row - {-2, 2}, {-1, 2}, {0, 2}, {1, 2}, {2, 2} // Bottom row + {-2, -2}, {-1, -2}, {0, -2}, {1, -2}, {2, -2}, // Top row + {-2, -1}, {-1, -1}, {0, -1}, {1, -1}, {2, -1}, // Second row + {-2, 0}, {-1, 0}, {0, 0}, {1, 0}, {2, 0}, // Middle row (including center) + {-2, 1}, {-1, 1}, {0, 1}, {1, 1}, {2, 1}, // Fourth row + {-2, 2}, {-1, 2}, {0, 2}, {1, 2}, {2, 2} // Bottom row }; +const Color STONE_GRAY = (Color){80, 80, 80, 255}; +const Color PUFF_RED = (Color){187, 0, 0, 255}; +const Color PUFF_CYAN = (Color){0, 187, 187, 255}; +const Color PUFF_WHITE = (Color){241, 241, 241, 241}; +const Color PUFF_BACKGROUND = (Color){6, 24, 24, 255}; +const Color PUFF_BACKGROUND2 = (Color){18, 72, 72, 255}; +const Color LIGHTGREEN = (Color){152, 255, 152, 255}; +const Color LIGHTYELLOW = (Color){255, 255, 152, 255}; +const Color SOFT_YELLOW = (Color){245, 245, 220, 255}; + struct timespec ts; typedef struct Drive Drive; @@ -122,20 +145,22 @@ typedef struct Log Log; typedef struct Graph Graph; typedef struct AdjListNode AdjListNode; typedef struct Co_Player_Log Co_Player_Log; -typedef struct Adaptive_Agent_Log Adaptive_Agent_Log; struct Log { float episode_return; float episode_length; float score; + float goals_reached_this_episode; + float goals_sampled_this_episode; float offroad_rate; float collision_rate; - float num_goals_reached; float completion_rate; + float offroad_per_agent; + float collisions_per_agent; float dnf_rate; float n; float lane_alignment_rate; - float avg_displacement_error; + float speed_at_goal; float active_agent_count; float expert_static_agent_count; float static_agent_count; @@ -145,8 +170,6 @@ struct Log { float avg_goal_weight; float avg_entropy_weight; float avg_discount_weight; - float avg_offroad_per_agent; - float avg_collisions_per_agent; }; typedef struct Entity Entity; @@ -155,14 +178,14 @@ struct Entity { int type; int id; int array_size; - float* traj_x; - float* traj_y; - float* traj_z; - float* traj_vx; - float* traj_vy; - float* traj_vz; - float* traj_heading; - int* traj_valid; + float *traj_x; + float *traj_y; + float *traj_z; + float *traj_vx; + float *traj_vy; + float *traj_vz; + float *traj_heading; + int *traj_valid; float width; float length; float height; @@ -173,7 +196,7 @@ struct Entity { float init_goal_y; int mark_as_expert; int collision_state; - float metrics_array[5]; // metrics_array: [collision, offroad, reached_goal, lane_aligned, avg_displacement_error] + float metrics_array[5]; // metrics_array: [collision, offroad, reached_goal, lane_aligned float x; float y; float z; @@ -188,9 +211,9 @@ struct Entity { int respawn_timestep; int respawn_count; int collided_before_goal; - int sampled_new_goal; - int reached_goal_this_episode; - int num_goals_reached; + float goals_reached_this_episode; + float goals_sampled_this_episode; + int current_goal_reached; int active_agent; float cumulative_displacement; int displacement_sample_count; @@ -206,12 +229,12 @@ struct Entity { float steering_angle; float wheelbase; - //population play + // population play bool is_ego; bool is_co_player; }; -void free_entity(Entity* entity){ +void free_entity(Entity *entity) { // free trajectory arrays free(entity->traj_x); free(entity->traj_y); @@ -222,22 +245,6 @@ void free_entity(Entity* entity){ free(entity->traj_heading); free(entity->traj_valid); } -struct Co_Player_Log { - float co_player_episode_return; - float co_player_episode_length; - float co_player_perf; - float co_player_score; - float co_player_offroad_rate; - float co_player_collision_rate; - float co_player_clean_collision_rate; - float co_player_num_goals_reached; - float co_player_completion_rate; - float co_player_dnf_rate; - float co_player_lane_alignment_rate; - float co_player_avg_displacement_error; - float co_player_n; -}; - // Utility functions float compute_delta_percent(float first, float last) { @@ -247,51 +254,26 @@ float compute_delta_percent(float first, float last) { return (last - first) / first * 100.0f; } -float relative_distance(float a, float b){ +float relative_distance(float a, float b) { float distance = sqrtf(powf(a - b, 2)); return distance; } -float relative_distance_2d(float x1, float y1, float x2, float y2){ +float relative_distance_2d(float x1, float y1, float x2, float y2) { float dx = x2 - x1; float dy = y2 - y1; - float distance = sqrtf(dx*dx + dy*dy); + float distance = sqrtf(dx * dx + dy * dy); return distance; } float clip(float value, float min, float max) { - if (value < min) return min; - if (value > max) return max; + if (value < min) + return min; + if (value > max) + return max; return value; } -float compute_displacement_error(Entity* agent, int timestep) { - // Check if timestep is within valid range - if (timestep < 0 || timestep >= agent->array_size) { - return 0.0f; - } - - // Check if reference trajectory is valid at this timestep - if (!agent->traj_valid[timestep]) { - return 0.0f; - } - - // Get reference position at current timestep, skip invalid ones - float ref_x = agent->traj_x[timestep]; - float ref_y = agent->traj_y[timestep]; - - if (ref_x == INVALID_POSITION || ref_y == INVALID_POSITION) { - return 0.0f; - } - - // Compute deltas: Euclidean distance between actual and reference position - float dx = agent->x - ref_x; - float dy = agent->y - ref_y; - float displacement = sqrtf(dx*dx + dy*dy); - - return displacement; -} - typedef struct GridMapEntity GridMapEntity; struct GridMapEntity { int entity_idx; @@ -308,64 +290,64 @@ struct GridMap { int grid_rows; int cell_size_x; int cell_size_y; - int* cell_entities_count; // number of entities in each cell of the GridMap - GridMapEntity** cells; // list of gridEntities in each cell of the GridMap - + int *cell_entities_count; // number of entities in each cell of the GridMap + GridMapEntity **cells; // list of gridEntities in each cell of the GridMap // Extras/Optimizations int vision_range; - int* neighbor_cache_count; // number of entities in each cells neighbor cache - GridMapEntity** neighbor_cache_entities; // preallocated array to hold neighbor entities + int *neighbor_cache_count; // number of entities in each cells neighbor cache + GridMapEntity **neighbor_cache_entities; // preallocated array to hold neighbor entities }; struct Drive { - Client* client; - float* observations; - float* actions; - float* rewards; - unsigned char* terminals; + Client *client; + float *observations; + float *actions; + float *rewards; + unsigned char *terminals; Log log; - Log* logs; + Log *logs; int num_agents; int active_agent_count; - int* active_agent_indices; + int *active_agent_indices; int action_type; int human_agent_idx; - Entity* entities; - Graph* topology_graph; + Entity *entities; int num_entities; int num_actors; int num_objects; int num_roads; int static_agent_count; - int* static_agent_indices; + int *static_agent_indices; int expert_static_agent_count; - int* expert_static_agent_indices; + int *expert_static_agent_indices; int timestep; int init_steps; int dynamics_model; - GridMap* grid_map; - int* neighbor_offsets; + GridMap *grid_map; + int *neighbor_offsets; int scenario_length; + int termination_mode; float reward_vehicle_collision; float reward_offroad_collision; - float reward_ade; - char* map_name; + char *map_name; float world_mean_x; float world_mean_y; float dt; float reward_goal; float reward_goal_post_respawn; float goal_radius; + float goal_speed; int max_controlled_agents; int logs_capacity; int goal_behavior; - char* ini_file; - char* scenario_id; + float goal_target_distance; + char *ini_file; + char *scenario_id; int collision_behavior; int offroad_behavior; int sdc_track_index; int num_tracks_to_predict; - int* tracks_to_predict_indices; + int *tracks_to_predict_indices; int init_mode; int control_mode; @@ -377,60 +359,74 @@ struct Drive { float offroad_weight_ub; float goal_weight_lb; float goal_weight_ub; - float* collision_weights; - float* offroad_weights; - float* goal_weights; + float *collision_weights; + float *offroad_weights; + float *goal_weights; // Entropy conditioning bool use_ec; float entropy_weight_lb; float entropy_weight_ub; - float* entropy_weights; + float *entropy_weights; // Discount conditioning bool use_dc; float discount_weight_lb; float discount_weight_ub; - float* discount_weights; - //fixed population play - Co_Player_Log co_player_log; - Co_Player_Log* co_player_logs; + float *discount_weights; + // fixed population play + Log co_player_log; + Log *co_player_logs; int num_co_players; int num_ego_agents; - int* co_player_ids; - int* ego_agent_ids; + int *co_player_ids; + int *ego_agent_ids; bool population_play; - - }; -void add_log(Drive* env) { - +void add_log(Drive *env) { for (int i = 0; i < env->active_agent_count; i++) { - Entity* e = &env->entities[env->active_agent_indices[i]]; + Entity *e = &env->entities[env->active_agent_indices[i]]; - if (e->is_ego) { - // ALWAYS update regular logs for all ego agents - if (e->reached_goal_this_episode) - env->log.completion_rate += 1.0f; + // Common metrics for all agents + env->log.goals_reached_this_episode += e->goals_reached_this_episode; + env->log.goals_sampled_this_episode += e->goals_sampled_this_episode; + if (e->is_ego) { + // EGO agent logging int offroad = env->logs[i].offroad_rate; env->log.offroad_rate += offroad; int collided = env->logs[i].collision_rate; env->log.collision_rate += collided; - int num_goals_reached = env->logs[i].num_goals_reached; - env->log.num_goals_reached += num_goals_reached; + float offroad_per_agent = env->logs[i].offroad_per_agent; + env->log.offroad_per_agent += offroad_per_agent; + float collisions_per_agent = env->logs[i].collisions_per_agent; + env->log.collisions_per_agent += collisions_per_agent; + + float frac_goal_reached = e->goals_reached_this_episode / e->goals_sampled_this_episode; + + // Calculate threshold based on goals sampled + float threshold = 0.99f; // Default threshold for 1 goal + if (e->goals_sampled_this_episode == 2.0f) { + threshold = 0.5f; // Require ≥50% completion for 2 goals + } else if (e->goals_sampled_this_episode < 5.0f) { + threshold = 0.8f; // Require ≥80% completion for 3-4 goals + } else { + threshold = 0.9f; // Require ≥90% completion for 5+ goals + } + + int collision_occurred = + (env->goal_behavior == GOAL_RESPAWN) ? e->collided_before_goal : env->logs[i].collision_rate; - if (e->reached_goal_this_episode && !e->collided_before_goal) { + if (frac_goal_reached > threshold && !collision_occurred) { env->log.score += 1.0f; } - if (!offroad && !collided && !e->reached_goal_this_episode) { + if (!offroad && !collided && frac_goal_reached < 1.0f) { env->log.dnf_rate += 1.0f; } int lane_aligned = env->logs[i].lane_alignment_rate; env->log.lane_alignment_rate += lane_aligned; - float displacement_error = env->logs[i].avg_displacement_error; - env->log.avg_displacement_error += displacement_error; + env->log.speed_at_goal += env->logs[i].speed_at_goal; env->log.episode_length += env->logs[i].episode_length; env->log.episode_return += env->logs[i].episode_return; @@ -438,81 +434,87 @@ void add_log(Drive* env) { env->log.expert_static_agent_count += env->expert_static_agent_count; env->log.static_agent_count += env->static_agent_count; env->log.n += 1.0f; - } - // Process co-player agents (separate if, not else-if!) + if (e->is_co_player && env->co_player_logs != NULL) { - if (e->reached_goal_this_episode) - env->co_player_log.co_player_completion_rate += 1.0f; - - int co_offroad = env->co_player_logs[i].co_player_offroad_rate; - env->co_player_log.co_player_offroad_rate += co_offroad; - int co_collided = env->co_player_logs[i].co_player_collision_rate; - env->co_player_log.co_player_collision_rate += co_collided; - int co_num_goals_reached = env->co_player_logs[i].co_player_num_goals_reached; - env->co_player_log.co_player_num_goals_reached += co_num_goals_reached; - - env->co_player_log.co_player_clean_collision_rate += - env->co_player_logs[i].co_player_clean_collision_rate; - - if (e->reached_goal_this_episode && !e->collided_before_goal) { - env->co_player_log.co_player_score += 1.0f; - env->co_player_log.co_player_perf += 1.0f; + int co_offroad = env->co_player_logs[i].offroad_rate; + env->co_player_log.offroad_rate += co_offroad; + int co_collided = env->co_player_logs[i].collision_rate; + env->co_player_log.collision_rate += co_collided; + float co_offroad_per_agent = env->co_player_logs[i].offroad_per_agent; + env->co_player_log.offroad_per_agent += co_offroad_per_agent; + float co_collisions_per_agent = env->co_player_logs[i].collisions_per_agent; + env->co_player_log.collisions_per_agent += co_collisions_per_agent; + + float co_frac_goal_reached = e->goals_reached_this_episode / e->goals_sampled_this_episode; + + // Calculate threshold for co-players + float co_threshold = 0.99f; + if (e->goals_sampled_this_episode == 2.0f) { + co_threshold = 0.5f; + } else if (e->goals_sampled_this_episode < 5.0f) { + co_threshold = 0.8f; + } else { + co_threshold = 0.9f; } - if (!co_offroad && !co_collided && !e->reached_goal_this_episode) { - env->co_player_log.co_player_dnf_rate += 1.0f; + int co_collision_occurred = + (env->goal_behavior == GOAL_RESPAWN) ? e->collided_before_goal : env->co_player_logs[i].collision_rate; + + if (co_frac_goal_reached > co_threshold && !co_collision_occurred) { + env->co_player_log.score += 1.0f; + } + + if (!co_offroad && !co_collided && co_frac_goal_reached < 1.0f) { + env->co_player_log.dnf_rate += 1.0f; } - int co_lane_aligned = env->co_player_logs[i].co_player_lane_alignment_rate; - env->co_player_log.co_player_lane_alignment_rate += co_lane_aligned; - float co_displacement_error = env->co_player_logs[i].co_player_avg_displacement_error; - env->co_player_log.co_player_avg_displacement_error += co_displacement_error; - env->co_player_log.co_player_episode_return += - env->co_player_logs[i].co_player_episode_return; - env->co_player_log.co_player_episode_length += - env->co_player_logs[i].co_player_episode_length; + int co_lane_aligned = env->co_player_logs[i].lane_alignment_rate; + env->co_player_log.lane_alignment_rate += co_lane_aligned; + env->co_player_log.speed_at_goal += env->co_player_logs[i].speed_at_goal; + env->co_player_log.episode_return += env->co_player_logs[i].episode_return; + env->co_player_log.episode_length += env->co_player_logs[i].episode_length; - env->co_player_log.co_player_n += 1.0f; + env->co_player_log.n += 1.0f; } } } struct AdjListNode { int dest; - struct AdjListNode* next; + struct AdjListNode *next; }; struct Graph { int V; - struct AdjListNode** array; + struct AdjListNode **array; }; // Function to create a new adjacency list node -struct AdjListNode* newAdjListNode(int dest) { - struct AdjListNode* newNode = malloc(sizeof(struct AdjListNode)); +struct AdjListNode *newAdjListNode(int dest) { + struct AdjListNode *newNode = malloc(sizeof(struct AdjListNode)); newNode->dest = dest; newNode->next = NULL; return newNode; } // Function to create a graph of V vertices -struct Graph* createGraph(int V) { - struct Graph* graph = malloc(sizeof(struct Graph)); +struct Graph *createGraph(int V) { + struct Graph *graph = malloc(sizeof(struct Graph)); graph->V = V; - graph->array = calloc(V, sizeof(struct AdjListNode*)); + graph->array = calloc(V, sizeof(struct AdjListNode *)); return graph; } // Function to get next lanes from a given lane entity index // Returns the number of next lanes found, fills next_lanes array with entity indices -int getNextLanes(struct Graph* graph, int entity_idx, int* next_lanes, int max_lanes) { +int getNextLanes(struct Graph *graph, int entity_idx, int *next_lanes, int max_lanes) { if (!graph || entity_idx < 0 || entity_idx >= graph->V) { return 0; } int count = 0; - struct AdjListNode* node = graph->array[entity_idx]; + struct AdjListNode *node = graph->array[entity_idx]; while (node && count < max_lanes) { next_lanes[count] = node->dest; @@ -524,13 +526,14 @@ int getNextLanes(struct Graph* graph, int entity_idx, int* next_lanes, int max_l } // Function to free the topology graph -void freeTopologyGraph(struct Graph* graph) { - if (!graph) return; +void freeTopologyGraph(struct Graph *graph) { + if (!graph) + return; for (int i = 0; i < graph->V; i++) { - struct AdjListNode* node = graph->array[i]; + struct AdjListNode *node = graph->array[i]; while (node) { - struct AdjListNode* temp = node; + struct AdjListNode *temp = node; node = node->next; free(temp); } @@ -540,11 +543,10 @@ void freeTopologyGraph(struct Graph* graph) { free(graph); } - -Entity* load_map_binary(const char* filename, Drive* env) { - FILE* file = fopen(filename, "rb"); - if (!file) return NULL; - +Entity *load_map_binary(const char *filename, Drive *env) { + FILE *file = fopen(filename, "rb"); + if (!file) + return NULL; // Read sdc_track_index fread(&env->sdc_track_index, sizeof(int), 1, file); @@ -552,7 +554,7 @@ Entity* load_map_binary(const char* filename, Drive* env) { // Read tracks_to_predict fread(&env->num_tracks_to_predict, sizeof(int), 1, file); if (env->num_tracks_to_predict > 0) { - env->tracks_to_predict_indices = (int*)malloc(env->num_tracks_to_predict * sizeof(int)); + env->tracks_to_predict_indices = (int *)malloc(env->num_tracks_to_predict * sizeof(int)); for (int i = 0; i < env->num_tracks_to_predict; i++) { fread(&env->tracks_to_predict_indices[i], sizeof(int), 1, file); @@ -564,25 +566,26 @@ Entity* load_map_binary(const char* filename, Drive* env) { fread(&env->num_objects, sizeof(int), 1, file); fread(&env->num_roads, sizeof(int), 1, file); env->num_entities = env->num_objects + env->num_roads; - Entity* entities = (Entity*)malloc(env->num_entities * sizeof(Entity)); + Entity *entities = (Entity *)malloc(env->num_entities * sizeof(Entity)); for (int i = 0; i < env->num_entities; i++) { - // Read base entity data + // Read base entity data fread(&entities[i].scenario_id, sizeof(int), 1, file); fread(&entities[i].type, sizeof(int), 1, file); fread(&entities[i].id, sizeof(int), 1, file); fread(&entities[i].array_size, sizeof(int), 1, file); // Allocate arrays based on type int size = entities[i].array_size; - entities[i].traj_x = (float*)malloc(size * sizeof(float)); - entities[i].traj_y = (float*)malloc(size * sizeof(float)); - entities[i].traj_z = (float*)malloc(size * sizeof(float)); - if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type + entities[i].traj_x = (float *)malloc(size * sizeof(float)); + entities[i].traj_y = (float *)malloc(size * sizeof(float)); + entities[i].traj_z = (float *)malloc(size * sizeof(float)); + if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || + entities[i].type == CYCLIST) { // Object type // Allocate arrays for object-specific data - entities[i].traj_vx = (float*)malloc(size * sizeof(float)); - entities[i].traj_vy = (float*)malloc(size * sizeof(float)); - entities[i].traj_vz = (float*)malloc(size * sizeof(float)); - entities[i].traj_heading = (float*)malloc(size * sizeof(float)); - entities[i].traj_valid = (int*)malloc(size * sizeof(int)); + entities[i].traj_vx = (float *)malloc(size * sizeof(float)); + entities[i].traj_vy = (float *)malloc(size * sizeof(float)); + entities[i].traj_vz = (float *)malloc(size * sizeof(float)); + entities[i].traj_heading = (float *)malloc(size * sizeof(float)); + entities[i].traj_valid = (int *)malloc(size * sizeof(int)); } else { // Roads don't use these arrays entities[i].traj_vx = NULL; @@ -595,7 +598,8 @@ Entity* load_map_binary(const char* filename, Drive* env) { fread(entities[i].traj_x, sizeof(float), size, file); fread(entities[i].traj_y, sizeof(float), size, file); fread(entities[i].traj_z, sizeof(float), size, file); - if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type + if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || + entities[i].type == CYCLIST) { // Object type fread(entities[i].traj_vx, sizeof(float), size, file); fread(entities[i].traj_vy, sizeof(float), size, file); fread(entities[i].traj_vz, sizeof(float), size, file); @@ -616,32 +620,31 @@ Entity* load_map_binary(const char* filename, Drive* env) { return entities; } -void set_start_position(Drive* env){ - //InitWindow(800, 600, "GPU Drive"); - //BeginDrawing(); - for(int i = 0; i < env->num_entities; i++){ +void set_start_position(Drive *env) { + for (int i = 0; i < env->num_entities; i++) { int is_active = 0; - for(int j = 0; j < env->active_agent_count; j++){ - if(env->active_agent_indices[j] == i){ + for (int j = 0; j < env->active_agent_count; j++) { + if (env->active_agent_indices[j] == i) { is_active = 1; break; } } - Entity* e = &env->entities[i]; + Entity *e = &env->entities[i]; // Clamp init_steps to ensure we don't go out of bounds int step = env->init_steps; - if (step >= e->array_size) step = e->array_size - 1; - if (step < 0) step = 0; + if (step >= e->array_size) + step = e->array_size - 1; + if (step < 0) + step = 0; e->x = e->traj_x[step]; e->y = e->traj_y[step]; e->z = e->traj_z[step]; - - if(e->type > CYCLIST || e->type == 0){ + if (e->type > CYCLIST || e->type == 0) { continue; } - if(is_active == 0){ + if (is_active == 0) { e->vx = 0; e->vy = 0; e->vz = 0; @@ -656,15 +659,12 @@ void set_start_position(Drive* env){ e->heading_y = sinf(e->heading); e->valid = e->traj_valid[env->init_steps]; e->collision_state = 0; - e->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision - e->metrics_array[OFFROAD_IDX] = 0.0f; // offroad + e->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision + e->metrics_array[OFFROAD_IDX] = 0.0f; // offroad e->metrics_array[REACHED_GOAL_IDX] = 0.0f; // reached goal e->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned - e->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; // avg displacement error - e->cumulative_displacement = 0.0f; - e->displacement_sample_count = 0; e->respawn_timestep = -1; - e->stopped = 0; + e->stopped = 0; e->removed = 0; e->respawn_count = 0; @@ -676,33 +676,35 @@ void set_start_position(Drive* env){ e->steering_angle = 0.0f; e->wheelbase = 0.6f * e->length; } - //EndDrawing(); } -int getGridIndex(Drive* env, float x1, float y1) { - if (env->grid_map->top_left_x >= env->grid_map->bottom_right_x || env->grid_map->bottom_right_y >= env->grid_map->top_left_y) { - return -1; // Invalid grid coordinates +int getGridIndex(Drive *env, float x1, float y1) { + if (env->grid_map->top_left_x >= env->grid_map->bottom_right_x || + env->grid_map->bottom_right_y >= env->grid_map->top_left_y) { + return -1; // Invalid grid coordinates } - float relativeX = x1 - env->grid_map->top_left_x; // Distance from left - float relativeY = y1 - env->grid_map->bottom_right_y; // Distance from bottom - int gridX = (int)(relativeX / GRID_CELL_SIZE); // Column index - int gridY = (int)(relativeY / GRID_CELL_SIZE); // Row index + float relativeX = x1 - env->grid_map->top_left_x; // Distance from left + float relativeY = y1 - env->grid_map->bottom_right_y; // Distance from bottom + int gridX = (int)(relativeX / GRID_CELL_SIZE); // Column index + int gridY = (int)(relativeY / GRID_CELL_SIZE); // Row index if (gridX < 0 || gridX >= env->grid_map->grid_cols || gridY < 0 || gridY >= env->grid_map->grid_rows) { - return -1; // Return -1 for out of bounds + return -1; // Return -1 for out of bounds } - int index = (gridY*env->grid_map->grid_cols) + gridX; + int index = (gridY * env->grid_map->grid_cols) + gridX; return index; } -void add_entity_to_grid(Drive* env, int grid_index, int entity_idx, int geometry_idx, int* cell_entities_insert_index){ - if(grid_index == -1){ +void add_entity_to_grid(Drive *env, int grid_index, int entity_idx, int geometry_idx, int *cell_entities_insert_index) { + if (grid_index == -1) { return; } int count = cell_entities_insert_index[grid_index]; - if(count >= env->grid_map->cell_entities_count[grid_index]) { - printf("Error: Exceeded precomputed entity count for grid cell %d. Current count: %d, Max count(Precomputed): %d\n", grid_index, count, env->grid_map->cell_entities_count[grid_index]); + if (count >= env->grid_map->cell_entities_count[grid_index]) { + printf("Error: Exceeded precomputed entity count for grid cell %d. Current count: %d, Max count(Precomputed): " + "%d\n", + grid_index, count, env->grid_map->cell_entities_count[grid_index]); return; } @@ -711,72 +713,9 @@ void add_entity_to_grid(Drive* env, int grid_index, int entity_idx, int geometry cell_entities_insert_index[grid_index] = count + 1; } - -void init_topology_graph(Drive* env){ - // Count ROAD_LANE entities - int road_lane_count = 0; - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type == ROAD_LANE){ - road_lane_count++; - } - } - - if(road_lane_count == 0){ - env->topology_graph = NULL; - return; - } - - // Create graph with all entities as vertices (we'll only use ROAD_LANE indices) - env->topology_graph = createGraph(env->num_entities); - - // Connect ROAD_LANE entities based on geometric connectivity - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type != ROAD_LANE) continue; - - Entity* lane_i = &env->entities[i]; - if(lane_i->array_size < 2) continue; // Need at least 2 points - - // Get end point of current lane - float end_x = lane_i->traj_x[lane_i->array_size - 1]; - float end_y = lane_i->traj_y[lane_i->array_size - 1]; - float end_vector_x = lane_i->traj_x[lane_i->array_size - 1] - lane_i->traj_x[lane_i->array_size - 2]; - float end_vector_y = lane_i->traj_y[lane_i->array_size - 1] - lane_i->traj_y[lane_i->array_size - 2]; - float end_heading = atan2f(end_vector_y, end_vector_x); - - // Find lanes that start near this lane's end - for(int j = 0; j < env->num_entities; j++){ - if(i == j || env->entities[j].type != ROAD_LANE) continue; - - Entity* lane_j = &env->entities[j]; - if(lane_j->array_size < 2) continue; - - // Get start point of potential next lane - float start_x = lane_j->traj_x[0]; - float start_y = lane_j->traj_y[0]; - float start_vector_x = lane_j->traj_x[1] - lane_j->traj_x[0]; - float start_vector_y = lane_j->traj_y[1] - lane_j->traj_y[0]; - float start_heading = atan2f(start_vector_y, start_vector_x); - - // Check if end of lane_i is close to start of lane_j - float distance = relative_distance_2d(end_x, end_y, start_x, start_y); - float heading_diff = fabsf(end_heading - start_heading); - - // Lane connectivity thresholds: - // - 0.01m distance: lanes must connect within 1cm (very strict for clean topology) - // - 0.1 (~5.7 degrees) heading difference: allow slight curves - if(distance < 0.01f && heading_diff < 0.1f){ - // Add directed edge from i to j (lane i connects to lane j) - struct AdjListNode* node = newAdjListNode(j); - node->next = env->topology_graph->array[i]; - env->topology_graph->array[i] = node; - } - } - } -} - -void init_grid_map(Drive* env){ +void init_grid_map(Drive *env) { // Allocate memory for the grid map structure - env->grid_map = (GridMap*)malloc(sizeof(GridMap)); + env->grid_map = (GridMap *)malloc(sizeof(GridMap)); // Find top left and bottom right points of the map float top_left_x; @@ -784,23 +723,29 @@ void init_grid_map(Drive* env){ float bottom_right_x; float bottom_right_y; int first_valid_point = 0; - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type > 3 && env->entities[i].type < 7){ + for (int i = 0; i < env->num_entities; i++) { + if (env->entities[i].type > 3 && env->entities[i].type < 7) { // Check all points in the trajectory for road elements - Entity* e = &env->entities[i]; - for(int j = 0; j < e->array_size; j++){ - if(e->traj_x[j] == INVALID_POSITION) continue; - if(e->traj_y[j] == INVALID_POSITION) continue; - if(!first_valid_point) { + Entity *e = &env->entities[i]; + for (int j = 0; j < e->array_size; j++) { + if (e->traj_x[j] == INVALID_POSITION) + continue; + if (e->traj_y[j] == INVALID_POSITION) + continue; + if (!first_valid_point) { top_left_x = bottom_right_x = e->traj_x[j]; top_left_y = bottom_right_y = e->traj_y[j]; first_valid_point = true; continue; } - if(e->traj_x[j] < top_left_x) top_left_x = e->traj_x[j]; - if(e->traj_x[j] > bottom_right_x) bottom_right_x = e->traj_x[j]; - if(e->traj_y[j] > top_left_y) top_left_y = e->traj_y[j]; - if(e->traj_y[j] < bottom_right_y) bottom_right_y = e->traj_y[j]; + if (e->traj_x[j] < top_left_x) + top_left_x = e->traj_x[j]; + if (e->traj_x[j] > bottom_right_x) + bottom_right_x = e->traj_x[j]; + if (e->traj_y[j] > top_left_y) + top_left_y = e->traj_y[j]; + if (e->traj_y[j] < bottom_right_y) + bottom_right_y = e->traj_y[j]; } } } @@ -817,41 +762,43 @@ void init_grid_map(Drive* env){ float grid_height = top_left_y - bottom_right_y; env->grid_map->grid_cols = ceil(grid_width / GRID_CELL_SIZE); env->grid_map->grid_rows = ceil(grid_height / GRID_CELL_SIZE); - int grid_cell_count = env->grid_map->grid_cols*env->grid_map->grid_rows; - env->grid_map->cells = (GridMapEntity**)calloc(grid_cell_count, sizeof(GridMapEntity*)); - env->grid_map->cell_entities_count = (int*)calloc(grid_cell_count, sizeof(int)); + int grid_cell_count = env->grid_map->grid_cols * env->grid_map->grid_rows; + env->grid_map->cells = (GridMapEntity **)calloc(grid_cell_count, sizeof(GridMapEntity *)); + env->grid_map->cell_entities_count = (int *)calloc(grid_cell_count, sizeof(int)); // Calculate number of entities in each grid cell - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type > 3 && env->entities[i].type < 7){ - for(int j = 0; j < env->entities[i].array_size - 1; j++){ - float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j+1]) / 2; - float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j+1]) / 2; + for (int i = 0; i < env->num_entities; i++) { + if (env->entities[i].type > 3 && env->entities[i].type < 7) { + for (int j = 0; j < env->entities[i].array_size - 1; j++) { + float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j + 1]) / 2; + float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j + 1]) / 2; int grid_index = getGridIndex(env, x_center, y_center); env->grid_map->cell_entities_count[grid_index]++; } } } - int cell_entities_insert_index[grid_cell_count]; // Helper array for insertion index + int cell_entities_insert_index[grid_cell_count]; // Helper array for insertion index memset(cell_entities_insert_index, 0, grid_cell_count * sizeof(int)); // Initialize grid cells - for(int grid_index = 0; grid_index < grid_cell_count; grid_index++){ - env->grid_map->cells[grid_index] = (GridMapEntity*)calloc(env->grid_map->cell_entities_count[grid_index], sizeof(GridMapEntity)); + for (int grid_index = 0; grid_index < grid_cell_count; grid_index++) { + env->grid_map->cells[grid_index] = + (GridMapEntity *)calloc(env->grid_map->cell_entities_count[grid_index], sizeof(GridMapEntity)); } - for(int i = 0;inum_entities; i++){ - if(env->entities[i].type > 3 && env->entities[i].type < 7){ // NOTE: Only Road Edges, Lines, and Lanes in grid map - for(int j = 0; j < env->entities[i].array_size - 1; j++){ - float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j+1]) / 2; - float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j+1]) / 2; + for (int i = 0; i < env->num_entities; i++) { + if (env->entities[i].type > 3 && + env->entities[i].type < 7) { // NOTE: Only Road Edges, Lines, and Lanes in grid map + for (int j = 0; j < env->entities[i].array_size - 1; j++) { + float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j + 1]) / 2; + float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j + 1]) / 2; int grid_index = getGridIndex(env, x_center, y_center); add_entity_to_grid(env, grid_index, i, j, cell_entities_insert_index); } @@ -859,24 +806,24 @@ void init_grid_map(Drive* env){ } } -void init_neighbor_offsets(Drive* env) { +void init_neighbor_offsets(Drive *env) { // Allocate memory for the offsets - env->neighbor_offsets = (int*)calloc(env->grid_map->vision_range*env->grid_map->vision_range*2, sizeof(int)); + env->neighbor_offsets = (int *)calloc(env->grid_map->vision_range * env->grid_map->vision_range * 2, sizeof(int)); // neighbor offsets in a spiral pattern int dx[] = {1, 0, -1, 0}; int dy[] = {0, 1, 0, -1}; - int x = 0; // Current x offset - int y = 0; // Current y offset - int dir = 0; // Current direction (0: right, 1: up, 2: left, 3: down) - int steps_to_take = 1; // Number of steps in current direction - int steps_taken = 0; // Steps taken in current direction + int x = 0; // Current x offset + int y = 0; // Current y offset + int dir = 0; // Current direction (0: right, 1: up, 2: left, 3: down) + int steps_to_take = 1; // Number of steps in current direction + int steps_taken = 0; // Steps taken in current direction int segments_completed = 0; // Count of direction segments completed - int total = 0; // Total offsets added - int max_offsets = env->grid_map->vision_range*env->grid_map->vision_range; + int total = 0; // Total offsets added + int max_offsets = env->grid_map->vision_range * env->grid_map->vision_range; // Start at center (0,0) int curr_idx = 0; - env->neighbor_offsets[curr_idx++] = 0; // x offset - env->neighbor_offsets[curr_idx++] = 0; // y offset + env->neighbor_offsets[curr_idx++] = 0; // x offset + env->neighbor_offsets[curr_idx++] = 0; // y offset total++; // Generate spiral pattern while (total < max_offsets) { @@ -884,16 +831,17 @@ void init_neighbor_offsets(Drive* env) { x += dx[dir]; y += dy[dir]; // Only add if within vision range bounds - if (abs(x) <= env->grid_map->vision_range/2 && abs(y) <= env->grid_map->vision_range/2) { + if (abs(x) <= env->grid_map->vision_range / 2 && abs(y) <= env->grid_map->vision_range / 2) { env->neighbor_offsets[curr_idx++] = x; env->neighbor_offsets[curr_idx++] = y; total++; } steps_taken++; // Check if we need to change direction - if(steps_taken != steps_to_take) continue; - steps_taken = 0; // Reset steps taken - dir = (dir + 1) % 4; // Change direction (clockwise: right->up->left->down) + if (steps_taken != steps_to_take) + continue; + steps_taken = 0; // Reset steps taken + dir = (dir + 1) % 4; // Change direction (clockwise: right->up->left->down) segments_completed++; // Increase step length every two direction changes if (segments_completed % 2 == 0) { @@ -902,65 +850,64 @@ void init_neighbor_offsets(Drive* env) { } } -void cache_neighbor_offsets(Drive* env){ +void cache_neighbor_offsets(Drive *env) { int count = 0; - int cell_count = env->grid_map->grid_cols*env->grid_map->grid_rows; - env->grid_map->neighbor_cache_entities = (GridMapEntity**)calloc(cell_count, sizeof(GridMapEntity*)); - env->grid_map->neighbor_cache_count = (int*)calloc(cell_count + 1, sizeof(int)); - for(int i = 0; i < cell_count; i++){ - int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates + int cell_count = env->grid_map->grid_cols * env->grid_map->grid_rows; + env->grid_map->neighbor_cache_entities = (GridMapEntity **)calloc(cell_count, sizeof(GridMapEntity *)); + env->grid_map->neighbor_cache_count = (int *)calloc(cell_count + 1, sizeof(int)); + for (int i = 0; i < cell_count; i++) { + int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates int cell_y = i / env->grid_map->grid_cols; int current_cell_neighbor_count = 0; - for(int j = 0; j < env->grid_map->vision_range*env->grid_map->vision_range; j++){ - int x = cell_x + env->neighbor_offsets[j*2]; - int y = cell_y + env->neighbor_offsets[j*2+1]; - int grid_index = env->grid_map->grid_cols*y + x; - if(x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows) continue; + for (int j = 0; j < env->grid_map->vision_range * env->grid_map->vision_range; j++) { + int x = cell_x + env->neighbor_offsets[j * 2]; + int y = cell_y + env->neighbor_offsets[j * 2 + 1]; + int grid_index = env->grid_map->grid_cols * y + x; + if (x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows) + continue; int grid_count = env->grid_map->cell_entities_count[grid_index]; current_cell_neighbor_count += grid_count; } env->grid_map->neighbor_cache_count[i] = current_cell_neighbor_count; count += current_cell_neighbor_count; - if(current_cell_neighbor_count == 0) { + if (current_cell_neighbor_count == 0) { env->grid_map->neighbor_cache_entities[i] = NULL; continue; } - env->grid_map->neighbor_cache_entities[i] = (GridMapEntity*)calloc(current_cell_neighbor_count, sizeof(GridMapEntity)); + env->grid_map->neighbor_cache_entities[i] = + (GridMapEntity *)calloc(current_cell_neighbor_count, sizeof(GridMapEntity)); } env->grid_map->neighbor_cache_count[cell_count] = count; - for(int i = 0; i < cell_count; i ++){ - int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates + for (int i = 0; i < cell_count; i++) { + int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates int cell_y = i / env->grid_map->grid_cols; int base_index = 0; - for(int j = 0; j < env->grid_map->vision_range*env->grid_map->vision_range; j++){ - int x = cell_x + env->neighbor_offsets[j*2]; - int y = cell_y + env->neighbor_offsets[j*2+1]; - int grid_index = env->grid_map->grid_cols*y + x; - if(x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows) continue; + for (int j = 0; j < env->grid_map->vision_range * env->grid_map->vision_range; j++) { + int x = cell_x + env->neighbor_offsets[j * 2]; + int y = cell_y + env->neighbor_offsets[j * 2 + 1]; + int grid_index = env->grid_map->grid_cols * y + x; + if (x < 0 || x >= env->grid_map->grid_cols || y < 0 || y >= env->grid_map->grid_rows) + continue; int grid_count = env->grid_map->cell_entities_count[grid_index]; // Skip if no entities or source is NULL - if(grid_count == 0 || env->grid_map->cells[grid_index] == NULL) { + if (grid_count == 0 || env->grid_map->cells[grid_index] == NULL) { continue; } int src_idx = grid_index; int dst_idx = base_index; // Copy grid_count pairs (entity_idx, geometry_idx) at once - memcpy(&env->grid_map->neighbor_cache_entities[i][dst_idx], - env->grid_map->cells[src_idx], - grid_count * sizeof(GridMapEntity)); - // for(int k = 0; k < grid_count; k++){ - // env->grid_map->neighbor_cache_entities[i][dst_idx + k] = env->grid_map->cells[src_idx][k]; - // } + memcpy(&env->grid_map->neighbor_cache_entities[i][dst_idx], env->grid_map->cells[src_idx], + grid_count * sizeof(GridMapEntity)); base_index += grid_count; } } } -int get_neighbor_cache_entities(Drive* env, int cell_idx, GridMapEntity* entities, int max_entities) { - GridMap* grid_map = env->grid_map; +int get_neighbor_cache_entities(Drive *env, int cell_idx, GridMapEntity *entities, int max_entities) { + GridMap *grid_map = env->grid_map; if (cell_idx < 0 || cell_idx >= (grid_map->grid_cols * grid_map->grid_rows)) { return 0; // Invalid cell index } @@ -974,14 +921,15 @@ int get_neighbor_cache_entities(Drive* env, int cell_idx, GridMapEntity* entitie return count; } -void set_means(Drive* env) { +void set_means(Drive *env) { float mean_x = 0.0f; float mean_y = 0.0f; int64_t point_count = 0; // Compute single mean for all entities (vehicles and roads) for (int i = 0; i < env->num_entities; i++) { - if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST) { + if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || + env->entities[i].type == CYCLIST) { for (int j = 0; j < env->entities[i].array_size; j++) { // Assume a validity flag exists (e.g., valid[j]); adjust if not available if (env->entities[i].traj_valid[j]) { // Add validity check if applicable @@ -1001,9 +949,11 @@ void set_means(Drive* env) { env->world_mean_x = mean_x; env->world_mean_y = mean_y; for (int i = 0; i < env->num_entities; i++) { - if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST || env->entities[i].type >= 4) { + if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || + env->entities[i].type == CYCLIST || env->entities[i].type >= 4) { for (int j = 0; j < env->entities[i].array_size; j++) { - if(env->entities[i].traj_x[j] == INVALID_POSITION) continue; + if (env->entities[i].traj_x[j] == INVALID_POSITION) + continue; env->entities[i].traj_x[j] -= mean_x; env->entities[i].traj_y[j] -= mean_y; } @@ -1011,11 +961,10 @@ void set_means(Drive* env) { env->entities[i].goal_position_y -= mean_y; } } - } -void move_expert(Drive* env, float* actions, int agent_idx){ - Entity* agent = &env->entities[agent_idx]; +void move_expert(Drive *env, float *actions, int agent_idx) { + Entity *agent = &env->entities[agent_idx]; int t = env->timestep; if (t < 0 || t >= agent->array_size) { agent->x = INVALID_POSITION; @@ -1058,7 +1007,8 @@ bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2]) float cross = dx1 * dy2 - dy1 * dx2; // If lines are parallel - if (cross == 0) return false; + if (cross == 0) + return false; // Calculate relative vectors between start points float dx3 = p1[0] - q1[0]; @@ -1072,10 +1022,12 @@ bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2]) return (s >= 0 && s <= 1 && t >= 0 && t <= 1); } -int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int max_size, const int (*local_offsets)[2], int offset_size) { +int checkNeighbors(Drive *env, float x, float y, GridMapEntity *entity_list, int max_size, + const int (*local_offsets)[2], int offset_size) { // Get the grid index for the given position (x, y) int index = getGridIndex(env, x, y); - if (index == -1) return 0; // Return 0 size if position invalid + if (index == -1) + return 0; // Return 0 size if position invalid // Calculate 2D grid coordinates int cellsX = env->grid_map->grid_cols; int gridX = index % cellsX; @@ -1086,7 +1038,8 @@ int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int int nx = gridX + local_offsets[i][0]; int ny = gridY + local_offsets[i][1]; // Ensure the neighbor is within grid bounds - if(nx < 0 || nx >= env->grid_map->grid_cols || ny < 0 || ny >= env->grid_map->grid_rows) continue; + if (nx < 0 || nx >= env->grid_map->grid_cols || ny < 0 || ny >= env->grid_map->grid_rows) + continue; int neighborIndex = ny * env->grid_map->grid_cols + nx; int count = env->grid_map->cell_entities_count[neighborIndex]; // Add entities from this cell to the list @@ -1101,7 +1054,7 @@ int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int return entity_list_count; } -int check_aabb_collision(Entity* car1, Entity* car2) { +int check_aabb_collision(Entity *car1, Entity *car2) { // Get car corners in world space float cos1 = car1->heading_x; float sin1 = car1->heading_y; @@ -1119,79 +1072,83 @@ int check_aabb_collision(Entity* car1, Entity* car2) { {car1->x + (half_len1 * cos1 - half_width1 * sin1), car1->y + (half_len1 * sin1 + half_width1 * cos1)}, {car1->x + (half_len1 * cos1 + half_width1 * sin1), car1->y + (half_len1 * sin1 - half_width1 * cos1)}, {car1->x + (-half_len1 * cos1 - half_width1 * sin1), car1->y + (-half_len1 * sin1 + half_width1 * cos1)}, - {car1->x + (-half_len1 * cos1 + half_width1 * sin1), car1->y + (-half_len1 * sin1 - half_width1 * cos1)} - }; + {car1->x + (-half_len1 * cos1 + half_width1 * sin1), car1->y + (-half_len1 * sin1 - half_width1 * cos1)}}; // Calculate car2's corners in world space float car2_corners[4][2] = { {car2->x + (half_len2 * cos2 - half_width2 * sin2), car2->y + (half_len2 * sin2 + half_width2 * cos2)}, {car2->x + (half_len2 * cos2 + half_width2 * sin2), car2->y + (half_len2 * sin2 - half_width2 * cos2)}, {car2->x + (-half_len2 * cos2 - half_width2 * sin2), car2->y + (-half_len2 * sin2 + half_width2 * cos2)}, - {car2->x + (-half_len2 * cos2 + half_width2 * sin2), car2->y + (-half_len2 * sin2 - half_width2 * cos2)} - }; + {car2->x + (-half_len2 * cos2 + half_width2 * sin2), car2->y + (-half_len2 * sin2 - half_width2 * cos2)}}; // Get the axes to check (normalized vectors perpendicular to each edge) float axes[4][2] = { - {cos1, sin1}, // Car1's length axis - {-sin1, cos1}, // Car1's width axis - {cos2, sin2}, // Car2's length axis - {-sin2, cos2} // Car2's width axis + {cos1, sin1}, // Car1's length axis + {-sin1, cos1}, // Car1's width axis + {cos2, sin2}, // Car2's length axis + {-sin2, cos2} // Car2's width axis }; // Check each axis - for(int i = 0; i < 4; i++) { + for (int i = 0; i < 4; i++) { float min1 = INFINITY, max1 = -INFINITY; float min2 = INFINITY, max2 = -INFINITY; // Project car1's corners onto the axis - for(int j = 0; j < 4; j++) { + for (int j = 0; j < 4; j++) { float proj = car1_corners[j][0] * axes[i][0] + car1_corners[j][1] * axes[i][1]; min1 = fminf(min1, proj); max1 = fmaxf(max1, proj); } // Project car2's corners onto the axis - for(int j = 0; j < 4; j++) { + for (int j = 0; j < 4; j++) { float proj = car2_corners[j][0] * axes[i][0] + car2_corners[j][1] * axes[i][1]; min2 = fminf(min2, proj); max2 = fmaxf(max2, proj); } // If there's a gap on this axis, the boxes don't intersect - if(max1 < min2 || min1 > max2) { - return 0; // No collision + if (max1 < min2 || min1 > max2) { + return 0; // No collision } } // If we get here, there's no separating axis, so the boxes intersect - return 1; // Collision + return 1; // Collision } -int collision_check(Drive* env, int agent_idx) { - Entity* agent = &env->entities[agent_idx]; +int collision_check(Drive *env, int agent_idx) { + Entity *agent = &env->entities[agent_idx]; - if(agent->x == INVALID_POSITION ) return -1; + if (agent->x == INVALID_POSITION) + return -1; int car_collided_with_index = -1; - if (agent->respawn_timestep != -1) return car_collided_with_index; // Skip respawning entities + if (agent->respawn_timestep != -1) + return car_collided_with_index; // Skip respawning entities - for(int i = 0; i < MAX_AGENTS; i++){ + for (int i = 0; i < MAX_AGENTS; i++) { int index = -1; - if(i < env->active_agent_count){ + if (i < env->active_agent_count) { index = env->active_agent_indices[i]; - } else if (i < env->num_actors){ + } else if (i < env->num_actors) { index = env->static_agent_indices[i - env->active_agent_count]; } - if(index == -1) continue; - if(index == agent_idx) continue; - Entity* entity = &env->entities[index]; - if (entity->respawn_timestep != -1) continue; // Skip respawning entities + if (index == -1) + continue; + if (index == agent_idx) + continue; + Entity *entity = &env->entities[index]; + if (entity->respawn_timestep != -1) + continue; // Skip respawning entities float x1 = entity->x; float y1 = entity->y; - float dist = ((x1 - agent->x)*(x1 - agent->x) + (y1 - agent->y)*(y1 - agent->y)); - if(dist > 225.0f) continue; - if(check_aabb_collision(agent, entity)) { + float dist = ((x1 - agent->x) * (x1 - agent->x) + (y1 - agent->y) * (y1 - agent->y)); + if (dist > 225.0f) + continue; + if (check_aabb_collision(agent, entity)) { car_collided_with_index = index; break; } @@ -1200,13 +1157,16 @@ int collision_check(Drive* env, int agent_idx) { return car_collided_with_index; } -int check_lane_aligned(Entity* car, Entity* lane, int geometry_idx) { +int check_lane_aligned(Entity *car, Entity *lane, int geometry_idx) { // Validate lane geometry length - if (!lane || lane->array_size < 2) return 0; + if (!lane || lane->array_size < 2) + return 0; // Clamp geometry index to valid segment range [0, array_size-2] - if (geometry_idx < 0) geometry_idx = 0; - if (geometry_idx >= lane->array_size - 1) geometry_idx = lane->array_size - 2; + if (geometry_idx < 0) + geometry_idx = 0; + if (geometry_idx >= lane->array_size - 1) + geometry_idx = lane->array_size - 2; // Compute local lane segment heading float heading_x1, heading_y1; @@ -1227,25 +1187,27 @@ int check_lane_aligned(Entity* car, Entity* lane, int geometry_idx) { float heading = (heading_1 + heading_2) / 2.0f; // Normalize to [-pi, pi] - if (heading > M_PI) heading -= 2.0f * M_PI; - if (heading < -M_PI) heading += 2.0f * M_PI; + if (heading > M_PI) + heading -= 2.0f * M_PI; + if (heading < -M_PI) + heading += 2.0f * M_PI; // Compute heading difference float car_heading = car->heading; // radians float heading_diff = fabsf(car_heading - heading); - if (heading_diff > M_PI) heading_diff = 2.0f * M_PI - heading_diff; + if (heading_diff > M_PI) + heading_diff = 2.0f * M_PI - heading_diff; // within 15 degrees return (heading_diff < (M_PI / 12.0f)) ? 1 : 0; } -void reset_agent_metrics(Drive* env, int agent_idx){ - Entity* agent = &env->entities[agent_idx]; - agent->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision - agent->metrics_array[OFFROAD_IDX] = 0.0f; // offroad +void reset_agent_metrics(Drive *env, int agent_idx) { + Entity *agent = &env->entities[agent_idx]; + agent->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision + agent->metrics_array[OFFROAD_IDX] = 0.0f; // offroad agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned - agent->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; agent->collision_state = 0; } @@ -1262,8 +1224,10 @@ float point_to_segment_distance_2d(float px, float py, float x1, float y1, float float t = ((px - x1) * dx + (py - y1) * dy) / (dx * dx + dy * dy); // Clamp t to the segment - if (t < 0) t = 0; - else if (t > 1) t = 1; + if (t < 0) + t = 0; + else if (t > 1) + t = 1; // Find the closest point on the segment float closestX = x1 + t * dx; @@ -1273,28 +1237,17 @@ float point_to_segment_distance_2d(float px, float py, float x1, float y1, float return sqrtf((px - closestX) * (px - closestX) + (py - closestY) * (py - closestY)); } -void compute_agent_metrics(Drive* env, int agent_idx) { - Entity* agent = &env->entities[agent_idx]; +void compute_agent_metrics(Drive *env, int agent_idx) { + Entity *agent = &env->entities[agent_idx]; reset_agent_metrics(env, agent_idx); - if(agent->x == INVALID_POSITION ) return; // invalid agent position - - // Compute displacement error - float displacement_error = compute_displacement_error(agent, env->timestep); - - if (displacement_error > 0.0f) { // Only count valid displacements - agent->cumulative_displacement += displacement_error; - agent->displacement_sample_count++; - - // Compute running average - agent->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = - agent->cumulative_displacement / agent->displacement_sample_count; - } + if (agent->x == INVALID_POSITION) + return; // invalid agent position int collided = 0; - float half_length = agent->length/2.0f; - float half_width = agent->width/2.0f; + float half_length = agent->length / 2.0f; + float half_width = agent->width / 2.0f; float cos_heading = cosf(agent->heading); float sin_heading = sinf(agent->heading); float min_distance = (float)INT16_MAX; @@ -1304,20 +1257,25 @@ void compute_agent_metrics(Drive* env, int agent_idx) { float corners[4][2]; for (int i = 0; i < 4; i++) { - corners[i][0] = agent->x + (offsets[i][0]*half_length*cos_heading - offsets[i][1]*half_width*sin_heading); - corners[i][1] = agent->y + (offsets[i][0]*half_length*sin_heading + offsets[i][1]*half_width*cos_heading); + corners[i][0] = + agent->x + (offsets[i][0] * half_length * cos_heading - offsets[i][1] * half_width * sin_heading); + corners[i][1] = + agent->y + (offsets[i][0] * half_length * sin_heading + offsets[i][1] * half_width * cos_heading); } - GridMapEntity entity_list[MAX_ENTITIES_PER_CELL*25]; // Array big enough for all neighboring cells - int list_size = checkNeighbors(env, agent->x, agent->y, entity_list, MAX_ENTITIES_PER_CELL*25, collision_offsets, 25); - for (int i = 0; i < list_size ; i++) { - if(entity_list[i].entity_idx == -1) continue; - if(entity_list[i].entity_idx == agent_idx) continue; - Entity* entity; + GridMapEntity entity_list[MAX_ENTITIES_PER_CELL * 25]; // Array big enough for all neighboring cells + int list_size = + checkNeighbors(env, agent->x, agent->y, entity_list, MAX_ENTITIES_PER_CELL * 25, collision_offsets, 25); + for (int i = 0; i < list_size; i++) { + if (entity_list[i].entity_idx == -1) + continue; + if (entity_list[i].entity_idx == agent_idx) + continue; + Entity *entity; entity = &env->entities[entity_list[i].entity_idx]; // Check for offroad collision with road edges - if(entity->type == ROAD_EDGE) { + if (entity->type == ROAD_EDGE) { int geometry_idx = entity_list[i].geometry_idx; float start[2] = {entity->traj_x[geometry_idx], entity->traj_y[geometry_idx]}; float end[2] = {entity->traj_x[geometry_idx + 1], entity->traj_y[geometry_idx + 1]}; @@ -1330,10 +1288,11 @@ void compute_agent_metrics(Drive* env, int agent_idx) { } } - if (collided == OFFROAD) break; + if (collided == OFFROAD) + break; // Find closest point on the road centerline to the agent - if(entity->type == ROAD_LANE) { + if (entity->type == ROAD_LANE) { int entity_idx = entity_list[i].entity_idx; int geometry_idx = entity_list[i].geometry_idx; @@ -1341,13 +1300,15 @@ void compute_agent_metrics(Drive* env, int agent_idx) { float end[2] = {entity->traj_x[geometry_idx + 1], entity->traj_y[geometry_idx + 1]}; float dist = point_to_segment_distance_2d(agent->x, agent->y, start[0], start[1], end[0], end[1]); - float heading_diff = fabsf(atan2f(end[1]-start[1], end[0]-start[0]) - agent->heading); + float heading_diff = fabsf(atan2f(end[1] - start[1], end[0] - start[0]) - agent->heading); // Normalize heading difference to [0, pi] - if (heading_diff > M_PI) heading_diff = 2.0f * M_PI - heading_diff; + if (heading_diff > M_PI) + heading_diff = 2.0f * M_PI - heading_diff; // Penalize if heading differs by more than 30 degrees - if (heading_diff > (M_PI / 6.0f)) dist += 3.0f; + if (heading_diff > (M_PI / 6.0f)) + dist += 3.0f; if (dist < min_distance) { min_distance = dist; @@ -1364,56 +1325,80 @@ void compute_agent_metrics(Drive* env, int agent_idx) { agent->current_lane_idx = -1; } else { agent->current_lane_idx = closest_lane_entity_idx; - - int lane_aligned = check_lane_aligned(agent, &env->entities[closest_lane_entity_idx], closest_lane_geometry_idx); + int lane_aligned = + check_lane_aligned(agent, &env->entities[closest_lane_entity_idx], closest_lane_geometry_idx); agent->metrics_array[LANE_ALIGNED_IDX] = lane_aligned; } // Check for vehicle collisions int car_collided_with_index = collision_check(env, agent_idx); - if (car_collided_with_index != -1) collided = VEHICLE_COLLISION; + if (car_collided_with_index != -1) + collided = VEHICLE_COLLISION; agent->collision_state = collided; + if (collided == VEHICLE_COLLISION) { + if (env->collision_behavior == STOP_AGENT && !agent->stopped) { + agent->stopped = 1; + agent->vx = agent->vy = 0.0f; + } else if (env->collision_behavior == REMOVE_AGENT && !agent->removed) { + Entity *agent_collided = &env->entities[car_collided_with_index]; + agent->removed = 1; + agent_collided->removed = 1; + agent->x = agent->y = -10000.0f; + agent_collided->x = agent_collided->y = -10000.0f; + } + } + if (collided == OFFROAD) { + agent->metrics_array[OFFROAD_IDX] = 1.0f; + if (env->offroad_behavior == STOP_AGENT && !agent->stopped) { + agent->stopped = 1; + agent->vx = agent->vy = 0.0f; + } else if (env->offroad_behavior == REMOVE_AGENT && !agent->removed) { + agent->removed = 1; + agent->x = agent->y = -10000.0f; + } + } + return; } -bool should_control_agent(Drive* env, int agent_idx){ - +bool should_control_agent(Drive *env, int agent_idx) { // Check if we have room for more agents or are already at capacity if (env->active_agent_count >= env->num_agents) { return false; } - Entity* entity = &env->entities[agent_idx]; + Entity *entity = &env->entities[agent_idx]; - // Shrink agent size for collision checking - entity->width *= 0.7f; // TODO: Move this somewhere else + // TODO: Move this elsewhere or remove + entity->width *= 0.7f; entity->length *= 0.7f; if (env->control_mode == CONTROL_SDC_ONLY) { - return (agent_idx == env->sdc_track_index); + return agent_idx == env->sdc_track_index; } - // Special mode: control only agents in prediction track list - if (env->control_mode == CONTROL_TRACKS_TO_PREDICT) { - for (int j = 0; j < env->num_tracks_to_predict; j++) { - if (env->tracks_to_predict_indices[j] == agent_idx) { - return true; - } - } - return false; - } + bool is_vehicle = (entity->type == VEHICLE); + bool is_ped_or_bike = (entity->type == PEDESTRIAN || entity->type == CYCLIST); + bool type_is_valid = false; + + switch (env->control_mode) { + case CONTROL_WOSAC: + // Valid types only, ignore expert flag and goal distance + return (is_vehicle || is_ped_or_bike); - // Standard mode: check type, distance to goal, and expert status - bool type_is_controllable = false; - if (env->control_mode == CONTROL_VEHICLES) { - type_is_controllable = (entity->type == VEHICLE); - } else { // CONTROL_AGENTS mode - type_is_controllable = (entity->type == VEHICLE || entity->type == PEDESTRIAN || entity->type == CYCLIST); + case CONTROL_VEHICLES: + type_is_valid = is_vehicle; + break; + + default: + type_is_valid = (is_vehicle || is_ped_or_bike); + break; } - if (!type_is_controllable || entity->mark_as_expert) { + // Filter invalid types or experts + if (!type_is_valid || entity->mark_as_expert) { return false; } @@ -1431,26 +1416,24 @@ bool should_control_agent(Drive* env, int agent_idx){ return distance_to_goal >= MIN_DISTANCE_TO_GOAL; } -void set_active_agents(Drive* env){ - +void set_active_agents(Drive *env) { // Initialize - env->active_agent_count = 0; // Policy-controlled agents - env->static_agent_count = 0; // Non-moving background agents + env->active_agent_count = 0; // Policy-controlled agents + env->static_agent_count = 0; // Non-moving background agents env->expert_static_agent_count = 0; // Expert replay agents (non-controlled) - env->num_actors = 0; // Total agents created + env->num_actors = 0; // Total agents created int active_agent_indices[MAX_AGENTS]; int static_agent_indices[MAX_AGENTS]; int expert_static_agent_indices[MAX_AGENTS]; - if(env->num_agents == 0){ + if (env->num_agents == 0) { env->num_agents = MAX_AGENTS; } - // Iterate through entities to find agents to create and/or control - for(int i = 0; i < env->num_objects && env->num_actors < MAX_AGENTS; i++){ + for (int i = 0; i < env->num_objects && env->num_actors < MAX_AGENTS; i++) { - Entity* entity = &env->entities[i]; + Entity *entity = &env->entities[i]; // Skip if not valid at initialization if (entity->traj_valid[env->init_steps] != 1) { @@ -1460,14 +1443,15 @@ void set_active_agents(Drive* env){ // Determine if entity should be created bool should_create = false; if (env->init_mode == INIT_ALL_VALID) { - should_create = true; // All valid entities + should_create = true; // All valid entities } else if (env->control_mode == CONTROL_VEHICLES) { should_create = (entity->type == VEHICLE); - } else { // Control all agents + } else { // Control all agents should_create = (entity->type == VEHICLE || entity->type == PEDESTRIAN || entity->type == CYCLIST); } - if (!should_create) continue; + if (!should_create) + continue; env->num_actors++; @@ -1475,12 +1459,13 @@ void set_active_agents(Drive* env){ bool is_controlled = false; is_controlled = should_control_agent(env, i); - if (is_controlled && env->active_agent_count >= env->max_controlled_agents && env->max_controlled_agents != -1) { + if (is_controlled && env->active_agent_count >= env->max_controlled_agents && + env->max_controlled_agents != -1) { is_controlled = false; entity->mark_as_expert = 1; } - if(is_controlled){ + if (is_controlled) { active_agent_indices[env->active_agent_count] = i; env->active_agent_count++; env->entities[i].active_agent = 1; @@ -1488,7 +1473,7 @@ void set_active_agents(Drive* env){ static_agent_indices[env->static_agent_count] = i; env->static_agent_count++; env->entities[i].active_agent = 0; - if(env->entities[i].mark_as_expert == 1 || env->active_agent_count == env->num_agents) { + if (env->entities[i].mark_as_expert == 1 || env->active_agent_count == env->num_agents) { expert_static_agent_indices[env->expert_static_agent_count] = i; env->expert_static_agent_count++; env->entities[i].mark_as_expert = 1; @@ -1497,24 +1482,24 @@ void set_active_agents(Drive* env){ } // Set up initial active agents - env->active_agent_indices = (int*)malloc(env->active_agent_count * sizeof(int)); - env->static_agent_indices = (int*)malloc(env->static_agent_count * sizeof(int)); - env->expert_static_agent_indices = (int*)malloc(env->expert_static_agent_count * sizeof(int)); - for(int i=0;iactive_agent_count;i++){ + env->active_agent_indices = (int *)malloc(env->active_agent_count * sizeof(int)); + env->static_agent_indices = (int *)malloc(env->static_agent_count * sizeof(int)); + env->expert_static_agent_indices = (int *)malloc(env->expert_static_agent_count * sizeof(int)); + for (int i = 0; i < env->active_agent_count; i++) { env->active_agent_indices[i] = active_agent_indices[i]; } - for(int i=0;istatic_agent_count;i++){ + for (int i = 0; i < env->static_agent_count; i++) { env->static_agent_indices[i] = static_agent_indices[i]; } - for(int i=0;iexpert_static_agent_count;i++){ + for (int i = 0; i < env->expert_static_agent_count; i++) { env->expert_static_agent_indices[i] = expert_static_agent_indices[i]; } return; } -void remove_bad_trajectories(Drive* env){ +void remove_bad_trajectories(Drive *env) { - if (env->control_mode != CONTROL_TRACKS_TO_PREDICT) { + if (env->control_mode != CONTROL_WOSAC) { return; // Leave all trajectories in WOSAC control mode } @@ -1526,22 +1511,23 @@ void remove_bad_trajectories(Drive* env){ collided_with_indices[i] = -1; } // move experts through trajectories to check for collisions and remove as illegal agents - for(int t = 0; t < env->scenario_length; t++){ - for(int i = 0; i < env->active_agent_count; i++){ + for (int t = 0; t < env->scenario_length; t++) { + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; move_expert(env, env->actions, agent_idx); } - for(int i = 0; i < env->expert_static_agent_count; i++){ + for (int i = 0; i < env->expert_static_agent_count; i++) { int expert_idx = env->expert_static_agent_indices[i]; - if(env->entities[expert_idx].x == INVALID_POSITION) continue; + if (env->entities[expert_idx].x == INVALID_POSITION) + continue; move_expert(env, env->actions, expert_idx); } // check collisions - for(int i = 0; i < env->active_agent_count; i++){ + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; env->entities[agent_idx].collision_state = 0; int collided_with_index = collision_check(env, agent_idx); - if((collided_with_index >= 0) && collided_agents[i] == 0){ + if ((collided_with_index >= 0) && collided_agents[i] == 0) { collided_agents[i] = 1; collided_with_indices[i] = collided_with_index; } @@ -1549,11 +1535,13 @@ void remove_bad_trajectories(Drive* env){ env->timestep++; } - for(int i = 0; i< env->active_agent_count; i++){ - if(collided_with_indices[i] == -1) continue; - for(int j = 0; j < env->static_agent_count; j++){ + for (int i = 0; i < env->active_agent_count; i++) { + if (collided_with_indices[i] == -1) + continue; + for (int j = 0; j < env->static_agent_count; j++) { int static_agent_idx = env->static_agent_indices[j]; - if(static_agent_idx != collided_with_indices[i]) continue; + if (static_agent_idx != collided_with_indices[i]) + continue; env->entities[static_agent_idx].traj_x[0] = INVALID_POSITION; env->entities[static_agent_idx].traj_y[0] = INVALID_POSITION; } @@ -1561,15 +1549,15 @@ void remove_bad_trajectories(Drive* env){ env->timestep = 0; } -void init_goal_positions(Drive* env){ - for(int x = 0;xactive_agent_count; x++){ +void init_goal_positions(Drive *env) { + for (int x = 0; x < env->active_agent_count; x++) { int agent_idx = env->active_agent_indices[x]; env->entities[agent_idx].init_goal_x = env->entities[agent_idx].goal_position_x; env->entities[agent_idx].init_goal_y = env->entities[agent_idx].goal_position_y; } } -void assign_ego_and_coplayer_roles(Drive* env) { +void assign_ego_and_coplayer_roles(Drive *env) { if (!env->population_play || env->num_ego_agents == 0) { for (int i = 0; i < env->num_entities; i++) { if (!env->entities[i].mark_as_expert) { @@ -1605,17 +1593,13 @@ void assign_ego_and_coplayer_roles(Drive* env) { } } - - - -void init(Drive* env){ +void init(Drive *env) { env->human_agent_idx = 0; env->timestep = 0; env->entities = load_map_binary(env->map_name, env); set_means(env); init_grid_map(env); - if (env->goal_behavior==GOAL_GENERATE_NEW) init_topology_graph(env); - env->grid_map->vision_range = 21; + env->grid_map->vision_range = 21; // TODO: Why is this hardcoded? init_neighbor_offsets(env); cache_neighbor_offsets(env); env->logs_capacity = 0; @@ -1625,34 +1609,32 @@ void init(Drive* env){ set_start_position(env); init_goal_positions(env); assign_ego_and_coplayer_roles(env); - env->logs = (Log*)calloc(env->active_agent_count, sizeof(Log)); + env->logs = (Log *)calloc(env->active_agent_count, sizeof(Log)); // Always allocate weight arrays for consistency - env->collision_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->offroad_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->goal_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->entropy_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->discount_weights = (float*)calloc(env->active_agent_count, sizeof(float)); + env->collision_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->offroad_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->goal_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->entropy_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->discount_weights = (float *)calloc(env->active_agent_count, sizeof(float)); if (env->population_play) { - if (env->co_player_logs) { free(env->co_player_logs); env->co_player_logs = NULL; } if (env->active_agent_count > 0) { - env->co_player_logs = (Co_Player_Log*)calloc(env->active_agent_count, sizeof(Co_Player_Log)); + env->co_player_logs = (Log *)calloc(env->active_agent_count, sizeof(Log)); } else { env->co_player_logs = NULL; } - memset(&env->co_player_log, 0, sizeof(Co_Player_Log)); + memset(&env->co_player_log, 0, sizeof(Log)); } if (env->population_play) { - if (env->co_player_logs) { free(env->co_player_logs); env->co_player_logs = NULL; @@ -1661,38 +1643,37 @@ void init(Drive* env){ // Always allocate for all active agents, not just co-players // because we index by x which goes from 0 to active_agent_count-1 if (env->active_agent_count > 0) { - env->co_player_logs = (Co_Player_Log*)calloc(env->active_agent_count, sizeof(Co_Player_Log)); + env->co_player_logs = (Log *)calloc(env->active_agent_count, sizeof(Log)); } else { env->co_player_logs = NULL; } - memset(&env->co_player_log, 0, sizeof(Co_Player_Log)); + memset(&env->co_player_log, 0, sizeof(Log)); } } - -void c_close(Drive* env){ +void c_close(Drive *env) { if (env->population_play && env->co_player_logs != NULL) { free(env->co_player_logs); free(env->co_player_ids); free(env->ego_agent_ids); } - for(int i = 0; i < env->num_entities; i++){ + for (int i = 0; i < env->num_entities; i++) { free_entity(&env->entities[i]); } free(env->entities); free(env->active_agent_indices); free(env->logs); // GridMap cleanup - int grid_cell_count = env->grid_map->grid_cols*env->grid_map->grid_rows; - for(int grid_index = 0; grid_index < grid_cell_count; grid_index++){ + int grid_cell_count = env->grid_map->grid_cols * env->grid_map->grid_rows; + for (int grid_index = 0; grid_index < grid_cell_count; grid_index++) { free(env->grid_map->cells[grid_index]); } free(env->grid_map->cells); free(env->grid_map->cell_entities_count); free(env->neighbor_offsets); - for(int i = 0; i < grid_cell_count; i++){ + for (int i = 0; i < grid_cell_count; i++) { free(env->grid_map->neighbor_cache_entities[i]); } free(env->grid_map->neighbor_cache_entities); @@ -1700,34 +1681,30 @@ void c_close(Drive* env){ free(env->grid_map); free(env->static_agent_indices); free(env->expert_static_agent_indices); - freeTopologyGraph(env->topology_graph); - // free(env->map_name); free(env->ini_file); - } -void allocate(Drive* env){ +void allocate(Drive *env) { init(env); int base_ego_dim = (env->dynamics_model == JERK) ? 10 : 7; int conditioning_dims = (env->use_rc ? 3 : 0) + (env->use_ec ? 1 : 0) + (env->use_dc ? 1 : 0); int ego_dim = base_ego_dim + conditioning_dims; // Always allocate weight arrays for consistency - env->collision_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->offroad_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->goal_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->entropy_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - env->discount_weights = (float*)calloc(env->active_agent_count, sizeof(float)); - - int max_obs = ego_dim + 7*(MAX_AGENTS - 1) + 7*MAX_ROAD_SEGMENT_OBSERVATIONS; - env->observations = (float*)calloc(env->active_agent_count*max_obs, sizeof(float)); - env->actions = (float*)calloc(env->active_agent_count*2, sizeof(float)); - env->rewards = (float*)calloc(env->active_agent_count, sizeof(float)); - env->terminals= (unsigned char*)calloc(env->active_agent_count, sizeof(unsigned char)); - + env->collision_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->offroad_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->goal_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->entropy_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + env->discount_weights = (float *)calloc(env->active_agent_count, sizeof(float)); + + int max_obs = ego_dim + 7 * (MAX_AGENTS - 1) + 7 * MAX_ROAD_SEGMENT_OBSERVATIONS; + env->observations = (float *)calloc(env->active_agent_count * max_obs, sizeof(float)); + env->actions = (float *)calloc(env->active_agent_count * 2, sizeof(float)); + env->rewards = (float *)calloc(env->active_agent_count, sizeof(float)); + env->terminals = (unsigned char *)calloc(env->active_agent_count, sizeof(unsigned char)); } -void free_allocated(Drive* env){ +void free_allocated(Drive *env) { free(env->observations); free(env->actions); free(env->rewards); @@ -1745,24 +1722,27 @@ void free_allocated(Drive* env){ float clipSpeed(float speed) { const float maxSpeed = MAX_SPEED; - if (speed > maxSpeed) return maxSpeed; - if (speed < -maxSpeed) return -maxSpeed; + if (speed > maxSpeed) + return maxSpeed; + if (speed < -maxSpeed) + return -maxSpeed; return speed; } -float normalize_heading(float heading){ - if(heading > M_PI) heading -= 2*M_PI; - if(heading < -M_PI) heading += 2*M_PI; +float normalize_heading(float heading) { + if (heading > M_PI) + heading -= 2 * M_PI; + if (heading < -M_PI) + heading += 2 * M_PI; return heading; } -float normalize_value(float value, float min, float max){ - return (value - min) / (max - min); -} +float normalize_value(float value, float min, float max) { return (value - min) / (max - min); } -void move_dynamics(Drive* env, int action_idx, int agent_idx){ - Entity* agent = &env->entities[agent_idx]; - if (agent->removed) return; +void move_dynamics(Drive *env, int action_idx, int agent_idx) { + Entity *agent = &env->entities[agent_idx]; + if (agent->removed) + return; if (agent->stopped) { agent->vx = 0.0f; @@ -1776,7 +1756,7 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ float steering = 0.0f; if (env->action_type == 1) { // continuous - float (*action_array_f)[2] = (float(*)[2])env->actions; + float (*action_array_f)[2] = (float (*)[2])env->actions; acceleration = action_array_f[action_idx][0]; steering = action_array_f[action_idx][1]; @@ -1784,8 +1764,7 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ steering *= STEERING_VALUES[12]; } else { // discrete // Interpret action as a single integer: a = accel_idx * num_steer + steer_idx - int* action_array = (int*)env->actions; - int num_accel = sizeof(ACCELERATION_VALUES) / sizeof(ACCELERATION_VALUES[0]); + int *action_array = (int *)env->actions; int num_steer = sizeof(STEERING_VALUES) / sizeof(STEERING_VALUES[0]); int action_val = action_array[action_idx]; int acceleration_index = action_val / num_steer; @@ -1801,27 +1780,28 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ float vx = agent->vx; float vy = agent->vy; - // Calculate current speed - float speed = sqrtf(vx*vx + vy*vy); + // Calculate current speed (signed based on direction relative to heading) + float speed_magnitude = sqrtf(vx * vx + vy * vy); + float v_dot_heading = vx * agent->heading_x + vy * agent->heading_y; + float signed_speed = copysignf(speed_magnitude, v_dot_heading); // Update speed with acceleration - speed = speed + acceleration*env->dt; - speed = clipSpeed(speed); - + signed_speed = signed_speed + acceleration * env->dt; + signed_speed = clipSpeed(signed_speed); // Compute yaw rate - float beta = tanh(.5*tanf(steering)); + float beta = tanh(.5 * tanf(steering)); // New heading - float yaw_rate = (speed*cosf(beta)*tanf(steering)) / agent->length; + float yaw_rate = (signed_speed * cosf(beta) * tanf(steering)) / agent->length; // New velocity - float new_vx = speed*cosf(heading + beta); - float new_vy = speed*sinf(heading + beta); + float new_vx = signed_speed * cosf(heading + beta); + float new_vy = signed_speed * sinf(heading + beta); // Update position - x = x + (new_vx*env->dt); - y = y + (new_vy*env->dt); - heading = heading + yaw_rate*env->dt; + x = x + (new_vx * env->dt); + y = y + (new_vy * env->dt); + heading = heading + yaw_rate * env->dt; // Apply updates to the agent's state agent->x = x; @@ -1836,23 +1816,26 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ // Extract action components float a_long, a_lat; if (env->action_type == 1) { // continuous - float (*action_array_f)[2] = (float(*)[2])env->actions; + float (*action_array_f)[2] = (float (*)[2])env->actions; // Asymmetric scaling for longitudinal jerk to match discrete action space // Discrete: JERK_LONG = [-15, -4, 0, 4] (more braking than acceleration) - float a_long_action = action_array_f[action_idx][0]; // [-1, 1] + float a_long_action = action_array_f[action_idx][0]; // [-1, 1] if (a_long_action < 0) { - a_long = a_long_action * (-JERK_LONG[0]); // Negative: [-1, 0] → [-15, 0] (braking) + a_long = a_long_action * (-JERK_LONG[0]); // Negative: [-1, 0] → [-15, 0] (braking) } else { - a_long = a_long_action * JERK_LONG[3]; // Positive: [0, 1] → [0, 4] (acceleration) + a_long = a_long_action * JERK_LONG[3]; // Positive: [0, 1] → [0, 4] (acceleration) } // Symmetric scaling for lateral jerk a_lat = action_array_f[action_idx][1] * JERK_LAT[2]; } else { // discrete - int (*action_array)[2] = (int(*)[2])env->actions; - int a_long_idx = action_array[action_idx][0]; - int a_lat_idx = action_array[action_idx][1]; + // Interpret action as a single integer: a = long_idx * num_lat + lat_idx + int *action_array = (int *)env->actions; + int num_lat = sizeof(JERK_LAT) / sizeof(JERK_LAT[0]); + int action_val = action_array[action_idx]; + int a_long_idx = action_val / num_lat; + int a_lat_idx = action_val % num_lat; a_long = JERK_LONG[a_long_idx]; a_lat = JERK_LAT[a_lat_idx]; } @@ -1876,7 +1859,7 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ // Calculate new velocity float v_dot_heading = agent->vx * agent->heading_x + agent->vy * agent->heading_y; - float signed_v = copysignf(sqrtf(agent->vx*agent->vx + agent->vy*agent->vy), v_dot_heading); + float signed_v = copysignf(sqrtf(agent->vx * agent->vx + agent->vy * agent->vy), v_dot_heading); float v_new = signed_v + 0.5f * (a_long_new + agent->a_long) * env->dt; // Make it easy to stop with 0 vel @@ -1931,10 +1914,23 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ return; } -void c_get_global_agent_state(Drive* env, float* x_out, float* y_out, float* z_out, float* heading_out, int* id_out) { - for(int i = 0; i < env->active_agent_count; i++){ +static inline int get_track_id_or_placeholder(Drive *env, int agent_idx) { + if (env->tracks_to_predict_indices == NULL || env->num_tracks_to_predict == 0) { + return -1; + } + for (int k = 0; k < env->num_tracks_to_predict; k++) { + if (env->tracks_to_predict_indices[k] == agent_idx) { + return env->tracks_to_predict_indices[k]; + } + } + return -1; +} + +void c_get_global_agent_state(Drive *env, float *x_out, float *y_out, float *z_out, float *heading_out, int *id_out, + float *length_out, float *width_out) { + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; - Entity* agent = &env->entities[agent_idx]; + Entity *agent = &env->entities[agent_idx]; // For WOSAC, we need the original world coordinates, so we add the world means back x_out[i] = agent->x + env->world_mean_x; @@ -1945,14 +1941,15 @@ void c_get_global_agent_state(Drive* env, float* x_out, float* y_out, float* z_o } } -void c_get_global_ground_truth_trajectories(Drive* env, float* x_out, float* y_out, float* z_out, float* heading_out, int* valid_out, int* id_out, int* scenario_id_out) { - for(int i = 0; i < env->active_agent_count; i++){ +void c_get_global_ground_truth_trajectories(Drive *env, float *x_out, float *y_out, float *z_out, float *heading_out, + int *valid_out, int *id_out, int *scenario_id_out) { + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; - Entity* agent = &env->entities[agent_idx]; + Entity *agent = &env->entities[agent_idx]; id_out[i] = env->tracks_to_predict_indices[i]; scenario_id_out[i] = agent->scenario_id; - for(int t = env->init_steps; t < agent->array_size; t++){ + for (int t = env->init_steps; t < agent->array_size; t++) { int out_idx = i * (agent->array_size - env->init_steps) + (t - env->init_steps); // Add world means back to get original world coordinates x_out[out_idx] = agent->traj_x[t] + env->world_mean_x; @@ -1964,35 +1961,67 @@ void c_get_global_ground_truth_trajectories(Drive* env, float* x_out, float* y_o } } -void compute_observations(Drive* env) { +void c_get_road_edge_counts(Drive *env, int *num_polylines_out, int *total_points_out) { + int count = 0, points = 0; + for (int i = env->num_objects; i < env->num_entities; i++) { + if (env->entities[i].type == ROAD_EDGE) { + count++; + points += env->entities[i].array_size; + } + } + *num_polylines_out = count; + *total_points_out = points; +} + +void c_get_road_edge_polylines(Drive *env, float *x_out, float *y_out, int *lengths_out, int *scenario_ids_out) { + int poly_idx = 0, pt_idx = 0; + for (int i = env->num_objects; i < env->num_entities; i++) { + Entity *e = &env->entities[i]; + if (e->type == ROAD_EDGE) { + lengths_out[poly_idx] = e->array_size; + scenario_ids_out[poly_idx] = e->scenario_id; + for (int j = 0; j < e->array_size; j++) { + x_out[pt_idx] = e->traj_x[j] + env->world_mean_x; + y_out[pt_idx] = e->traj_y[j] + env->world_mean_y; + pt_idx++; + } + poly_idx++; + } + } +} + +void compute_observations(Drive *env) { int base_ego_dim = (env->dynamics_model == JERK) ? 10 : 7; int conditioning_dims = (env->use_rc ? 3 : 0) + (env->use_ec ? 1 : 0) + (env->use_dc ? 1 : 0); int ego_dim = base_ego_dim + conditioning_dims; - int max_obs = ego_dim + 7*(MAX_AGENTS - 1) + 7*MAX_ROAD_SEGMENT_OBSERVATIONS; + int max_obs = ego_dim + 7 * (MAX_AGENTS - 1) + 7 * MAX_ROAD_SEGMENT_OBSERVATIONS; - memset(env->observations, 0, max_obs*env->active_agent_count*sizeof(float)); - float (*observations)[max_obs] = (float(*)[max_obs])env->observations; + memset(env->observations, 0, max_obs * env->active_agent_count * sizeof(float)); + float (*observations)[max_obs] = (float (*)[max_obs])env->observations; - for(int i = 0; i < env->active_agent_count; i++) { - float* obs = &observations[i][0]; - Entity* ego_entity = &env->entities[env->active_agent_indices[i]]; - if(ego_entity->type > 3) break; + for (int i = 0; i < env->active_agent_count; i++) { + float *obs = &observations[i][0]; + Entity *ego_entity = &env->entities[env->active_agent_indices[i]]; + if (ego_entity->type > 3) + break; float cos_heading = ego_entity->heading_x; float sin_heading = ego_entity->heading_y; - float ego_speed = sqrtf(ego_entity->vx*ego_entity->vx + ego_entity->vy*ego_entity->vy); + float speed_magnitude = sqrtf(ego_entity->vx * ego_entity->vx + ego_entity->vy * ego_entity->vy); + float v_dot_heading = ego_entity->vx * ego_entity->heading_x + ego_entity->vy * ego_entity->heading_y; + float signed_speed = copysignf(speed_magnitude, v_dot_heading); // Set goal distances float goal_x = ego_entity->goal_position_x - ego_entity->x; float goal_y = ego_entity->goal_position_y - ego_entity->y; // Rotate to ego vehicle's frame - float rel_goal_x = goal_x*cos_heading + goal_y*sin_heading; - float rel_goal_y = -goal_x*sin_heading + goal_y*cos_heading; + float rel_goal_x = goal_x * cos_heading + goal_y * sin_heading; + float rel_goal_y = -goal_x * sin_heading + goal_y * cos_heading; - obs[0] = rel_goal_x* 0.005f; - obs[1] = rel_goal_y* 0.005f; - obs[2] = ego_speed / MAX_SPEED; + obs[0] = rel_goal_x * 0.005f; + obs[1] = rel_goal_y * 0.005f; + obs[2] = signed_speed / MAX_SPEED; obs[3] = ego_entity->width / MAX_VEH_WIDTH; obs[4] = ego_entity->length / MAX_VEH_LEN; obs[5] = (ego_entity->collision_state > 0) ? 1.0f : 0.0f; @@ -2001,7 +2030,8 @@ void compute_observations(Drive* env) { if (env->dynamics_model == JERK) { obs[7] = ego_entity->steering_angle / M_PI; // Asymmetric normalization for a_long to match action space - obs[8] = (ego_entity->a_long < 0) ? ego_entity->a_long / (-JERK_LONG[0]) : ego_entity->a_long / JERK_LONG[3]; + obs[8] = + (ego_entity->a_long < 0) ? ego_entity->a_long / (-JERK_LONG[0]) : ego_entity->a_long / JERK_LONG[3]; obs[9] = ego_entity->a_lat / JERK_LAT[2]; } @@ -2021,86 +2051,97 @@ void compute_observations(Drive* env) { // Relative Pos of other cars int cars_seen = 0; - for(int j = 0; j < MAX_AGENTS; j++) { + for (int j = 0; j < MAX_AGENTS; j++) { int index = -1; - if(j < env->active_agent_count){ + if (j < env->active_agent_count) { index = env->active_agent_indices[j]; - } else if (j < env->num_actors){ + } else if (j < env->num_actors) { index = env->static_agent_indices[j - env->active_agent_count]; } - if(index == -1) continue; - if(env->entities[index].type > 3) break; - if(index == env->active_agent_indices[i]) continue; // Skip self, but don't increment obs_idx - Entity* other_entity = &env->entities[index]; - if(ego_entity->respawn_timestep != -1) continue; - if(other_entity->respawn_timestep != -1) continue; + if (index == -1) + continue; + if (env->entities[index].type > 3) + break; + if (index == env->active_agent_indices[i]) + continue; // Skip self, but don't increment obs_idx + Entity *other_entity = &env->entities[index]; + if (ego_entity->respawn_timestep != -1) + continue; + if (other_entity->respawn_timestep != -1) + continue; // Store original relative positions float dx = other_entity->x - ego_entity->x; float dy = other_entity->y - ego_entity->y; - float dist = (dx*dx + dy*dy); - if(dist > 2500.0f) continue; + float dist = (dx * dx + dy * dy); + if (dist > 2500.0f) + continue; // Rotate to ego vehicle's frame - float rel_x = dx*cos_heading + dy*sin_heading; - float rel_y = -dx*sin_heading + dy*cos_heading; + float rel_x = dx * cos_heading + dy * sin_heading; + float rel_y = -dx * sin_heading + dy * cos_heading; // Store observations with correct indexing obs[obs_idx] = rel_x * 0.02f; - // Add conditioning weights to observations + // Add conditioning weights to observations obs[obs_idx + 1] = rel_y * 0.02f; obs[obs_idx + 2] = other_entity->width / MAX_VEH_WIDTH; obs[obs_idx + 3] = other_entity->length / MAX_VEH_LEN; // relative heading - float rel_heading_x = other_entity->heading_x * ego_entity->heading_x + - other_entity->heading_y * ego_entity->heading_y; // cos(a-b) = cos(a)cos(b) + sin(a)sin(b) - float rel_heading_y = other_entity->heading_y * ego_entity->heading_x - - other_entity->heading_x * ego_entity->heading_y; // sin(a-b) = sin(a)cos(b) - cos(a)sin(b) + float rel_heading_x = + other_entity->heading_x * ego_entity->heading_x + + other_entity->heading_y * ego_entity->heading_y; // cos(a-b) = cos(a)cos(b) + sin(a)sin(b) + float rel_heading_y = + other_entity->heading_y * ego_entity->heading_x - + other_entity->heading_x * ego_entity->heading_y; // sin(a-b) = sin(a)cos(b) - cos(a)sin(b) obs[obs_idx + 4] = rel_heading_x; obs[obs_idx + 5] = rel_heading_y; - // obs[obs_idx + 4] = cosf(rel_heading) / MAX_ORIENTATION_RAD; - // obs[obs_idx + 5] = sinf(rel_heading) / MAX_ORIENTATION_RAD; - // // relative speed - float other_speed = sqrtf(other_entity->vx*other_entity->vx + other_entity->vy*other_entity->vy); - obs[obs_idx + 6] = other_speed / MAX_SPEED; + + // relative speed + float other_speed_magnitude = + sqrtf(other_entity->vx * other_entity->vx + other_entity->vy * other_entity->vy); + float other_v_dot_heading = + other_entity->vx * other_entity->heading_x + other_entity->vy * other_entity->heading_y; + float other_signed_speed = copysignf(other_speed_magnitude, other_v_dot_heading); + obs[obs_idx + 6] = other_signed_speed / MAX_SPEED; cars_seen++; - obs_idx += 7; // Move to next observation slot + obs_idx += 7; // Move to next observation slot } int remaining_partner_obs = (MAX_AGENTS - 1 - cars_seen) * 7; memset(&obs[obs_idx], 0, remaining_partner_obs * sizeof(float)); obs_idx += remaining_partner_obs; // map observations - GridMapEntity entity_list[MAX_ENTITIES_PER_CELL*25]; + GridMapEntity entity_list[MAX_ENTITIES_PER_CELL * 25]; int grid_idx = getGridIndex(env, ego_entity->x, ego_entity->y); int list_size = get_neighbor_cache_entities(env, grid_idx, entity_list, MAX_ROAD_SEGMENT_OBSERVATIONS); - for(int k = 0; k < list_size; k++) { + for (int k = 0; k < list_size; k++) { int entity_idx = entity_list[k].entity_idx; int geometry_idx = entity_list[k].geometry_idx; // Validate entity_idx before accessing - if(entity_idx < 0 || entity_idx >= env->num_entities) { - printf("ERROR: Invalid entity_idx %d (max: %d)\n", entity_idx, env->num_entities-1); + if (entity_idx < 0 || entity_idx >= env->num_entities) { + printf("ERROR: Invalid entity_idx %d (max: %d)\n", entity_idx, env->num_entities - 1); continue; } - Entity* entity = &env->entities[entity_idx]; + Entity *entity = &env->entities[entity_idx]; // Validate geometry_idx before accessing - if(geometry_idx < 0 || geometry_idx >= entity->array_size) { - printf("ERROR: Invalid geometry_idx %d for entity %d (max: %d)\n", - geometry_idx, entity_idx, entity->array_size-1); + if (geometry_idx < 0 || geometry_idx >= entity->array_size) { + printf("ERROR: Invalid geometry_idx %d for entity %d (max: %d)\n", geometry_idx, entity_idx, + entity->array_size - 1); continue; } float start_x = entity->traj_x[geometry_idx]; float start_y = entity->traj_y[geometry_idx]; - float end_x = entity->traj_x[geometry_idx+1]; - float end_y = entity->traj_y[geometry_idx+1]; + float end_x = entity->traj_x[geometry_idx + 1]; + float end_y = entity->traj_y[geometry_idx + 1]; float mid_x = (start_x + end_x) / 2.0f; float mid_y = (start_y + end_y) / 2.0f; float rel_x = mid_x - ego_entity->x; float rel_y = mid_y - ego_entity->y; - float x_obs = rel_x*cos_heading + rel_y*sin_heading; - float y_obs = -rel_x*sin_heading + rel_y*cos_heading; + float x_obs = rel_x * cos_heading + rel_y * sin_heading; + float y_obs = -rel_x * sin_heading + rel_y * cos_heading; float length = relative_distance_2d(mid_x, mid_y, end_x, end_y); float width = 0.1; // Calculate angle from ego to midpoint (vector from ego to midpoint) @@ -2108,14 +2149,14 @@ void compute_observations(Drive* env) { float dy = end_y - mid_y; float dx_norm = dx; float dy_norm = dy; - float hypot = sqrtf(dx*dx + dy*dy); - if(hypot > 0) { + float hypot = sqrtf(dx * dx + dy * dy); + if (hypot > 0) { dx_norm /= hypot; dy_norm /= hypot; } // Compute sin and cos of relative angle directly without atan2f - float cos_angle = dx_norm*cos_heading + dy_norm*sin_heading; - float sin_angle = -dx_norm*sin_heading + dy_norm*cos_heading; + float cos_angle = dx_norm * cos_heading + dy_norm * sin_heading; + float sin_angle = -dx_norm * sin_heading + dy_norm * cos_heading; obs[obs_idx] = x_obs * 0.02f; obs[obs_idx + 1] = y_obs * 0.02f; obs[obs_idx + 2] = length / MAX_ROAD_SEGMENT_LENGTH; @@ -2131,170 +2172,99 @@ void compute_observations(Drive* env) { } } -static int find_forward_projection_on_lane(Entity* lane, Entity* agent, int* out_segment_idx, float* out_fraction) { - int best_idx = -1; - float best_dist_sq = 1e30f; - - for (int i = 1; i < lane->array_size; i++) { - float x0 = lane->traj_x[i - 1]; - float y0 = lane->traj_y[i - 1]; - float x1 = lane->traj_x[i]; - float y1 = lane->traj_y[i]; - float dx = x1 - x0; - float dy = y1 - y0; - float seg_len_sq = dx * dx + dy * dy; - if (seg_len_sq < 1e-6f) continue; - - float to_agent_x = agent->x - x0; - float to_agent_y = agent->y - y0; - float t = (to_agent_x * dx + to_agent_y * dy) / seg_len_sq; - if (t < 0.0f) t = 0.0f; - else if (t > 1.0f) t = 1.0f; - - float proj_x = x0 + t * dx; - float proj_y = y0 + t * dy; - - float rel_x = proj_x - agent->x; - float rel_y = proj_y - agent->y; - float forward = rel_x * agent->heading_x + rel_y * agent->heading_y; - if (forward < 0.0f) continue; +void sample_new_goal(Drive *env, int agent_idx) { + // Samples a new goal position based on the existing road lane points + Entity *agent = &env->entities[agent_idx]; + float best_x = agent->x; + float best_y = agent->y; + float best_distance_error = 1e30f; - float dist_sq = rel_x * rel_x + rel_y * rel_y; - if (dist_sq < best_dist_sq) { - best_dist_sq = dist_sq; - best_idx = i; - *out_fraction = t; - } - } - - if (best_idx != -1) { - *out_segment_idx = best_idx; - return 1; - } - - return 0; -} + // Sample points from all road lanes + for (int i = env->num_objects; i < env->num_entities; i++) { + if (env->entities[i].type != ROAD_LANE) + continue; -void compute_new_goal(Drive* env, int agent_idx) { - Entity* agent = &env->entities[agent_idx]; - int current_lane = agent->current_lane_idx; + Entity *lane = &env->entities[i]; - if (current_lane == -1) return; // No current lane + // Check every point in the lane + for (int j = 0; j < lane->array_size; j++) { + float point_x = lane->traj_x[j]; + float point_y = lane->traj_y[j]; - // Target distance: 40m ahead along the lane topology from agent's current position - float target_distance = 40.0f; - int current_entity = current_lane; - Entity* lane = &env->entities[current_entity]; + // Calculate vector from agent to point + float to_point_x = point_x - agent->x; + float to_point_y = point_y - agent->y; - int initial_segment_idx = 1; - float initial_fraction = 0.0f; - if (!find_forward_projection_on_lane(lane, agent, &initial_segment_idx, &initial_fraction)) { - int forward_idx = -1; - for (int i = 0; i < lane->array_size; i++) { - float to_point_x = lane->traj_x[i] - agent->x; - float to_point_y = lane->traj_y[i] - agent->y; + // Check if point is ahead of agent float dot = to_point_x * agent->heading_x + to_point_y * agent->heading_y; - if (dot > 0.0f) { - forward_idx = i; - break; - } - } - - if (forward_idx == -1) { - agent->goal_position_x = lane->traj_x[lane->array_size - 1]; - agent->goal_position_y = lane->traj_y[lane->array_size - 1]; - agent->sampled_new_goal = 0; - return; - } - - initial_segment_idx = forward_idx; - if (initial_segment_idx == 0) initial_segment_idx = 1; - initial_fraction = 0.0f; - } - - float remaining_distance = target_distance; - int first_lane = 1; - - // Traverse the topology graph starting from the vehicle's position forward - while (current_entity != -1) { - lane = &env->entities[current_entity]; - - int start_idx = first_lane ? initial_segment_idx : 1; - // Ensure start_idx is at least 1 to avoid accessing traj_x[i-1] with i=0 - if (start_idx < 1) start_idx = 1; - first_lane = 0; + if (dot <= 0.0f) + continue; - for (int i = start_idx; i < lane->array_size; i++) { - float prev_x = lane->traj_x[i - 1]; - float prev_y = lane->traj_y[i - 1]; - float next_x = lane->traj_x[i]; - float next_y = lane->traj_y[i]; - float seg_dx = next_x - prev_x; - float seg_dy = next_y - prev_y; - float segment_length = relative_distance_2d(prev_x, prev_y, next_x, next_y); + // Calculate distance to point + float distance = sqrtf(to_point_x * to_point_x + to_point_y * to_point_y); - if (remaining_distance <= segment_length) { - agent->goal_position_x = next_x; - agent->goal_position_y = next_y; - agent->sampled_new_goal = 0; - return; + // Find point closest to target distance + float distance_error = fabsf(distance - env->goal_target_distance); + if (distance_error < best_distance_error) { + best_distance_error = distance_error; + best_x = point_x; + best_y = point_y; } - - remaining_distance -= segment_length; - } - - int connected_lanes[5]; - int num_connected = getNextLanes(env->topology_graph, current_entity, connected_lanes, 5); - - if (num_connected == 0) { - agent->goal_position_x = lane->traj_x[lane->array_size - 1]; - agent->goal_position_y = lane->traj_y[lane->array_size - 1]; - agent->sampled_new_goal = 0; - return; // No further lanes to traverse } + } - int random_idx = agent_idx % num_connected; - current_entity = connected_lanes[random_idx]; + // If no valid goal found, use another agent's initial goal + if (best_distance_error >= 1e30f && env->active_agent_count > 1) { + int other_idx = env->active_agent_indices[(agent_idx + 1) % env->active_agent_count]; + best_x = env->entities[other_idx].init_goal_x; + best_y = env->entities[other_idx].init_goal_y; } + + agent->goal_position_x = best_x; + agent->goal_position_y = best_y; + agent->goals_sampled_this_episode += 1; } -void c_reset(Drive* env){ +void c_reset(Drive *env) { env->timestep = env->init_steps; set_start_position(env); - // Initialize all conditioning weights even when no conditioning (lb=ub) - for(int i = 0; i < env->active_agent_count; i++) { - env->collision_weights[i] = ((float)rand() / RAND_MAX) * (env->collision_weight_ub - env->collision_weight_lb) + env->collision_weight_lb; - env->offroad_weights[i] = ((float)rand() / RAND_MAX) * (env->offroad_weight_ub - env->offroad_weight_lb) + env->offroad_weight_lb; - env->goal_weights[i] = ((float)rand() / RAND_MAX) * (env->goal_weight_ub - env->goal_weight_lb) + env->goal_weight_lb; - env->entropy_weights[i] = ((float)rand() / RAND_MAX) * (env->entropy_weight_ub - env->entropy_weight_lb) + env->entropy_weight_lb; - env->discount_weights[i] = ((float)rand() / RAND_MAX) * (env->discount_weight_ub - env->discount_weight_lb) + env->discount_weight_lb; - } - - for(int x = 0;xactive_agent_count; x++){ + for (int i = 0; i < env->active_agent_count; i++) { + env->collision_weights[i] = ((float)rand() / RAND_MAX) * (env->collision_weight_ub - env->collision_weight_lb) + + env->collision_weight_lb; + env->offroad_weights[i] = + ((float)rand() / RAND_MAX) * (env->offroad_weight_ub - env->offroad_weight_lb) + env->offroad_weight_lb; + env->goal_weights[i] = + ((float)rand() / RAND_MAX) * (env->goal_weight_ub - env->goal_weight_lb) + env->goal_weight_lb; + env->entropy_weights[i] = + ((float)rand() / RAND_MAX) * (env->entropy_weight_ub - env->entropy_weight_lb) + env->entropy_weight_lb; + env->discount_weights[i] = + ((float)rand() / RAND_MAX) * (env->discount_weight_ub - env->discount_weight_lb) + env->discount_weight_lb; + } + + for (int x = 0; x < env->active_agent_count; x++) { env->logs[x] = (Log){0}; int agent_idx = env->active_agent_indices[x]; env->entities[agent_idx].respawn_timestep = -1; env->entities[agent_idx].respawn_count = 0; env->entities[agent_idx].collided_before_goal = 0; - env->entities[agent_idx].reached_goal_this_episode = 0; + env->entities[agent_idx].goals_reached_this_episode = 0.0f; + // Initialize to 1 because there is one goal in the data file + env->entities[agent_idx].goals_sampled_this_episode = 1.0f; + env->entities[agent_idx].current_goal_reached = 0; env->entities[agent_idx].metrics_array[COLLISION_IDX] = 0.0f; env->entities[agent_idx].metrics_array[OFFROAD_IDX] = 0.0f; env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 0.0f; env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; - env->entities[agent_idx].cumulative_displacement = 0.0f; - env->entities[agent_idx].displacement_sample_count = 0; - env->entities[agent_idx].stopped = 0; + env->entities[agent_idx].stopped = 0; env->entities[agent_idx].removed = 0; - if (env->goal_behavior==GOAL_GENERATE_NEW) { + if (env->goal_behavior == GOAL_GENERATE_NEW) { env->entities[agent_idx].goal_position_x = env->entities[agent_idx].init_goal_x; env->entities[agent_idx].goal_position_y = env->entities[agent_idx].init_goal_y; - env->entities[agent_idx].sampled_new_goal = 0; } - if (env->population_play){ - env->co_player_logs[x] = (Co_Player_Log){0}; + if (env->population_play) { + env->co_player_logs[x] = (Log){0}; } compute_agent_metrics(env, agent_idx); @@ -2302,7 +2272,7 @@ void c_reset(Drive* env){ compute_observations(env); } -void respawn_agent(Drive* env, int agent_idx){ +void respawn_agent(Drive *env, int agent_idx) { env->entities[agent_idx].x = env->entities[agent_idx].traj_x[0]; env->entities[agent_idx].y = env->entities[agent_idx].traj_y[0]; env->entities[agent_idx].heading = env->entities[agent_idx].traj_heading[0]; @@ -2314,10 +2284,9 @@ void respawn_agent(Drive* env, int agent_idx){ env->entities[agent_idx].metrics_array[OFFROAD_IDX] = 0.0f; env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 0.0f; env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; - env->entities[agent_idx].cumulative_displacement = 0.0f; - env->entities[agent_idx].displacement_sample_count = 0; + env->entities[agent_idx].respawn_timestep = env->timestep; + env->entities[agent_idx].collided_before_goal = 0; env->entities[agent_idx].stopped = 0; env->entities[agent_idx].removed = 0; env->entities[agent_idx].a_long = 0.0f; @@ -2327,43 +2296,55 @@ void respawn_agent(Drive* env, int agent_idx){ env->entities[agent_idx].steering_angle = 0.0f; } -void c_step(Drive* env){ +void c_step(Drive *env) { memset(env->rewards, 0, env->active_agent_count * sizeof(float)); memset(env->terminals, 0, env->active_agent_count * sizeof(unsigned char)); env->timestep++; - if(env->timestep == env->scenario_length){ + + int originals_remaining = 0; + for (int i = 0; i < env->active_agent_count; i++) { + int agent_idx = env->active_agent_indices[i]; + // Keep flag true if there is at least one agent that has not been respawned yet + if (env->entities[agent_idx].respawn_count == 0) { + originals_remaining = 1; + break; + } + } + + if (env->timestep == env->scenario_length || (!originals_remaining && env->termination_mode == 1)) { add_log(env); c_reset(env); - return; } // Move static experts for (int i = 0; i < env->expert_static_agent_count; i++) { int expert_idx = env->expert_static_agent_indices[i]; - if(env->entities[expert_idx].x == INVALID_POSITION) continue; + if (env->entities[expert_idx].x == INVALID_POSITION) + continue; move_expert(env, env->actions, expert_idx); } // Process actions for all active agents - for(int i = 0; i < env->active_agent_count; i++){ + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; env->entities[agent_idx].collision_state = 0; + move_dynamics(env, i, agent_idx); // Update logs based on agent type - use i directly as log index - if(env->entities[agent_idx].is_ego){ + if (env->entities[agent_idx].is_ego) { env->logs[i].score = 0.0f; env->logs[i].episode_length += 1; - } else if(env->entities[agent_idx].is_co_player){ - env->co_player_logs[i].co_player_score = 0.0f; - env->co_player_logs[i].co_player_episode_length += 1; + } else if (env->entities[agent_idx].is_co_player) { + env->co_player_logs[i].score = 0.0f; + env->co_player_logs[i].episode_length += 1; } } // Compute metrics and rewards - for(int i = 0; i < env->active_agent_count; i++){ + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; env->entities[agent_idx].collision_state = 0; @@ -2372,143 +2353,129 @@ void c_step(Drive* env){ int collision_state = env->entities[agent_idx].collision_state; int is_ego = env->entities[agent_idx].is_ego; int is_co_player = env->entities[agent_idx].is_co_player; + int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX]; // Handle collisions - SAME REWARD for both ego and co-players - if(collision_state > 0){ - if(collision_state == VEHICLE_COLLISION){ + if (collision_state > 0) { + if (collision_state == VEHICLE_COLLISION) { env->rewards[i] = env->collision_weights[i]; - if(is_ego){ + if (is_ego) { env->logs[i].episode_return += env->collision_weights[i]; env->logs[i].collision_rate = 1.0f; - env->logs[i].avg_collisions_per_agent += 1.0f; - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return += env->collision_weights[i]; - env->co_player_logs[i].co_player_collision_rate = 1.0f; + env->logs[i].collisions_per_agent += 1.0f; + } else if (is_co_player) { + env->co_player_logs[i].episode_return += env->collision_weights[i]; + env->co_player_logs[i].collision_rate = 1.0f; + env->co_player_logs[i].collisions_per_agent += 1.0f; } - } else if(collision_state == OFFROAD){ + } else if (collision_state == OFFROAD) { env->rewards[i] = env->offroad_weights[i]; - if(is_ego){ + if (is_ego) { env->logs[i].episode_return += env->offroad_weights[i]; env->logs[i].offroad_rate = 1.0f; - env->logs[i].avg_offroad_per_agent += 1.0f; // ADD THIS - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return += env->offroad_weights[i]; - env->co_player_logs[i].co_player_offroad_rate = 1.0f; - env->logs[i].avg_offroad_per_agent += 1.0f; - } + env->logs[i].offroad_per_agent += 1.0f; + } else if (is_co_player) { + env->co_player_logs[i].episode_return += env->offroad_weights[i]; + env->co_player_logs[i].offroad_rate = 1.0f; + env->co_player_logs[i].offroad_per_agent += 1.0f; + } } - if(!env->entities[agent_idx].reached_goal_this_episode){ + if (!reached_goal) { env->entities[agent_idx].collided_before_goal = 1; } } - // Handle goal reward - SAME REWARD for both ego and co-players - float distance_to_goal = relative_distance_2d( - env->entities[agent_idx].x, - env->entities[agent_idx].y, - env->entities[agent_idx].goal_position_x, - env->entities[agent_idx].goal_position_y + // Handle goal reward - NEW INCOMING LOGIC with speed check + float distance_to_goal = + relative_distance_2d(env->entities[agent_idx].x, env->entities[agent_idx].y, + env->entities[agent_idx].goal_position_x, env->entities[agent_idx].goal_position_y); + + float current_speed = sqrtf(env->entities[agent_idx].vx * env->entities[agent_idx].vx + + env->entities[agent_idx].vy * env->entities[agent_idx].vy); - ); + // Reward agent if it is within X meters of goal and speed is below threshold + bool within_distance = distance_to_goal < env->goal_radius; + bool within_speed = current_speed <= env->goal_speed; - if(distance_to_goal < env->goal_radius){ - if (env->goal_behavior == GOAL_RESPAWN && env->entities[agent_idx].respawn_timestep != -1){ + if (within_distance && within_speed && !env->entities[agent_idx].current_goal_reached) { + if (env->goal_behavior == GOAL_RESPAWN && env->entities[agent_idx].respawn_timestep != -1) { float scaled_post_respawn_reward = env->reward_goal_post_respawn * env->goal_weights[i]; env->rewards[i] += scaled_post_respawn_reward; - if(is_ego){ + + if (is_ego) { env->logs[i].episode_return += scaled_post_respawn_reward; - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return += scaled_post_respawn_reward; + } else if (is_co_player) { + env->co_player_logs[i].episode_return += scaled_post_respawn_reward; } - } else if (env->goal_behavior == GOAL_GENERATE_NEW) { + env->entities[agent_idx].current_goal_reached = 1; + } else if (env->goal_behavior == GOAL_GENERATE_NEW && (!env->entities[agent_idx].current_goal_reached)) { env->rewards[i] += env->goal_weights[i]; - env->entities[agent_idx].sampled_new_goal = 1; - if(is_ego){ + + if (is_ego) { env->logs[i].episode_return += env->goal_weights[i]; - env->logs[i].num_goals_reached += 1; - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return += env->goal_weights[i]; - env->co_player_logs[i].co_player_num_goals_reached += 1; + } else if (is_co_player) { + env->co_player_logs[i].episode_return += env->goal_weights[i]; } + + sample_new_goal(env, agent_idx); + env->entities[agent_idx].current_goal_reached = 0; + env->entities[agent_idx].goals_reached_this_episode += 1.0f; } else { // Zero out the velocity so that the agent stops at the goal env->rewards[i] = env->goal_weights[i]; - if(is_ego){ + + if (is_ego) { env->logs[i].episode_return = env->goal_weights[i]; - env->logs[i].num_goals_reached = 1; - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return = env->goal_weights[i]; - env->co_player_logs[i].co_player_num_goals_reached = 1; + } else if (is_co_player) { + env->co_player_logs[i].episode_return = env->goal_weights[i]; } + env->entities[agent_idx].stopped = 1; - env->entities[agent_idx].vx=env->entities[agent_idx].vy = 0.0f; + env->entities[agent_idx].vx = env->entities[agent_idx].vy = 0.0f; + env->entities[agent_idx].goals_reached_this_episode += 1.0f; } - env->entities[agent_idx].reached_goal_this_episode = 1; env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 1.0f; - } - if(env->entities[agent_idx].sampled_new_goal && env->goal_behavior == GOAL_GENERATE_NEW){ - compute_new_goal(env, agent_idx); + if (is_ego) { + env->logs[i].speed_at_goal = current_speed; + } else if (is_co_player) { + env->co_player_logs[i].speed_at_goal = current_speed; + } } int lane_aligned = env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX]; - if(is_ego){ + if (is_ego) { env->logs[i].lane_alignment_rate = lane_aligned; - } else if(is_co_player){ - env->co_player_logs[i].co_player_lane_alignment_rate = lane_aligned; - } - - float current_ade = env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX]; - if(current_ade > 0.0f && env->reward_ade != 0.0f){ - float ade_reward = env->reward_ade * current_ade; - env->rewards[i] += ade_reward; - - if(is_ego){ - env->logs[i].episode_return += ade_reward; - env->logs[i].avg_displacement_error = current_ade; - } else if(is_co_player){ - env->co_player_logs[i].co_player_episode_return += ade_reward; - env->co_player_logs[i].co_player_avg_displacement_error = current_ade; - } + } else if (is_co_player) { + env->co_player_logs[i].lane_alignment_rate = lane_aligned; } } - if (env->goal_behavior==GOAL_RESPAWN) { - for(int i = 0; i < env->active_agent_count; i++){ + if (env->goal_behavior == GOAL_RESPAWN) { + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX]; - if(reached_goal){ + if (reached_goal) { respawn_agent(env, agent_idx); env->entities[agent_idx].respawn_count++; } } - } - else if (env->goal_behavior==GOAL_STOP) { - for(int i = 0; i < env->active_agent_count; i++){ + } else if (env->goal_behavior == GOAL_STOP) { + for (int i = 0; i < env->active_agent_count; i++) { int agent_idx = env->active_agent_indices[i]; int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX]; - if(reached_goal){ + if (reached_goal) { env->entities[agent_idx].stopped = 1; - env->entities[agent_idx].vx=env->entities[agent_idx].vy = 0.0f; + env->entities[agent_idx].vx = env->entities[agent_idx].vy = 0.0f; } } } - compute_observations(env); } - -const Color STONE_GRAY = (Color){80, 80, 80, 255}; -const Color PUFF_RED = (Color){187, 0, 0, 255}; -const Color PUFF_CYAN = (Color){0, 187, 187, 255}; -const Color PUFF_WHITE = (Color){241, 241, 241, 241}; -const Color PUFF_BACKGROUND = (Color){6, 24, 24, 255}; -const Color PUFF_BACKGROUND2 = (Color){18, 72, 72, 255}; -const Color LIGHTGREEN = (Color){152, 255, 152, 255}; - typedef struct Client Client; struct Client { float width; @@ -2518,17 +2485,20 @@ struct Client { float camera_zoom; Camera3D camera; Model cars[6]; - int car_assignments[MAX_AGENTS]; // To keep car model assignments consistent per vehicle + Model cyclist; + Model pedestrian; + ModelAnimation *cycle_anim; + int car_assignments[MAX_AGENTS]; // To keep car model assignments consistent per vehicle Vector3 default_camera_position; Vector3 default_camera_target; }; -Client* make_client(Drive* env){ - Client* client = (Client*)calloc(1, sizeof(Client)); +Client *make_client(Drive *env) { + Client *client = (Client *)calloc(1, sizeof(Client)); client->width = 1280; client->height = 704; SetConfigFlags(FLAG_MSAA_4X_HINT); - InitWindow(client->width, client->height, "PufferLib Ray GPU Drive"); + InitWindow(client->width, client->height, "PufferDrive"); SetTargetFPS(30); client->puffers = LoadTexture("resources/puffers_128.png"); client->cars[0] = LoadModel("resources/drive/RedCar.glb"); @@ -2537,26 +2507,30 @@ Client* make_client(Drive* env){ client->cars[3] = LoadModel("resources/drive/YellowCar.glb"); client->cars[4] = LoadModel("resources/drive/GreenCar.glb"); client->cars[5] = LoadModel("resources/drive/GreyCar.glb"); + client->cyclist = LoadModel("resources/drive/cyclist.glb"); + client->pedestrian = LoadModel("resources/drive/pedestrian.glb"); + int animCountCyc = 0; + client->cycle_anim = LoadModelAnimations("resources/drive/cyclist.glb", &animCountCyc); for (int i = 0; i < MAX_AGENTS; i++) { client->car_assignments[i] = (rand() % 4) + 1; } // Get initial target position from first active agent Vector3 target_pos = { 0, - 0, // Y is up - 1 // Z is depth + 0, // Y is up + 1 // Z is depth }; // Set up camera to look at target from above and behind client->default_camera_position = (Vector3){ - 0, // Same X as target - 120.0f, // 20 units above target - 175.0f // 20 units behind target + 0, // Same X as target + 120.0f, // 20 units above target + 175.0f // 20 units behind target }; client->default_camera_target = target_pos; client->camera.position = client->default_camera_position; client->camera.target = client->default_camera_target; - client->camera.up = (Vector3){ 0.0f, -1.0f, 0.0f }; // Y is up + client->camera.up = (Vector3){0.0f, -1.0f, 0.0f}; // Y is up client->camera.fovy = 45.0f; client->camera.projection = CAMERA_PERSPECTIVE; client->camera_zoom = 1.0f; @@ -2564,7 +2538,7 @@ Client* make_client(Drive* env){ } // Camera control functions -void handle_camera_controls(Client* client) { +void handle_camera_controls(Client *client) { static Vector2 prev_mouse_pos = {0}; static bool is_dragging = false; float camera_move_speed = 0.5f; @@ -2581,10 +2555,8 @@ void handle_camera_controls(Client* client) { if (is_dragging) { Vector2 current_mouse_pos = GetMousePosition(); - Vector2 delta = { - (current_mouse_pos.x - prev_mouse_pos.x) * camera_move_speed, - -(current_mouse_pos.y - prev_mouse_pos.y) * camera_move_speed - }; + Vector2 delta = {(current_mouse_pos.x - prev_mouse_pos.x) * camera_move_speed, + -(current_mouse_pos.y - prev_mouse_pos.y) * camera_move_speed}; // Update camera position (only X and Y) client->camera.position.x += delta.x; @@ -2602,11 +2574,9 @@ void handle_camera_controls(Client* client) { if (wheel != 0) { float zoom_factor = 1.0f - (wheel * 0.1f); // Calculate the current direction vector from target to position - Vector3 direction = { - client->camera.position.x - client->camera.target.x, - client->camera.position.y - client->camera.target.y, - client->camera.position.z - client->camera.target.z - }; + Vector3 direction = {client->camera.position.x - client->camera.target.x, + client->camera.position.y - client->camera.target.y, + client->camera.position.z - client->camera.target.z}; // Scale the direction vector by the zoom factor direction.x *= zoom_factor; @@ -2620,28 +2590,27 @@ void handle_camera_controls(Client* client) { } } -void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int lasers){ +void draw_agent_obs(Drive *env, int agent_index, int mode, int obs_only, int lasers) { // Diamond dimensions - float diamond_height = 3.0f; // Total height of diamond - float diamond_width = 1.5f; // Width of diamond - float diamond_z = 8.0f; // Base Z position + float diamond_height = 3.0f; // Total height of diamond + float diamond_width = 1.5f; // Width of diamond + float diamond_z = 8.0f; // Base Z position // Define diamond points - Vector3 top_point = (Vector3){0.0f, 0.0f, diamond_z + diamond_height/2}; // Top point - Vector3 bottom_point = (Vector3){0.0f, 0.0f, diamond_z - diamond_height/2}; // Bottom point - Vector3 front_point = (Vector3){0.0f, diamond_width/2, diamond_z}; // Front point - Vector3 back_point = (Vector3){0.0f, -diamond_width/2, diamond_z}; // Back point - Vector3 left_point = (Vector3){-diamond_width/2, 0.0f, diamond_z}; // Left point - Vector3 right_point = (Vector3){diamond_width/2, 0.0f, diamond_z}; // Right point + Vector3 top_point = (Vector3){0.0f, 0.0f, diamond_z + diamond_height / 2}; // Top point + Vector3 bottom_point = (Vector3){0.0f, 0.0f, diamond_z - diamond_height / 2}; // Bottom point + Vector3 front_point = (Vector3){0.0f, diamond_width / 2, diamond_z}; // Front point + Vector3 back_point = (Vector3){0.0f, -diamond_width / 2, diamond_z}; // Back point + Vector3 left_point = (Vector3){-diamond_width / 2, 0.0f, diamond_z}; // Left point + Vector3 right_point = (Vector3){diamond_width / 2, 0.0f, diamond_z}; // Right point // Draw the diamond faces // Top pyramid - - if(mode ==0){ - DrawTriangle3D(top_point, front_point, right_point, PUFF_CYAN); // Front-right face - DrawTriangle3D(top_point, right_point, back_point, PUFF_CYAN); // Back-right face - DrawTriangle3D(top_point, back_point, left_point, PUFF_CYAN); // Back-left face - DrawTriangle3D(top_point, left_point, front_point, PUFF_CYAN); // Front-left face + if (mode == 0) { + DrawTriangle3D(top_point, front_point, right_point, PUFF_CYAN); // Front-right face + DrawTriangle3D(top_point, right_point, back_point, PUFF_CYAN); // Back-right face + DrawTriangle3D(top_point, back_point, left_point, PUFF_CYAN); // Back-left face + DrawTriangle3D(top_point, left_point, front_point, PUFF_CYAN); // Front-left face // Bottom pyramid DrawTriangle3D(bottom_point, right_point, front_point, PUFF_CYAN); // Front-right face @@ -2649,14 +2618,14 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las DrawTriangle3D(bottom_point, left_point, back_point, PUFF_CYAN); // Back-left face DrawTriangle3D(bottom_point, front_point, left_point, PUFF_CYAN); // Front-left face } - if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){ + if (!IsKeyDown(KEY_LEFT_CONTROL) && obs_only == 0) { return; } int ego_dim = (env->dynamics_model == JERK) ? 10 : 7; - int max_obs = ego_dim + 7*(MAX_AGENTS - 1) + 7*MAX_ROAD_SEGMENT_OBSERVATIONS; - float (*observations)[max_obs] = (float(*)[max_obs])env->observations; - float* agent_obs = &observations[agent_index][0]; + int max_obs = ego_dim + 7 * (MAX_AGENTS - 1) + 7 * MAX_ROAD_SEGMENT_OBSERVATIONS; + float (*observations)[max_obs] = (float (*)[max_obs])env->observations; + float *agent_obs = &observations[agent_index][0]; // self int active_idx = env->active_agent_indices[agent_index]; float heading_self_x = env->entities[active_idx].heading_x; @@ -2666,82 +2635,64 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las // draw goal float goal_x = agent_obs[0] * 200; float goal_y = agent_obs[1] * 200; - if(mode == 0 ){ + if (mode == 0) { DrawSphere((Vector3){goal_x, goal_y, 1}, 0.5f, LIGHTGREEN); - DrawCircle3D((Vector3){goal_x, goal_y, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.3f)); + DrawCircle3D((Vector3){goal_x, goal_y, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, + Fade(LIGHTGREEN, 0.3f)); } - if (mode == 1){ - float goal_x_world = px + (goal_x * heading_self_x - goal_y*heading_self_y); - float goal_y_world = py + (goal_x * heading_self_y + goal_y*heading_self_x); + if (mode == 1) { + float goal_x_world = px + (goal_x * heading_self_x - goal_y * heading_self_y); + float goal_y_world = py + (goal_x * heading_self_y + goal_y * heading_self_x); DrawSphere((Vector3){goal_x_world, goal_y_world, 1}, 0.5f, LIGHTGREEN); - DrawCircle3D((Vector3){goal_x_world, goal_y_world, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.3f)); + DrawCircle3D((Vector3){goal_x_world, goal_y_world, 0.1f}, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, + Fade(LIGHTGREEN, 0.3f)); } // First draw other agent observations - int obs_idx = ego_dim; // Start after ego obs - for(int j = 0; j < MAX_AGENTS - 1; j++) { - if(agent_obs[obs_idx] == 0 || agent_obs[obs_idx + 1] == 0) { - obs_idx += 7; // Move to next agent observation + int obs_idx = ego_dim; // Start after ego obs + for (int j = 0; j < MAX_AGENTS - 1; j++) { + if (agent_obs[obs_idx] == 0 || agent_obs[obs_idx + 1] == 0) { + obs_idx += 7; // Move to next agent observation continue; } // Draw position of other agents float x = agent_obs[obs_idx] * 50; float y = agent_obs[obs_idx + 1] * 50; - if(lasers && mode == 0){ - DrawLine3D( - (Vector3){0, 0, 0}, - (Vector3){x, y, 1}, - ORANGE - ); - } - - float partner_x = px + (x*heading_self_x - y*heading_self_y); - float partner_y = py + (x*heading_self_y + y*heading_self_x); - if(lasers && mode ==1){ - DrawLine3D( - (Vector3){px, py, 1}, - (Vector3){partner_x,partner_y,1}, - ORANGE - ); - } - - float half_width = 0.5*agent_obs[obs_idx + 2]*MAX_VEH_WIDTH; - float half_len = 0.5*agent_obs[obs_idx + 3]*MAX_VEH_LEN; + if (lasers && mode == 0) { + DrawLine3D((Vector3){0, 0, 0}, (Vector3){x, y, 1}, ORANGE); + } + + float partner_x = px + (x * heading_self_x - y * heading_self_y); + float partner_y = py + (x * heading_self_y + y * heading_self_x); + if (lasers && mode == 1) { + DrawLine3D((Vector3){px, py, 1}, (Vector3){partner_x, partner_y, 1}, ORANGE); + } + + float half_width = 0.5 * agent_obs[obs_idx + 2] * MAX_VEH_WIDTH; + float half_len = 0.5 * agent_obs[obs_idx + 3] * MAX_VEH_LEN; float theta_x = agent_obs[obs_idx + 4]; float theta_y = agent_obs[obs_idx + 5]; float partner_angle = atan2f(theta_y, theta_x); float cos_heading = cosf(partner_angle); float sin_heading = sinf(partner_angle); Vector3 corners[4] = { - (Vector3){ - x + (half_len * cos_heading - half_width * sin_heading), - y + (half_len * sin_heading + half_width * cos_heading), - 1 - }, - (Vector3){ - x + (half_len * cos_heading + half_width * sin_heading), - y + (half_len * sin_heading - half_width * cos_heading), - 1 - }, - (Vector3){ - x + (-half_len * cos_heading + half_width * sin_heading), - y + (-half_len * sin_heading - half_width * cos_heading), - 1 - }, - (Vector3){ - x + (-half_len * cos_heading - half_width * sin_heading), - y + (-half_len * sin_heading + half_width * cos_heading), - 1 - }, + (Vector3){x + (half_len * cos_heading - half_width * sin_heading), + y + (half_len * sin_heading + half_width * cos_heading), 1}, + (Vector3){x + (half_len * cos_heading + half_width * sin_heading), + y + (half_len * sin_heading - half_width * cos_heading), 1}, + (Vector3){x + (-half_len * cos_heading + half_width * sin_heading), + y + (-half_len * sin_heading - half_width * cos_heading), 1}, + (Vector3){x + (-half_len * cos_heading - half_width * sin_heading), + y + (-half_len * sin_heading + half_width * cos_heading), 1}, }; - if(mode ==0){ + if (mode == 0) { for (int j = 0; j < 4; j++) { - DrawLine3D(corners[j], corners[(j+1)%4], ORANGE); + DrawLine3D(corners[j], corners[(j + 1) % 4], ORANGE); } } - if(mode ==1){ + if (mode == 1) { Vector3 world_corners[4]; for (int j = 0; j < 4; j++) { float lx = corners[j].x; @@ -2752,90 +2703,74 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las world_corners[j].z = 1; } for (int j = 0; j < 4; j++) { - DrawLine3D(world_corners[j], world_corners[(j+1)%4], ORANGE); + DrawLine3D(world_corners[j], world_corners[(j + 1) % 4], ORANGE); } } // draw an arrow above the car pointing in the direction that the partner is going - float arrow_length = 7.5f; - float arrow_x = x + arrow_length*cosf(partner_angle); - float arrow_y = y + arrow_length*sinf(partner_angle); + float arrow_length = 2.5f; + float arrow_x = x + arrow_length * cosf(partner_angle); + float arrow_y = y + arrow_length * sinf(partner_angle); float arrow_x_world; float arrow_y_world; - if(mode ==0){ - DrawLine3D((Vector3){x, y, 1}, (Vector3){arrow_x, arrow_y, 1}, PUFF_WHITE); + if (mode == 0) { + DrawLine3D((Vector3){x, y, 0.0}, (Vector3){arrow_x, arrow_y, 0.0}, PUFF_WHITE); } - if(mode == 1){ - arrow_x_world = px + (arrow_x * heading_self_x - arrow_y*heading_self_y); - arrow_y_world = py + (arrow_x * heading_self_y + arrow_y*heading_self_x); + if (mode == 1) { + arrow_x_world = px + (arrow_x * heading_self_x - arrow_y * heading_self_y); + arrow_y_world = py + (arrow_x * heading_self_y + arrow_y * heading_self_x); DrawLine3D((Vector3){partner_x, partner_y, 1}, (Vector3){arrow_x_world, arrow_y_world, 1}, PUFF_WHITE); } // Calculate perpendicular offsets for arrow head - float arrow_size = 2.0f; // Size of the arrow head + float arrow_size = 0.3f; // Size of the arrow head float dx = arrow_x - x; float dy = arrow_y - y; - float length = sqrtf(dx*dx + dy*dy); + float length = sqrtf(dx * dx + dy * dy); if (length > 0) { // Normalize direction vector dx /= length; dy /= length; // Calculate perpendicular vector - float perp_x = -dy * arrow_size; float perp_y = dx * arrow_size; - float arrow_x_end1 = arrow_x - dx*arrow_size + perp_x; - float arrow_y_end1 = arrow_y - dy*arrow_size + perp_y; - float arrow_x_end2 = arrow_x - dx*arrow_size - perp_x; - float arrow_y_end2 = arrow_y - dy*arrow_size - perp_y; + float arrow_x_end1 = arrow_x - dx * arrow_size + perp_x; + float arrow_y_end1 = arrow_y - dy * arrow_size + perp_y; + float arrow_x_end2 = arrow_x - dx * arrow_size - perp_x; + float arrow_y_end2 = arrow_y - dy * arrow_size - perp_y; // Draw the two lines forming the arrow head - if(mode ==0){ - DrawLine3D( - (Vector3){arrow_x, arrow_y, 1}, - (Vector3){arrow_x_end1, arrow_y_end1, 1}, - PUFF_WHITE - ); - DrawLine3D( - (Vector3){arrow_x, arrow_y, 1}, - (Vector3){arrow_x_end2, arrow_y_end2, 1}, - PUFF_WHITE - ); + if (mode == 0) { + DrawLine3D((Vector3){arrow_x, arrow_y, 0.0}, (Vector3){arrow_x_end1, arrow_y_end1, 0.0}, PUFF_WHITE); + DrawLine3D((Vector3){arrow_x, arrow_y, 0.0}, (Vector3){arrow_x_end2, arrow_y_end2, 0.0}, PUFF_WHITE); } - if(mode==1){ - float arrow_x_end1_world = px + (arrow_x_end1 * heading_self_x - arrow_y_end1*heading_self_y); - float arrow_y_end1_world = py + (arrow_x_end1 * heading_self_y + arrow_y_end1*heading_self_x); - float arrow_x_end2_world = px + (arrow_x_end2 * heading_self_x - arrow_y_end2*heading_self_y); - float arrow_y_end2_world = py + (arrow_x_end2 * heading_self_y + arrow_y_end2*heading_self_x); - DrawLine3D( - (Vector3){arrow_x_world, arrow_y_world, 1}, - (Vector3){arrow_x_end1_world, arrow_y_end1_world, 1}, - PUFF_WHITE - ); - DrawLine3D( - (Vector3){arrow_x_world, arrow_y_world, 1}, - (Vector3){arrow_x_end2_world, arrow_y_end2_world, 1}, - PUFF_WHITE - ); - + if (mode == 1) { + float arrow_x_end1_world = px + (arrow_x_end1 * heading_self_x - arrow_y_end1 * heading_self_y); + float arrow_y_end1_world = py + (arrow_x_end1 * heading_self_y + arrow_y_end1 * heading_self_x); + float arrow_x_end2_world = px + (arrow_x_end2 * heading_self_x - arrow_y_end2 * heading_self_y); + float arrow_y_end2_world = py + (arrow_x_end2 * heading_self_y + arrow_y_end2 * heading_self_x); + DrawLine3D((Vector3){arrow_x_world, arrow_y_world, 0.0}, + (Vector3){arrow_x_end1_world, arrow_y_end1_world, 0.0}, PUFF_WHITE); + DrawLine3D((Vector3){arrow_x_world, arrow_y_world, 0.0}, + (Vector3){arrow_x_end2_world, arrow_y_end2_world, 0.0}, PUFF_WHITE); } } - obs_idx += 7; // Move to next agent observation (7 values per agent) + obs_idx += PARTNER_FEATURES; // Move to next agent observation (7 values per agent) } // Then draw map observations - int map_start_idx = 7 + 7*(MAX_AGENTS - 1); // Start after agent observations - for(int k = 0; k < MAX_ROAD_SEGMENT_OBSERVATIONS; k++) { // Loop through potential map entities - int entity_idx = map_start_idx + k*7; - if(agent_obs[entity_idx] == 0 && agent_obs[entity_idx + 1] == 0){ + int map_start_idx = ego_dim + PARTNER_FEATURES * (MAX_AGENTS - 1); // Start after agent observations + for (int k = 0; k < MAX_ROAD_SEGMENT_OBSERVATIONS; k++) { // Loop through potential map entities + int entity_idx = map_start_idx + k * 7; + if (agent_obs[entity_idx] == 0 && agent_obs[entity_idx + 1] == 0) { continue; } - Color lineColor = BLUE; // Default color + Color lineColor = BLUE; // Default color int entity_type = (int)agent_obs[entity_idx + 6]; // Choose color based on entity type - if(entity_type+4 != ROAD_EDGE){ + if (entity_type + 4 != ROAD_EDGE) { continue; } lineColor = PUFF_CYAN; @@ -2848,88 +2783,60 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las float segment_length = agent_obs[entity_idx + 2] * MAX_ROAD_SEGMENT_LENGTH; // Calculate endpoint using the relative angle directly // Calculate endpoint directly - float x_start = x_middle - segment_length*cosf(rel_angle); - float y_start = y_middle - segment_length*sinf(rel_angle); - float x_end = x_middle + segment_length*cosf(rel_angle); - float y_end = y_middle + segment_length*sinf(rel_angle); - - - if(lasers && mode ==0){ - DrawLine3D((Vector3){0,0,0}, (Vector3){x_middle, y_middle, 1}, lineColor); - } - - if(mode ==1){ - float x_middle_world = px + (x_middle*heading_self_x - y_middle*heading_self_y); - float y_middle_world = py + (x_middle*heading_self_y + y_middle*heading_self_x); - float x_start_world = px + (x_start*heading_self_x - y_start*heading_self_y); - float y_start_world = py + (x_start*heading_self_y + y_start*heading_self_x); - float x_end_world = px + (x_end*heading_self_x - y_end*heading_self_y); - float y_end_world = py + (x_end*heading_self_y + y_end*heading_self_x); + float x_start = x_middle - segment_length * cosf(rel_angle); + float y_start = y_middle - segment_length * sinf(rel_angle); + float x_end = x_middle + segment_length * cosf(rel_angle); + float y_end = y_middle + segment_length * sinf(rel_angle); + + if (lasers && mode == 0) { + DrawLine3D((Vector3){0, 0, 0}, (Vector3){x_middle, y_middle, 1}, lineColor); + } + + if (mode == 1) { + float x_middle_world = px + (x_middle * heading_self_x - y_middle * heading_self_y); + float y_middle_world = py + (x_middle * heading_self_y + y_middle * heading_self_x); + float x_start_world = px + (x_start * heading_self_x - y_start * heading_self_y); + float y_start_world = py + (x_start * heading_self_y + y_start * heading_self_x); + float x_end_world = px + (x_end * heading_self_x - y_end * heading_self_y); + float y_end_world = py + (x_end * heading_self_y + y_end * heading_self_x); DrawCube((Vector3){x_middle_world, y_middle_world, 1}, 0.5f, 0.5f, 0.5f, lineColor); DrawLine3D((Vector3){x_start_world, y_start_world, 1}, (Vector3){x_end_world, y_end_world, 1}, BLUE); - if(lasers) DrawLine3D((Vector3){px,py,1}, (Vector3){x_middle_world, y_middle_world, 1}, lineColor); + if (lasers) + DrawLine3D((Vector3){px, py, 1}, (Vector3){x_middle_world, y_middle_world, 1}, lineColor); } - if(mode ==0){ + if (mode == 0) { DrawCube((Vector3){x_middle, y_middle, 1}, 0.5f, 0.5f, 0.5f, lineColor); DrawLine3D((Vector3){x_start, y_start, 1}, (Vector3){x_end, y_end, 1}, BLUE); } } } -void draw_road_edge(Drive* env, float start_x, float start_y, float end_x, float end_y){ - Color CURB_TOP = (Color){220, 220, 220, 255}; // Top surface - lightest - Color CURB_SIDE = (Color){180, 180, 180, 255}; // Side faces - medium +void draw_road_edge(Drive *env, float start_x, float start_y, float end_x, float end_y) { + Color CURB_TOP = (Color){220, 220, 220, 255}; // Top surface - lightest + Color CURB_SIDE = (Color){180, 180, 180, 255}; // Side faces - medium Color CURB_BOTTOM = (Color){160, 160, 160, 255}; - // Calculate curb dimensions - float curb_height = 0.5f; // Height of the curb - float curb_width = 0.3f; // Width/thickness of the curb - float road_z = 0.2f; // Ensure z-level for roads is below agents + // Calculate curb dimensions + float curb_height = 0.5f; // Height of the curb + float curb_width = 0.3f; // Width/thickness of the curb + float road_z = 0.0f; // Ensure z-level for roads is below agents // Calculate direction vector between start and end - Vector3 direction = { - end_x - start_x, - end_y - start_y, - 0.0f - }; + Vector3 direction = {end_x - start_x, end_y - start_y, 0.0f}; // Calculate length of the segment float length = sqrtf(direction.x * direction.x + direction.y * direction.y); // Normalize direction vector - Vector3 normalized_dir = { - direction.x / length, - direction.y / length, - 0.0f - }; + Vector3 normalized_dir = {direction.x / length, direction.y / length, 0.0f}; // Calculate perpendicular vector for width - Vector3 perpendicular = { - -normalized_dir.y, - normalized_dir.x, - 0.0f - }; + Vector3 perpendicular = {-normalized_dir.y, normalized_dir.x, 0.0f}; // Calculate the four bottom corners of the curb - Vector3 b1 = { - start_x - perpendicular.x * curb_width/2, - start_y - perpendicular.y * curb_width/2, - road_z - }; - Vector3 b2 = { - start_x + perpendicular.x * curb_width/2, - start_y + perpendicular.y * curb_width/2, - road_z - }; - Vector3 b3 = { - end_x + perpendicular.x * curb_width/2, - end_y + perpendicular.y * curb_width/2, - road_z - }; - Vector3 b4 = { - end_x - perpendicular.x * curb_width/2, - end_y - perpendicular.y * curb_width/2, - road_z - }; + Vector3 b1 = {start_x - perpendicular.x * curb_width / 2, start_y - perpendicular.y * curb_width / 2, road_z}; + Vector3 b2 = {start_x + perpendicular.x * curb_width / 2, start_y + perpendicular.y * curb_width / 2, road_z}; + Vector3 b3 = {end_x + perpendicular.x * curb_width / 2, end_y + perpendicular.y * curb_width / 2, road_z}; + Vector3 b4 = {end_x - perpendicular.x * curb_width / 2, end_y - perpendicular.y * curb_width / 2, road_z}; // Draw the curb faces // Bottom face @@ -2955,56 +2862,58 @@ void draw_road_edge(Drive* env, float start_x, float start_y, float end_x, float DrawTriangle3D(t4, t1, b1, CURB_SIDE); } -void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers, int show_grid){ - // Draw a grid to help with orientation - // DrawGrid(20, 1.0f); - DrawLine3D((Vector3){env->grid_map->top_left_x, env->grid_map->top_left_y, 0}, (Vector3){env->grid_map->bottom_right_x, env->grid_map->top_left_y, 0}, PUFF_CYAN); - DrawLine3D((Vector3){env->grid_map->top_left_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->top_left_x, env->grid_map->top_left_y, 0}, PUFF_CYAN); - DrawLine3D((Vector3){env->grid_map->bottom_right_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->bottom_right_x, env->grid_map->top_left_y, 0}, PUFF_CYAN); - DrawLine3D((Vector3){env->grid_map->top_left_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->bottom_right_x, env->grid_map->bottom_right_y, 0}, PUFF_CYAN); - for(int i = 0; i < env->num_entities; i++) { +void draw_scene(Drive *env, Client *client, int mode, int obs_only, int lasers, int show_grid) { + + if (show_grid) { + float grid_start_x = env->grid_map->top_left_x; + float grid_start_y = env->grid_map->bottom_right_y; + for (int i = 0; i < env->grid_map->grid_cols; i++) { + for (int j = 0; j < env->grid_map->grid_rows; j++) { + float x = grid_start_x + i * GRID_CELL_SIZE; + float y = grid_start_y + j * GRID_CELL_SIZE; + DrawCubeWires((Vector3){x + GRID_CELL_SIZE / 2, y + GRID_CELL_SIZE / 2, 0.0f}, GRID_CELL_SIZE, + GRID_CELL_SIZE, 0.1f, Fade(PUFF_BACKGROUND2, 0.3f)); + } + } + } + + // Draw a grid to help with orientation + for (int i = 0; i < env->num_entities; i++) { // Draw objects - if(env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST) { + if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || + env->entities[i].type == CYCLIST) { // Check if this vehicle is an active agent bool is_active_agent = false; bool is_static_agent = false; int agent_index = -1; - for(int j = 0; j < env->active_agent_count; j++) { - if(env->active_agent_indices[j] == i) { + for (int j = 0; j < env->active_agent_count; j++) { + if (env->active_agent_indices[j] == i) { is_active_agent = true; agent_index = j; break; } } - for(int j = 0; j < env->static_agent_count; j++) { - if(env->static_agent_indices[j] == i) { + for (int j = 0; j < env->static_agent_count; j++) { + if (env->static_agent_indices[j] == i) { is_static_agent = true; break; } } // HIDE CARS ON RESPAWN - IMPORTANT TO KNOW VISUAL SETTING - if((!is_active_agent && !is_static_agent) || env->entities[i].respawn_timestep != -1){ + if ((!is_active_agent && !is_static_agent) || env->entities[i].respawn_timestep != -1) { continue; } Vector3 position; float heading; - position = (Vector3){ - env->entities[i].x, - env->entities[i].y, - 1 - }; + position = (Vector3){env->entities[i].x, env->entities[i].y, 1.1}; heading = env->entities[i].heading; // Create size vector - Vector3 size = { - env->entities[i].length, - env->entities[i].width, - env->entities[i].height - }; + Vector3 size = {env->entities[i].length, env->entities[i].width, env->entities[i].height}; bool is_expert = (!is_active_agent) && (env->entities[i].mark_as_expert == 1); // Save current transform - if(mode==1){ + if (mode == 1) { float cos_heading = env->entities[i].heading_x; float sin_heading = env->entities[i].heading_y; @@ -3014,220 +2923,176 @@ void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers, // Calculate the four corners of the collision box Vector3 corners[4] = { - (Vector3){ - position.x + (half_len * cos_heading - half_width * sin_heading), - position.y + (half_len * sin_heading + half_width * cos_heading), - position.z - }, - - - (Vector3){ - position.x + (half_len * cos_heading + half_width * sin_heading), - position.y + (half_len * sin_heading - half_width * cos_heading), - position.z - }, - (Vector3){ - position.x + (-half_len * cos_heading + half_width * sin_heading), - position.y + (-half_len * sin_heading - half_width * cos_heading), - position.z - }, - (Vector3){ - position.x + (-half_len * cos_heading - half_width * sin_heading), - position.y + (-half_len * sin_heading + half_width * cos_heading), - position.z - }, - + (Vector3){position.x + (half_len * cos_heading - half_width * sin_heading), + position.y + (half_len * sin_heading + half_width * cos_heading), position.z}, + (Vector3){position.x + (half_len * cos_heading + half_width * sin_heading), + position.y + (half_len * sin_heading - half_width * cos_heading), position.z}, + (Vector3){position.x + (-half_len * cos_heading + half_width * sin_heading), + position.y + (-half_len * sin_heading - half_width * cos_heading), position.z}, + (Vector3){position.x + (-half_len * cos_heading - half_width * sin_heading), + position.y + (-half_len * sin_heading + half_width * cos_heading), position.z}, }; - if(agent_index == env->human_agent_idx && !env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]) { + if (agent_index == env->human_agent_idx && + !env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]) { draw_agent_obs(env, agent_index, mode, obs_only, lasers); } - if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){ + + if ((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx) { continue; } // --- Draw the car --- - - Vector3 carPos = { position.x, position.y, position.z }; - Color car_color = GRAY; // default for static - if (is_expert) car_color = GOLD; // expert replay - if (is_active_agent) car_color = BLUE; // policy-controlled - if (is_active_agent && env->entities[i].collision_state > 0) car_color = RED; + Color car_color = GRAY; // default for static + if (is_expert) + car_color = GOLD; // expert replay + if (is_active_agent) + car_color = BLUE; // policy-controlled + if (is_active_agent && env->entities[i].collision_state > 0) + car_color = RED; rlSetLineWidth(3.0f); for (int j = 0; j < 4; j++) { - DrawLine3D(corners[j], corners[(j+1)%4], car_color); + DrawLine3D(corners[j], corners[(j + 1) % 4], car_color); } // --- Draw a heading arrow pointing forward --- Vector3 arrowStart = position; - Vector3 arrowEnd = { - position.x + cos_heading * half_len * 1.5f, // extend arrow beyond car - position.y + sin_heading * half_len * 1.5f, - position.z - }; + Vector3 arrowEnd = {position.x + cos_heading * half_len * 1.5f, // extend arrow beyond car + position.y + sin_heading * half_len * 1.5f, position.z}; DrawLine3D(arrowStart, arrowEnd, car_color); - DrawSphere(arrowEnd, 0.2f, car_color); // arrow tip + DrawSphere(arrowEnd, 0.2f, car_color); // arrow tip - } - else { + } else { // Agent view rlPushMatrix(); // Translate to position, rotate around Y axis, then draw rlTranslatef(position.x, position.y, position.z); - rlRotatef(heading*RAD2DEG, 0.0f, 0.0f, 1.0f); // Convert radians to degrees - // Determine color based on status - Color object_color = PUFF_BACKGROUND2; // fill color unused for model tint - Color outline_color = PUFF_CYAN; // not used for model tint - Model car_model = client->cars[5]; - if(is_active_agent){ - car_model = client->cars[client->car_assignments[i %64]]; - } - if(agent_index == env->human_agent_idx){ - object_color = PUFF_CYAN; - outline_color = PUFF_WHITE; - } - if(is_active_agent && env->entities[i].collision_state > 0) { - car_model = client->cars[0]; // Collided agent + rlRotatef(heading * RAD2DEG, 0.0f, 0.0f, 1.0f); // Convert radians to degrees + + // Select car model (skip index 0) + Model car_model = client->cars[(i % 5) + 1]; // Cycles through indices 1-5 + + if (agent_index == env->human_agent_idx) { + car_model = client->cars[0]; // Ego agent always uses red car + } else if (is_active_agent) { + + car_model = client->cars[(i % 5) + 1]; + + if (env->entities[i].collision_state > 0) { + car_model = client->cars[0]; // Collided agents use red + } } - // Draw obs for human selected agent - if(agent_index == env->human_agent_idx && !env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]) { + // Draw obs for selected agent index + if (agent_index == env->human_agent_idx && + (!env->entities[agent_index].metrics_array[REACHED_GOAL_IDX] || + env->goal_behavior == GOAL_GENERATE_NEW || env->goal_behavior == GOAL_STOP)) { draw_agent_obs(env, agent_index, mode, obs_only, lasers); } + // Draw cube for cars static and active // Calculate scale factors based on desired size and model dimensions - BoundingBox bounds = GetModelBoundingBox(car_model); - Vector3 model_size = { - bounds.max.x - bounds.min.x, - bounds.max.y - bounds.min.y, - bounds.max.z - bounds.min.z - }; - Vector3 scale = { - size.x / model_size.x, - size.y / model_size.y, - size.z / model_size.z - }; - if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){ - rlPopMatrix(); - continue; + Vector3 model_size = {bounds.max.x - bounds.min.x, bounds.max.y - bounds.min.y, + bounds.max.z - bounds.min.z}; + Vector3 scale = {size.x / model_size.x, size.y / model_size.y, size.z / model_size.z}; + // if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){ + // rlPopMatrix(); + // continue; + // } + if (env->entities[i].type == CYCLIST) { + scale = (Vector3){0.01, 0.01, 0.01}; + car_model = client->cyclist; + } + if (env->entities[i].type == PEDESTRIAN) { + scale = (Vector3){2, 2, 2}; + car_model = client->pedestrian; } - DrawModelEx(car_model, (Vector3){0, 0, 0}, (Vector3){1, 0, 0}, 90.0f, scale, WHITE); { - float cos_heading = env->entities[i].heading_x; - float sin_heading = env->entities[i].heading_y; float half_len = env->entities[i].length * 0.5f; float half_width = env->entities[i].width * 0.5f; Vector3 corners[4] = { - (Vector3){ 0 + ( half_len * cos_heading - half_width * sin_heading), 0 + ( half_len * sin_heading + half_width * cos_heading), 0 }, - (Vector3){ 0 + ( half_len * cos_heading + half_width * sin_heading), 0 + ( half_len * sin_heading - half_width * cos_heading), 0 }, - (Vector3){ 0 + (-half_len * cos_heading + half_width * sin_heading), 0 + (-half_len * sin_heading - half_width * cos_heading), 0 }, - (Vector3){ 0 + (-half_len * cos_heading - half_width * sin_heading), 0 + (-half_len * sin_heading + half_width * cos_heading), 0 }, + (Vector3){half_len, -half_width, 0}, // Front-left + (Vector3){half_len, half_width, 0}, // Front-right + (Vector3){-half_len, half_width, 0}, // Back-right + (Vector3){-half_len, -half_width, 0}, // Back-left }; - Color wire_color = GRAY; // static - if (!is_active_agent && env->entities[i].mark_as_expert == 1) wire_color = GOLD; // expert replay - if (is_active_agent) wire_color = BLUE; // policy - if (is_active_agent && env->entities[i].collision_state > 0) wire_color = RED; + Color wire_color = GRAY; // static + if (!is_active_agent && env->entities[i].mark_as_expert == 1) + wire_color = GOLD; // expert replay + if (is_active_agent) + wire_color = BLUE; // policy + if (is_active_agent && env->entities[i].collision_state > 0) + wire_color = RED; rlSetLineWidth(2.0f); for (int j = 0; j < 4; j++) { - DrawLine3D(corners[j], corners[(j+1)%4], wire_color); + DrawLine3D(corners[j], corners[(j + 1) % 4], wire_color); } } rlPopMatrix(); } // FPV Camera Control - if(IsKeyDown(KEY_SPACE) && env->human_agent_idx== agent_index){ - if(env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]){ - env->human_agent_idx = rand() % env->active_agent_count; - } - Vector3 camera_position = (Vector3){ - position.x - (25.0f * cosf(heading)), - position.y - (25.0f * sinf(heading)), - position.z + 15 - }; + if (IsKeyDown(KEY_SPACE) && env->human_agent_idx == agent_index) { + Vector3 camera_position = (Vector3){position.x - (25.0f * cosf(heading)), + position.y - (25.0f * sinf(heading)), position.z + 15}; - Vector3 camera_target = (Vector3){ - position.x + 40.0f * cosf(heading), - position.y + 40.0f * sinf(heading), - position.z - 5.0f - }; + Vector3 camera_target = (Vector3){position.x + 40.0f * cosf(heading), + position.y + 40.0f * sinf(heading), position.z - 5.0f}; client->camera.position = camera_position; client->camera.target = camera_target; client->camera.up = (Vector3){0, 0, 1}; } - if(IsKeyReleased(KEY_SPACE)){ + if (IsKeyReleased(KEY_SPACE)) { client->camera.position = client->default_camera_position; client->camera.target = client->default_camera_target; client->camera.up = (Vector3){0, 0, 1}; } // Draw goal position for active agents - - if(!is_active_agent || env->entities[i].valid == 0) { + if (!is_active_agent || env->entities[i].valid == 0) { continue; } - if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){ - DrawSphere((Vector3){ - env->entities[i].goal_position_x, - env->entities[i].goal_position_y, - 1 - }, 0.5f, DARKGREEN); - - DrawCircle3D((Vector3){ - env->entities[i].goal_position_x, - env->entities[i].goal_position_y, - 0.1f - }, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.3f)); + if (!IsKeyDown(KEY_LEFT_CONTROL) && obs_only == 0) { + DrawSphere((Vector3){env->entities[i].goal_position_x, env->entities[i].goal_position_y, 1}, 0.5f, + DARKGREEN); + + DrawCircle3D((Vector3){env->entities[i].goal_position_x, env->entities[i].goal_position_y, 0.1f}, + env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.9f)); } } // Draw road elements - if(env->entities[i].type <=3 && env->entities[i].type >= 7){ + if (env->entities[i].type <= 3 && env->entities[i].type >= 7) { continue; } - for(int j = 0; j < env->entities[i].array_size - 1; j++) { - Vector3 start = { - env->entities[i].traj_x[j], - env->entities[i].traj_y[j], - 1 - }; - Vector3 end = { - env->entities[i].traj_x[j + 1], - env->entities[i].traj_y[j + 1], - 1 - }; + for (int j = 0; j < env->entities[i].array_size - 1; j++) { + Vector3 start = {env->entities[i].traj_x[j], env->entities[i].traj_y[j], 1}; + Vector3 end = {env->entities[i].traj_x[j + 1], env->entities[i].traj_y[j + 1], 1}; Color lineColor = GRAY; - if (env->entities[i].type == ROAD_LANE) lineColor = GRAY; - else if (env->entities[i].type == ROAD_LINE) lineColor = BLUE; - else if (env->entities[i].type == ROAD_EDGE) lineColor = WHITE; - else if (env->entities[i].type == DRIVEWAY) lineColor = RED; - if(env->entities[i].type != ROAD_EDGE){ - continue; - } - if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){ - draw_road_edge(env, start.x, start.y, end.x, end.y); + if (env->entities[i].type == ROAD_LANE) + lineColor = Fade(SOFT_YELLOW, 0.25f); + else if (env->entities[i].type == ROAD_LINE) + lineColor = WHITE; + else if (env->entities[i].type == ROAD_EDGE) + lineColor = WHITE; + else if (env->entities[i].type == DRIVEWAY) + lineColor = RED; + + if (!IsKeyDown(KEY_LEFT_CONTROL) && obs_only == 0) { + if (env->entities[i].type == ROAD_EDGE) { + draw_road_edge(env, start.x, start.y, end.x, end.y); + } else if (env->entities[i].type == ROAD_LANE || env->entities[i].type == ROAD_LINE) { + // Draw road lanes and lines as purple lines + rlSetLineWidth(2.0f); + DrawLine3D(start, end, lineColor); + } } } } - if(show_grid) { - // Draw grid cells using the stored bounds - float grid_start_x = env->grid_map->top_left_x; - float grid_start_y = env->grid_map->bottom_right_y; - for(int i = 0; i < env->grid_map->grid_cols; i++) { - for(int j = 0; j < env->grid_map->grid_rows; j++) { - float x = grid_start_x + i*GRID_CELL_SIZE; - float y = grid_start_y + j*GRID_CELL_SIZE; - DrawCubeWires( - (Vector3){x + GRID_CELL_SIZE/2, y + GRID_CELL_SIZE/2, 1}, - GRID_CELL_SIZE, GRID_CELL_SIZE, 0.1f, PUFF_BACKGROUND2); - } - } - } EndMode3D(); // Draw track indices for the tracks to predict - if (mode == 1 && env->control_mode == CONTROL_TRACKS_TO_PREDICT) { - float map_width = env->grid_map->bottom_right_x - env->grid_map->top_left_x; + if (mode == 1 && env->control_mode == CONTROL_WOSAC) { float map_height = env->grid_map->top_left_y - env->grid_map->bottom_right_y; float pixels_per_world_unit = client->height / map_height; @@ -3242,151 +3107,105 @@ void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers, float raw_x = -env->entities[agent_idx].x * pixels_per_world_unit; float raw_y = env->entities[agent_idx].y * pixels_per_world_unit; - int screen_x = (int)raw_x + client->width/2 + 20; - int screen_y = (int)raw_y + client->height/2 - 25; + int screen_x = (int)raw_x + client->width / 2 + 20; + int screen_y = (int)raw_y + client->height / 2 - 25; - if (screen_x >= 0 && screen_x <= client->width && - screen_y >= 0 && screen_y <= client->height) { + if (screen_x >= 0 && screen_x <= client->width && screen_y >= 0 && screen_y <= client->height) { char text[32]; snprintf(text, sizeof(text), "%d", womd_track_idx); int text_width = MeasureText(text, 20); - DrawText(text, screen_x - text_width/2, screen_y, 20, PUFF_WHITE); + DrawText(text, screen_x - text_width / 2, screen_y, 20, PUFF_WHITE); } } } } -void saveTopDownImage(Drive* env, Client* client, const char *filename, RenderTexture2D target, int map_height, int obs, int lasers, int trajectories, int frame_count, float* path, int log_trajectories, int show_grid){ - // Top-down orthographic camera - Camera3D camera = {0}; - camera.position = (Vector3){ 0.0f, 0.0f, 500.0f }; // above the scene - camera.target = (Vector3){ 0.0f, 0.0f, 0.0f }; // look at origin - camera.up = (Vector3){ 0.0f, -1.0f, 0.0f }; - camera.fovy = map_height; - camera.projection = CAMERA_ORTHOGRAPHIC; - Color road = (Color){35, 35, 37, 255}; - - BeginTextureMode(target); - ClearBackground(road); - BeginMode3D(camera); - rlEnableDepthTest(); - - // Draw log trajectories FIRST (in background at lower Z-level) - if(log_trajectories){ - for(int i=0; iactive_agent_count;i++){ - int idx = env->active_agent_indices[i]; - for(int j=0; jentities[idx].array_size;j++){ - float x = env->entities[idx].traj_x[j]; - float y = env->entities[idx].traj_y[j]; - float valid = env->entities[idx].traj_valid[j]; - if(!valid) continue; - DrawSphere((Vector3){x,y,0.5f}, 0.3f, Fade(LIGHTGREEN, 0.6f)); - } - } - } - - // Draw current path trajectories SECOND (slightly higher than log trajectories) - if(trajectories){ - for(int i=0; iactive_agent_indices[env->human_agent_idx]; - Entity* agent = &env->entities[agent_idx]; - - Camera3D camera = {0}; - // Position camera behind and above the agent - camera.position = (Vector3){ - agent->x - (25.0f * cosf(agent->heading)), - agent->y - (25.0f * sinf(agent->heading)), - 15.0f - }; - camera.target = (Vector3){ - agent->x + 40.0f * cosf(agent->heading), - agent->y + 40.0f * sinf(agent->heading), - 1.0f - }; - camera.up = (Vector3){ 0.0f, 0.0f, 1.0f }; - camera.fovy = 45.0f; - camera.projection = CAMERA_PERSPECTIVE; - - Color road = (Color){35, 35, 37, 255}; - - BeginTextureMode(target); - ClearBackground(road); - BeginMode3D(camera); - rlEnableDepthTest(); - draw_scene(env, client, 0, obs_only, lasers, show_grid); // mode=0 for agent view - EndMode3D(); - EndTextureMode(); - - // Save to file - Image img = LoadImageFromTexture(target.texture); - ImageFlipVertical(&img); - ExportImage(img, filename); - UnloadImage(img); -} - -void c_render(Drive* env) { +void c_render(Drive *env) { if (env->client == NULL) { env->client = make_client(env); } - Client* client = env->client; + Client *client = env->client; BeginDrawing(); Color road = (Color){35, 35, 37, 255}; ClearBackground(road); BeginMode3D(client->camera); handle_camera_controls(env->client); draw_scene(env, client, 0, 0, 0, 0); + + if (IsKeyPressed(KEY_TAB)) { + env->human_agent_idx = (env->human_agent_idx + 1) % env->active_agent_count; + } + // Draw debug info - DrawText(TextFormat("Camera Position: (%.2f, %.2f, %.2f)", - client->camera.position.x, - client->camera.position.y, - client->camera.position.z), 10, 10, 20, PUFF_WHITE); - DrawText(TextFormat("Camera Target: (%.2f, %.2f, %.2f)", - client->camera.target.x, - client->camera.target.y, - client->camera.target.z), 10, 30, 20, PUFF_WHITE); + DrawText(TextFormat("Camera Position: (%.2f, %.2f, %.2f)", client->camera.position.x, client->camera.position.y, + client->camera.position.z), + 10, 10, 20, PUFF_WHITE); + DrawText(TextFormat("Camera Target: (%.2f, %.2f, %.2f)", client->camera.target.x, client->camera.target.y, + client->camera.target.z), + 10, 30, 20, PUFF_WHITE); DrawText(TextFormat("Timestep: %d", env->timestep), 10, 50, 20, PUFF_WHITE); - // acceleration & steering + int human_idx = env->active_agent_indices[env->human_agent_idx]; DrawText(TextFormat("Controlling Agent: %d", env->human_agent_idx), 10, 70, 20, PUFF_WHITE); DrawText(TextFormat("Agent Index: %d", human_idx), 10, 90, 20, PUFF_WHITE); + + // Display current action values - yellow when controlling, white otherwise + Color action_color = IsKeyDown(KEY_LEFT_SHIFT) ? YELLOW : PUFF_WHITE; + + if (env->action_type == 0) { // discrete + int *action_array = (int *)env->actions; + int action_val = action_array[env->human_agent_idx]; + + if (env->dynamics_model == CLASSIC) { + int num_steer = 13; + int accel_idx = action_val / num_steer; + int steer_idx = action_val % num_steer; + float accel_value = ACCELERATION_VALUES[accel_idx]; + float steer_value = STEERING_VALUES[steer_idx]; + + DrawText(TextFormat("Acceleration: %.2f m/s^2", accel_value), 10, 110, 20, action_color); + DrawText(TextFormat("Steering: %.3f", steer_value), 10, 130, 20, action_color); + } else if (env->dynamics_model == JERK) { + int num_lat = 3; + int jerk_long_idx = action_val / num_lat; + int jerk_lat_idx = action_val % num_lat; + float jerk_long_value = JERK_LONG[jerk_long_idx]; + float jerk_lat_value = JERK_LAT[jerk_lat_idx]; + + DrawText(TextFormat("Longitudinal Jerk: %.2f m/s^3", jerk_long_value), 10, 110, 20, action_color); + DrawText(TextFormat("Lateral Jerk: %.2f m/s^3", jerk_lat_value), 10, 130, 20, action_color); + } + } else { // continuous + float (*action_array_f)[2] = (float (*)[2])env->actions; + DrawText(TextFormat("Acceleration: %.2f", action_array_f[env->human_agent_idx][0]), 10, 110, 20, action_color); + DrawText(TextFormat("Steering: %.2f", action_array_f[env->human_agent_idx][1]), 10, 130, 20, action_color); + } + + // Show key press status + int status_y = 150; + if (IsKeyDown(KEY_LEFT_SHIFT)) { + DrawText("[shift pressed]", 10, status_y, 20, YELLOW); + status_y += 20; + } + if (IsKeyDown(KEY_SPACE)) { + DrawText("[space pressed]", 10, status_y, 20, YELLOW); + status_y += 20; + } + if (IsKeyDown(KEY_LEFT_CONTROL)) { + DrawText("[ctrl pressed]", 10, status_y, 20, YELLOW); + status_y += 20; + } + // Controls help - DrawText("Controls: W/S - Accelerate/Brake, A/D - Steer, 1-4 - Switch Agent", - 10, client->height - 30, 20, PUFF_WHITE); - // acceleration & steering - if (env->action_type == 1) { // continuous (float) - float (*action_array_f)[2] = (float(*)[2])env->actions; - DrawText(TextFormat("Acceleration: %.2f", action_array_f[env->human_agent_idx][0]), 10, 110, 20, PUFF_WHITE); - DrawText(TextFormat("Steering: %.2f", action_array_f[env->human_agent_idx][1]), 10, 130, 20, PUFF_WHITE); - } else { // discrete (int) - int (*action_array)[2] = (int(*)[2])env->actions; - DrawText(TextFormat("Acceleration: %d", action_array[env->human_agent_idx][0]), 10, 110, 20, PUFF_WHITE); - DrawText(TextFormat("Steering: %d", action_array[env->human_agent_idx][1]), 10, 130, 20, PUFF_WHITE); - } - DrawText(TextFormat("Grid Rows: %d", env->grid_map->grid_rows), 10, 150, 20, PUFF_WHITE); - DrawText(TextFormat("Grid Cols: %d", env->grid_map->grid_cols), 10, 170, 20, PUFF_WHITE); + DrawText("Controls: SHIFT + W/S - Accelerate/Brake, SHIFT + A/D - Steer, TAB - Switch Agent", 10, + client->height - 30, 20, PUFF_WHITE); + + DrawText(TextFormat("Grid Rows: %d", env->grid_map->grid_rows), 10, status_y, 20, PUFF_WHITE); + DrawText(TextFormat("Grid Cols: %d", env->grid_map->grid_cols), 10, status_y + 20, 20, PUFF_WHITE); EndDrawing(); } -void close_client(Client* client){ +void close_client(Client *client) { for (int i = 0; i < 6; i++) { UnloadModel(client->cars[i]); } diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index c3c6be27af..420e726b5e 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -6,6 +6,8 @@ import pufferlib from pufferlib.ocean.drive import binding import torch +from multiprocessing import Pool, cpu_count +from tqdm import tqdm class Drive(pufferlib.PufferEnv): @@ -20,13 +22,16 @@ def __init__( reward_offroad_collision=-0.1, reward_goal=1.0, reward_goal_post_respawn=0.5, - reward_ade=0.0, goal_behavior=0, + goal_target_distance=10.0, goal_radius=2.0, + goal_speed=20.0, collision_behavior=0, offroad_behavior=0, dt=0.1, scenario_length=None, + episode_length=None, + termination_mode=None, resample_frequency=91, num_maps=100, num_agents=512, @@ -45,10 +50,14 @@ def __init__( co_player_enabled=False, num_ego_agents=512, co_player_policy={}, + map_dir="resources/drive/binaries/training", + use_all_maps=False, + report_all_scenarios=False, ): # env self.dt = dt self.render_mode = render_mode + self.report_all_scenarios = report_all_scenarios self.num_maps = num_maps self.report_interval = report_interval self.reward_vehicle_collision = reward_vehicle_collision @@ -56,14 +65,20 @@ def __init__( self.reward_goal = reward_goal self.reward_goal_post_respawn = reward_goal_post_respawn self.goal_radius = goal_radius + self.goal_speed = goal_speed self.goal_behavior = goal_behavior + self.goal_target_distance = goal_target_distance self.collision_behavior = collision_behavior self.offroad_behavior = offroad_behavior - self.reward_ade = reward_ade self.human_agent_idx = human_agent_idx self.scenario_length = scenario_length + self.termination_mode = termination_mode self.resample_frequency = resample_frequency self.ini_file = ini_file + self.use_all_maps = use_all_maps + + if episode_length != None: + self.scenario_length = episode_length # Adaptive driving agent setup self.adaptive_driving_agent = int(adaptive_driving_agent) @@ -119,16 +134,24 @@ def __init__( self.dynamics_model = dynamics_model # Observation space calculation - base_ego_dim = 10 if self.dynamics_model == "jerk" else 7 + self.ego_features = {"classic": binding.EGO_FEATURES_CLASSIC, "jerk": binding.EGO_FEATURES_JERK}.get( + dynamics_model + ) + + self.ego_features += conditioning_dims + + # Extract observation shapes from constants + # These need to be defined in C, since they determine the shape of the arrays + self.max_road_objects = binding.MAX_ROAD_SEGMENT_OBSERVATIONS + self.max_partner_objects = binding.MAX_AGENTS - 1 + self.partner_features = binding.PARTNER_FEATURES + self.road_features = binding.ROAD_FEATURES - partner_features = 7 - road_features = 7 - max_partner_objects = 63 - max_road_objects = 200 self.num_obs = ( - base_ego_dim + conditioning_dims + max_partner_objects * partner_features + max_road_objects * road_features + self.ego_features + + self.max_partner_objects * self.partner_features + + self.max_road_objects * self.road_features ) - self.single_observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(self.num_obs,), dtype=np.float32) # Co-player policy setup @@ -159,12 +182,13 @@ def __init__( self.init_steps = init_steps self.init_mode_str = init_mode self.control_mode_str = control_mode + self.map_dir = map_dir if self.control_mode_str == "control_vehicles": self.control_mode = 0 elif self.control_mode_str == "control_agents": self.control_mode = 1 - elif self.control_mode_str == "control_tracks_to_predict": + elif self.control_mode_str == "control_wosac": self.control_mode = 2 elif self.control_mode_str == "control_sdc_only": self.control_mode = 3 @@ -188,7 +212,8 @@ def __init__( # Multi discrete (assume independence) # self.single_action_space = gymnasium.spaces.MultiDiscrete([7, 13]) elif dynamics_model == "jerk": - self.single_action_space = gymnasium.spaces.MultiDiscrete([4, 3]) + # Joint action space (assume dependence) - 4 longitudinal × 3 lateral = 12 + self.single_action_space = gymnasium.spaces.MultiDiscrete([4 * 3]) else: raise ValueError(f"dynamics_model must be 'classic' or 'jerk'. Got: {dynamics_model}") elif action_type == "continuous": @@ -199,17 +224,17 @@ def __init__( self._action_type_flag = 0 if action_type == "discrete" else 1 # Check if resources directory exists - binary_path = "resources/drive/binaries/map_000.bin" + binary_path = f"{map_dir}/map_000.bin" if not os.path.exists(binary_path): raise FileNotFoundError( f"Required directory {binary_path} not found. Please ensure the Drive maps are downloaded and installed correctly per docs." ) # Check maps availability - available_maps = len([name for name in os.listdir("resources/drive/binaries") if name.endswith(".bin")]) + available_maps = len([name for name in os.listdir(map_dir) if name.endswith(".bin")]) if num_maps > available_maps: raise ValueError( - f"num_maps ({num_maps}) exceeds available maps in directory ({available_maps}). Please reduce num_maps or add more maps to resources/drive/binaries." + f"num_maps ({num_maps}) exceeds available maps in directory ({available_maps}). Please reduce num_maps or add more maps to {map_dir}." ) if self.population_play: if self.num_ego_agents > num_agents: @@ -237,7 +262,6 @@ def __init__( self.co_player_actions = np.zeros(co_player_atn_space.shape, dtype=co_player_atn_space.dtype) else: self.co_player_actions = np.zeros(co_player_atn_space.shape, dtype=np.int32) - env_ids = [] for i in range(self.num_envs): cur = self.agent_offsets[i] @@ -256,13 +280,15 @@ def __init__( reward_offroad_collision=reward_offroad_collision, reward_goal=reward_goal, reward_goal_post_respawn=reward_goal_post_respawn, - reward_ade=reward_ade, goal_radius=goal_radius, + goal_speed=goal_speed, goal_behavior=self.goal_behavior, + goal_target_distance=self.goal_target_distance, collision_behavior=self.collision_behavior, offroad_behavior=self.offroad_behavior, dt=dt, - scenario_length=(int(scenario_length) if scenario_length is not None else None), + scenario_length=(int(self.scenario_length) if self.scenario_length is not None else None), + termination_mode=(int(self.termination_mode) if self.termination_mode is not None else 0), max_controlled_agents=self.max_controlled_agents, map_id=self.map_ids[i], max_agents=nxt - cur, @@ -288,6 +314,7 @@ def __init__( discount_weight_ub=self.discount_weight_ub, init_mode=self.init_mode, control_mode=self.control_mode, + map_dir=map_dir, ) env_ids.append(env_id) @@ -304,6 +331,7 @@ def reset(self, seed=0): def _set_env_variables(self): my_shared_tuple = binding.shared( + map_dir=self.map_dir, num_agents=self.num_agents, num_maps=self.num_maps, init_mode=self.init_mode, @@ -313,6 +341,8 @@ def _set_env_variables(self): goal_behavior=self.goal_behavior, population_play=self.population_play, num_ego_agents=self.num_ego_agents, + goal_target_distance=self.goal_target_distance, + use_all_maps=self.use_all_maps, ) if self.population_play: @@ -377,7 +407,9 @@ def _set_env_variables(self): self.agent_offsets, self.map_ids, self.num_envs = my_shared_tuple self.ego_ids = [i for i in range(self.agent_offsets[-1])] if len(self.ego_ids) != self.num_agents: - raise ValueError("mismatch between number of ego agents and number of agents") + raise ValueError( + f"mismatch between number of ego agents {len(self.ego_ids)} and number of agents {self.num_agents}" + ) self.local_co_player_ids = [[] for i in range(self.num_envs)] self.local_ego_ids = [[0] for i in range(self.num_envs)] @@ -526,10 +558,6 @@ def _compute_delta_metrics(self): delta_key = f"ada_delta_{metric}" delta_metrics[delta_key] = last_metrics[metric] - first_metrics[metric] - # Add a count of how many agents this represents - if "n" in last_metrics: - delta_metrics["ada_agent_count"] = last_metrics["n"] - return delta_metrics def step(self, actions): @@ -547,19 +575,17 @@ def step(self, actions): info = [] if self.tick % self.report_interval == 0: - log = binding.vec_log(self.c_envs) + log = binding.vec_log(self.c_envs, self.num_agents) if log: if self.adaptive_driving_agent: self.current_scenario_infos.append(log) - - # Only append to info if we're in the 0th scenario - if self.current_scenario == 0: + # For training: only report 0-shot (scenario 0) metrics + # For evaluation: report all scenarios when report_all_scenarios=True + if self.current_scenario == 0 or self.report_all_scenarios: info.append(log) - print("0th scenario metrics are ", log, flush=True) else: # Non-adaptive mode: always append info.append(log) - print("Regular metrics are ", log, flush=True) if self.tick % self.scenario_length == 0: if self.adaptive_driving_agent and self.current_scenario_infos: @@ -571,7 +597,6 @@ def step(self, actions): delta_metrics = self._compute_delta_metrics() if delta_metrics: info.append(delta_metrics) - print("delta metrics are ", delta_metrics, flush=True) self.scenario_metrics = [] @@ -588,7 +613,6 @@ def step(self, actions): delta_metrics = self._compute_delta_metrics() if delta_metrics: info.append(delta_metrics) - print("delta metrics 2, are ", delta_metrics, flush=True) self.scenario_metrics = [] self.current_scenario_infos = [] self.current_scenario = 0 @@ -612,13 +636,14 @@ def step(self, actions): dynamics_model=self.dynamics_model, reward_vehicle_collision=self.reward_vehicle_collision, reward_offroad_collision=self.reward_offroad_collision, - reward_goal=self.reward_goal, - reward_goal_post_respawn=self.reward_goal_post_respawn, - reward_ade=self.reward_ade, goal_radius=self.goal_radius, goal_behavior=self.goal_behavior, collision_behavior=self.collision_behavior, offroad_behavior=self.offroad_behavior, + reward_goal=self.reward_goal, + reward_goal_post_respawn=self.reward_goal_post_respawn, + goal_speed=self.goal_speed, + goal_target_distance=self.goal_target_distance, dt=self.dt, scenario_length=(int(self.scenario_length) if self.scenario_length is not None else None), max_controlled_agents=self.max_controlled_agents, @@ -646,6 +671,7 @@ def step(self, actions): init_steps=self.init_steps, init_mode=self.init_mode, control_mode=self.control_mode, + map_dir=self.map_dir, ) env_ids.append(env_id) self.c_envs = binding.vectorize(*env_ids) @@ -662,7 +688,7 @@ def get_global_agent_state(self): """Get current global state of all active agents. Returns: - dict with keys 'x', 'y', 'z', 'heading', 'id' containing numpy arrays + dict with keys 'x', 'y', 'z', 'heading', 'id', 'length', 'width' containing numpy arrays of shape (num_active_agents,) """ num_agents = self.num_agents @@ -673,10 +699,19 @@ def get_global_agent_state(self): "z": np.zeros(num_agents, dtype=np.float32), "heading": np.zeros(num_agents, dtype=np.float32), "id": np.zeros(num_agents, dtype=np.int32), + "length": np.zeros(num_agents, dtype=np.float32), + "width": np.zeros(num_agents, dtype=np.float32), } binding.vec_get_global_agent_state( - self.c_envs, states["x"], states["y"], states["z"], states["heading"], states["id"] + self.c_envs, + states["x"], + states["y"], + states["z"], + states["heading"], + states["id"], + states["length"], + states["width"], ) return states @@ -715,6 +750,32 @@ def get_ground_truth_trajectories(self): return trajectories + def get_road_edge_polylines(self): + """Get road edge polylines for all scenarios. + + Returns: + dict with keys 'x', 'y', 'lengths', 'scenario_id' containing numpy arrays. + x, y are flattened point coordinates; lengths indicates points per polyline. + """ + num_polylines, total_points = binding.vec_get_road_edge_counts(self.c_envs) + + polylines = { + "x": np.zeros(total_points, dtype=np.float32), + "y": np.zeros(total_points, dtype=np.float32), + "lengths": np.zeros(num_polylines, dtype=np.int32), + "scenario_id": np.zeros(num_polylines, dtype=np.int32), + } + + binding.vec_get_road_edge_polylines( + self.c_envs, + polylines["x"], + polylines["y"], + polylines["lengths"], + polylines["scenario_id"], + ) + + return polylines + def render(self): binding.vec_render(self.c_envs, 0) @@ -727,7 +788,13 @@ def calculate_area(p1, p2, p3): return 0.5 * abs((p1["x"] - p3["x"]) * (p2["y"] - p1["y"]) - (p1["x"] - p2["x"]) * (p3["y"] - p1["y"])) -def simplify_polyline(geometry, polyline_reduction_threshold): +def dist(a, b): + dx = a["x"] - b["x"] + dy = a["y"] - b["y"] + return dx * dx + dy * dy + + +def simplify_polyline(geometry, polyline_reduction_threshold, max_segment_length): """Simplify the given polyline using a method inspired by Visvalingham-Whyatt, optimized for Python.""" num_points = len(geometry) if num_points < 3: @@ -756,8 +823,7 @@ def simplify_polyline(geometry, polyline_reduction_threshold): point2 = geometry[k_1] point3 = geometry[k_2] area = calculate_area(point1, point2, point3) - - if area < polyline_reduction_threshold: + if area < polyline_reduction_threshold and dist(point1, point3) <= max_segment_length: skip[k_1] = True skip_changed = True k = k_2 @@ -786,8 +852,6 @@ def save_map_binary(map_data, output_file, unique_map_id): f.write(struct.pack("i", track_index)) # Count total entities - print(len(map_data.get("objects", []))) - print(len(map_data.get("roads", []))) num_objects = len(map_data.get("objects", [])) num_roads = len(map_data.get("roads", [])) # num_entities = num_objects + num_roads @@ -869,7 +933,7 @@ def save_map_binary(map_data, output_file, unique_map_id): road_type = 15 # breakpoint() if len(geometry) > 10 and road_type <= 16: - geometry = simplify_polyline(geometry, 0.1) + geometry = simplify_polyline(geometry, 0.1, 250) size = len(geometry) # breakpoint() if road_type >= 0 and road_type <= 3: @@ -916,32 +980,75 @@ def load_map(map_name, unique_map_id, binary_output=None): save_map_binary(map_data, binary_output, unique_map_id) -def process_all_maps(): - """Process all maps and save them as binaries""" +def _process_single_map(args): + """Worker function to process a single map file""" + i, map_path, binary_path = args + try: + load_map(str(map_path), i, str(binary_path)) + return (i, map_path.name, True, None) + except Exception as e: + return (i, map_path.name, False, str(e)) + + +def process_all_maps( + data_folder="data/processed/training", + max_maps=50_000, + num_workers=None, + shuffle=False, +): + """Process all maps and save them as binaries using multiprocessing + + Args: + data_folder: Path to the folder containing JSON map files + max_maps: Maximum number of maps to process + num_workers: Number of parallel workers (defaults to cpu_count()) + shuffle: If True, shuffle the JSON files before assigning map IDs. + This ensures that when using num_maps < total, you get + a random mix of all source maps instead of alphabetically first ones. + """ from pathlib import Path + import random - # Create the binaries directory if it doesn't exist - binary_dir = Path("resources/drive/binaries") - binary_dir.mkdir(parents=True, exist_ok=True) + if num_workers is None: + num_workers = cpu_count() # Path to the training data - data_dir = Path("data/processed/training") + data_dir = Path(data_folder) + dataset_name = data_dir.name + + # Create the binaries directory if it doesn't exist + binary_dir = Path(f"resources/drive/binaries/{dataset_name}") + binary_dir.mkdir(parents=True, exist_ok=True) # Get all JSON files in the training directory json_files = sorted(data_dir.glob("*.json")) - print(f"Found {len(json_files)} JSON files") + if shuffle: + json_files = list(json_files) + random.shuffle(json_files) - # Process each JSON file - for i, map_path in enumerate(json_files[:10000]): - binary_file = f"map_{i:03d}.bin" # Use zero-padded numbers for consistent sorting + # Prepare arguments for parallel processing + tasks = [] + for i, map_path in enumerate(json_files[:max_maps]): + binary_file = f"map_{i:03d}.bin" binary_path = binary_dir / binary_file + tasks.append((i, map_path, binary_path)) - print(f"Processing {map_path.name} -> {binary_file}") - # try: - load_map(str(map_path), i, str(binary_path)) - # except Exception as e: - # print(f"Error processing {map_path.name}: {e}") + # Process maps in parallel with progress bar + with Pool(num_workers) as pool: + results = list( + tqdm(pool.imap(_process_single_map, tasks), total=len(tasks), desc="Processing maps", unit="map") + ) + + # Collect statistics + successful = sum(1 for _, _, success, _ in results if success) + failed = sum(1 for _, _, success, _ in results if not success) + + if failed > 0: + print(f"\nFailed {failed}/{len(results)} files:") + for i, name, success, error in results: + if not success: + print(f" {name}: {error}") def test_performance(timeout=10, atn_cache=1024, num_agents=1024): @@ -976,4 +1083,9 @@ def test_performance(timeout=10, atn_cache=1024, num_agents=1024): if __name__ == "__main__": # test_performance() - process_all_maps() + # Process the train dataset + process_all_maps(data_folder="/data/processed/training") + # Process the validation/test dataset + # process_all_maps(data_folder="data/processed/validation") + # # Process the validation_interactive dataset + # process_all_maps(data_folder="data/processed/validation_interactive") diff --git a/pufferlib/ocean/drive/drivenet.h b/pufferlib/ocean/drive/drivenet.h index 56dbda5212..e474bafba8 100644 --- a/pufferlib/ocean/drive/drivenet.h +++ b/pufferlib/ocean/drive/drivenet.h @@ -8,44 +8,47 @@ #include #include +#define NN_INPUT_SIZE 64 +#define NN_HIDDEN_SIZE 256 + typedef struct DriveNet DriveNet; struct DriveNet { int num_agents; int conditioning_dims; int ego_dim; - float* obs_self; - float* obs_partner; - float* obs_road; - float* partner_linear_output; - float* road_linear_output; - float* partner_layernorm_output; - float* road_layernorm_output; - float* partner_linear_output_two; - float* road_linear_output_two; - Linear* ego_encoder; - Linear* road_encoder; - Linear* partner_encoder; - LayerNorm* ego_layernorm; - LayerNorm* road_layernorm; - LayerNorm* partner_layernorm; - Linear* ego_encoder_two; - Linear* road_encoder_two; - Linear* partner_encoder_two; - MaxDim1* partner_max; - MaxDim1* road_max; - CatDim1* cat1; - CatDim1* cat2; - GELU* gelu; - Linear* shared_embedding; - ReLU* relu; - LSTM* lstm; - Linear* actor; - Linear* value_fn; - Multidiscrete* multidiscrete; + float *obs_self; + float *obs_partner; + float *obs_road; + float *partner_linear_output; + float *road_linear_output; + float *partner_layernorm_output; + float *road_layernorm_output; + float *partner_linear_output_two; + float *road_linear_output_two; + Linear *ego_encoder; + Linear *road_encoder; + Linear *partner_encoder; + LayerNorm *ego_layernorm; + LayerNorm *road_layernorm; + LayerNorm *partner_layernorm; + Linear *ego_encoder_two; + Linear *road_encoder_two; + Linear *partner_encoder_two; + MaxDim1 *partner_max; + MaxDim1 *road_max; + CatDim1 *cat1; + CatDim1 *cat2; + GELU *gelu; + Linear *shared_embedding; + ReLU *relu; + LSTM *lstm; + Linear *actor; + Linear *value_fn; + Multidiscrete *multidiscrete; }; -DriveNet* init_drivenet(Weights* weights, int num_agents, int dynamics_model, bool use_rc, bool use_ec, bool use_dc) { - DriveNet* net = calloc(1, sizeof(DriveNet)); +DriveNet *init_drivenet(Weights *weights, int num_agents, int dynamics_model, bool use_rc, bool use_ec, bool use_dc) { + DriveNet *net = calloc(1, sizeof(DriveNet)); int hidden_size = 256; int input_size = 64; @@ -60,8 +63,8 @@ DriveNet* init_drivenet(Weights* weights, int num_agents, int dynamics_model, bo action_size = 7 * 13; // Joint action space logit_sizes[0] = 7 * 13; action_dim = 1; - } else { // JERK - action_size = 7; // 4 + 3 + } else { // JERK + action_size = 7; // 4 + 3 logit_sizes[0] = 4; logit_sizes[1] = 3; action_dim = 2; @@ -69,16 +72,16 @@ DriveNet* init_drivenet(Weights* weights, int num_agents, int dynamics_model, bo net->num_agents = num_agents; - net->obs_self = calloc(num_agents*net->ego_dim, sizeof(float)); - net->obs_partner = calloc(num_agents*63*7, sizeof(float)); // 63 objects, 7 features - net->obs_road = calloc(num_agents*200*13, sizeof(float)); // 200 objects, 13 features + net->obs_self = calloc(num_agents * net->ego_dim, sizeof(float)); + net->obs_partner = calloc(num_agents * 63 * 7, sizeof(float)); // 63 objects, 7 features + net->obs_road = calloc(num_agents * 200 * 13, sizeof(float)); // 200 objects, 13 features - net->partner_linear_output = calloc(num_agents*63*input_size, sizeof(float)); - net->road_linear_output = calloc(num_agents*200*input_size, sizeof(float)); - net->partner_linear_output_two = calloc(num_agents*63*input_size, sizeof(float)); - net->road_linear_output_two = calloc(num_agents*200*input_size, sizeof(float)); - net->partner_layernorm_output = calloc(num_agents*63*input_size, sizeof(float)); - net->road_layernorm_output = calloc(num_agents*200*input_size, sizeof(float)); + net->partner_linear_output = calloc(num_agents * 63 * input_size, sizeof(float)); + net->road_linear_output = calloc(num_agents * 200 * input_size, sizeof(float)); + net->partner_linear_output_two = calloc(num_agents * 63 * input_size, sizeof(float)); + net->road_linear_output_two = calloc(num_agents * 200 * input_size, sizeof(float)); + net->partner_layernorm_output = calloc(num_agents * 63 * input_size, sizeof(float)); + net->road_layernorm_output = calloc(num_agents * 200 * input_size, sizeof(float)); net->ego_encoder = make_linear(weights, num_agents, net->ego_dim, input_size); net->ego_layernorm = make_layernorm(weights, num_agents, input_size); net->ego_encoder_two = make_linear(weights, num_agents, input_size, input_size); @@ -92,19 +95,19 @@ DriveNet* init_drivenet(Weights* weights, int num_agents, int dynamics_model, bo net->road_max = make_max_dim1(num_agents, 200, input_size); net->cat1 = make_cat_dim1(num_agents, input_size, input_size); net->cat2 = make_cat_dim1(num_agents, input_size + input_size, input_size); - net->gelu = make_gelu(num_agents, 3*input_size); - net->shared_embedding = make_linear(weights, num_agents, input_size*3, hidden_size); + net->gelu = make_gelu(num_agents, 3 * input_size); + net->shared_embedding = make_linear(weights, num_agents, input_size * 3, hidden_size); net->relu = make_relu(num_agents, hidden_size); net->actor = make_linear(weights, num_agents, hidden_size, action_size); net->value_fn = make_linear(weights, num_agents, hidden_size, 1); net->lstm = make_lstm(weights, num_agents, hidden_size, 256); - memset(net->lstm->state_h, 0, num_agents*256*sizeof(float)); - memset(net->lstm->state_c, 0, num_agents*256*sizeof(float)); + memset(net->lstm->state_h, 0, num_agents * 256 * sizeof(float)); + memset(net->lstm->state_c, 0, num_agents * 256 * sizeof(float)); net->multidiscrete = make_multidiscrete(num_agents, logit_sizes, action_dim); return net; } -void free_drivenet(DriveNet* net) { +void free_drivenet(DriveNet *net) { free(net->obs_self); free(net->obs_partner); free(net->obs_road); @@ -137,45 +140,50 @@ void free_drivenet(DriveNet* net) { free(net); } -void forward(DriveNet* net, float* observations, int* actions) { +void forward(DriveNet *net, float *observations, int *actions) { int ego_dim = net->ego_dim; + int max_partners = MAX_AGENTS - 1; + int max_road_obs = MAX_ROAD_SEGMENT_OBSERVATIONS; + int partner_features = PARTNER_FEATURES; + int road_features = ROAD_FEATURES; + int road_feat_onehot = road_features + 6; // one-hot extra 6 features for road // Clear previous observations memset(net->obs_self, 0, net->num_agents * ego_dim * sizeof(float)); - memset(net->obs_partner, 0, net->num_agents * 63 * 7 * sizeof(float)); - memset(net->obs_road, 0, net->num_agents * 200 * 13 * sizeof(float)); - - // Reshape observations into 2D boards and additional features - float* obs_self = net->obs_self; - float (*obs_partner)[63][7] = (float (*)[63][7])net->obs_partner; - float (*obs_road)[200][13] = (float (*)[200][13])net->obs_road; + memset(net->obs_partner, 0, net->num_agents * max_partners * partner_features * sizeof(float)); + memset(net->obs_road, 0, net->num_agents * max_road_obs * road_feat_onehot * sizeof(float)); for (int b = 0; b < net->num_agents; b++) { - int b_offset = b * (ego_dim + 63*7 + 200*7); // offset for each batch + int b_offset = b * (ego_dim + max_partners * partner_features + max_road_obs * road_features); int partner_offset = b_offset + ego_dim; - int road_offset = b_offset + ego_dim + 63*7; + int road_offset = b_offset + ego_dim + max_partners * partner_features; + // Process self observation - for(int i = 0; i < ego_dim; i++) { - obs_self[b*ego_dim + i] = observations[b_offset + i]; + for (int i = 0; i < ego_dim; i++) { + net->obs_self[b * ego_dim + i] = observations[b_offset + i]; } // Process partner observation - for(int i = 0; i < 63; i++) { - for(int j = 0; j < 7; j++) { - net->obs_partner[b*63*7 + i*7 + j] = observations[partner_offset + i*7 + j]; + for (int i = 0; i < max_partners; i++) { + for (int j = 0; j < partner_features; j++) { + net->obs_partner[b * max_partners * partner_features + i * partner_features + j] = + observations[partner_offset + i * partner_features + j]; } } // Process road observation - for(int i = 0; i < 200; i++) { - for(int j = 0; j < 7; j++) { - net->obs_road[b*200*13 + i*13 + j] = observations[road_offset + i*7 + j]; + for (int i = 0; i < MAX_ROAD_SEGMENT_OBSERVATIONS; i++) { + for (int j = 0; j < 7; j++) { + net->obs_road[b * MAX_ROAD_SEGMENT_OBSERVATIONS * ROAD_FEATURES_ONEHOT + i * ROAD_FEATURES_ONEHOT + j] = + observations[road_offset + i * 7 + j]; } - for(int j = 0; j < 7; j++) { - if(j == observations[road_offset+i*7 + 6]) { - net->obs_road[b*200*13 + i*13 + 6 + j] = 1.0f; + for (int j = 0; j < 7; j++) { + if (j == observations[road_offset + i * 7 + 6]) { + net->obs_road[b * MAX_ROAD_SEGMENT_OBSERVATIONS * ROAD_FEATURES_ONEHOT + i * ROAD_FEATURES_ONEHOT + + 6 + j] = 1.0f; } else { - net->obs_road[b*200*13 + i*13 + 6 + j] = 0.0f; + net->obs_road[b * MAX_ROAD_SEGMENT_OBSERVATIONS * ROAD_FEATURES_ONEHOT + i * ROAD_FEATURES_ONEHOT + + 6 + j] = 0.0f; } } } @@ -186,57 +194,63 @@ void forward(DriveNet* net, float* observations, int* actions) { layernorm(net->ego_layernorm, net->ego_encoder->output); linear(net->ego_encoder_two, net->ego_layernorm->output); for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 63; obj++) { + for (int obj = 0; obj < max_partners; obj++) { // Get the 7 features for this object - float* obj_features = &net->obs_partner[b*63*7 + obj*7]; + float *obj_features = &net->obs_partner[b * max_partners * partner_features + obj * partner_features]; // Apply linear layer to this object _linear(obj_features, net->partner_encoder->weights, net->partner_encoder->bias, - &net->partner_linear_output[b*63*64 + obj*64], 1, 7, 64); + &net->partner_linear_output[b * max_partners * NN_INPUT_SIZE + obj * NN_INPUT_SIZE], 1, + partner_features, NN_INPUT_SIZE); } } for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 63; obj++) { - float* after_first = &net->partner_linear_output[b*63*64 + obj*64]; + for (int obj = 0; obj < max_partners; obj++) { + float *after_first = &net->partner_linear_output[b * max_partners * NN_INPUT_SIZE + obj * NN_INPUT_SIZE]; _layernorm(after_first, net->partner_layernorm->weights, net->partner_layernorm->bias, - &net->partner_layernorm_output[b*63*64 + obj*64], 1, 64); + &net->partner_layernorm_output[b * max_partners * NN_INPUT_SIZE + obj * NN_INPUT_SIZE], 1, + NN_INPUT_SIZE); } } for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 63; obj++) { + for (int obj = 0; obj < max_partners; obj++) { // Get the 7 features for this object - float* obj_features = &net->partner_layernorm_output[b*63*64 + obj*64]; + float *obj_features = + &net->partner_layernorm_output[b * max_partners * NN_INPUT_SIZE + obj * NN_INPUT_SIZE]; // Apply linear layer to this object _linear(obj_features, net->partner_encoder_two->weights, net->partner_encoder_two->bias, - &net->partner_linear_output_two[b*63*64 + obj*64], 1, 64, 64); - + &net->partner_linear_output_two[b * max_partners * NN_INPUT_SIZE + obj * NN_INPUT_SIZE], 1, + NN_INPUT_SIZE, NN_INPUT_SIZE); } } // Process road objects: apply linear to each object individually for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 200; obj++) { + for (int obj = 0; obj < max_road_obs; obj++) { // Get the 13 features for this object - float* obj_features = &net->obs_road[b*200*13 + obj*13]; + float *obj_features = &net->obs_road[b * max_road_obs * ROAD_FEATURES_ONEHOT + obj * ROAD_FEATURES_ONEHOT]; // Apply linear layer to this object _linear(obj_features, net->road_encoder->weights, net->road_encoder->bias, - &net->road_linear_output[b*200*64 + obj*64], 1, 13, 64); + &net->road_linear_output[b * max_road_obs * NN_INPUT_SIZE + obj * NN_INPUT_SIZE], 1, + ROAD_FEATURES_ONEHOT, NN_INPUT_SIZE); } } // Apply layer norm and second linear to each road object for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 200; obj++) { - float* after_first = &net->road_linear_output[b*200*64 + obj*64]; + for (int obj = 0; obj < max_road_obs; obj++) { + float *after_first = &net->road_linear_output[b * max_road_obs * NN_INPUT_SIZE + obj * NN_INPUT_SIZE]; _layernorm(after_first, net->road_layernorm->weights, net->road_layernorm->bias, - &net->road_layernorm_output[b*200*64 + obj*64], 1, 64); + &net->road_layernorm_output[b * max_road_obs * NN_INPUT_SIZE + obj * NN_INPUT_SIZE], 1, + NN_INPUT_SIZE); } } for (int b = 0; b < net->num_agents; b++) { - for (int obj = 0; obj < 200; obj++) { - float* after_first = &net->road_layernorm_output[b*200*64 + obj*64]; + for (int obj = 0; obj < max_road_obs; obj++) { + float *after_first = &net->road_layernorm_output[b * max_road_obs * NN_INPUT_SIZE + obj * NN_INPUT_SIZE]; _linear(after_first, net->road_encoder_two->weights, net->road_encoder_two->bias, - &net->road_linear_output_two[b*200*64 + obj*64], 1, 64, 64); + &net->road_linear_output_two[b * max_road_obs * NN_INPUT_SIZE + obj * NN_INPUT_SIZE], 1, + NN_INPUT_SIZE, NN_INPUT_SIZE); } } diff --git a/pufferlib/ocean/drive/error.h b/pufferlib/ocean/drive/error.h index b1eb78e7ed..77ae171bb5 100644 --- a/pufferlib/ocean/drive/error.h +++ b/pufferlib/ocean/drive/error.h @@ -18,21 +18,29 @@ typedef enum { ERROR_UNKNOWN } ErrorType; -const char* error_type_to_string(ErrorType type) { +const char *error_type_to_string(ErrorType type) { switch (type) { - case ERROR_NONE: return "No Error"; - case ERROR_NULL_POINTER: return "Null Pointer"; - case ERROR_INVALID_ARGUMENT: return "Invalid Argument"; - case ERROR_OUT_OF_BOUNDS: return "Out of Bounds"; - case ERROR_MEMORY_ALLOCATION: return "Memory Allocation Failed"; - case ERROR_FILE_NOT_FOUND: return "File Not Found"; - case ERROR_INITIALIZATION_FAILED: return "Initialization Failed"; - default: return "Unknown Error"; + case ERROR_NONE: + return "No Error"; + case ERROR_NULL_POINTER: + return "Null Pointer"; + case ERROR_INVALID_ARGUMENT: + return "Invalid Argument"; + case ERROR_OUT_OF_BOUNDS: + return "Out of Bounds"; + case ERROR_MEMORY_ALLOCATION: + return "Memory Allocation Failed"; + case ERROR_FILE_NOT_FOUND: + return "File Not Found"; + case ERROR_INITIALIZATION_FAILED: + return "Initialization Failed"; + default: + return "Unknown Error"; } } // Enhanced error function with custom message support -void raise_error_with_message(ErrorType type, const char* format, ...) { +void raise_error_with_message(ErrorType type, const char *format, ...) { printf("Error occurred: %s", error_type_to_string(type)); if (format != NULL) { @@ -47,35 +55,28 @@ void raise_error_with_message(ErrorType type, const char* format, ...) { } // Simple error function (backward compatibility) -void raise_error(ErrorType type) { - raise_error_with_message(type, NULL); -} +void raise_error(ErrorType type) { raise_error_with_message(type, NULL); } // Convenience macros for common error patterns -#define RAISE_FILE_ERROR(path) \ - raise_error_with_message(ERROR_FILE_NOT_FOUND, "at path: %s", path) +#define RAISE_FILE_ERROR(path) raise_error_with_message(ERROR_FILE_NOT_FOUND, "at path: %s", path) -#define RAISE_BOUNDS_ERROR() \ - raise_error(ERROR_OUT_OF_BOUNDS) +#define RAISE_BOUNDS_ERROR() raise_error(ERROR_OUT_OF_BOUNDS) -#define RAISE_BOUNDS_ERROR_WITH_BOUNDS(index, min, max) \ +#define RAISE_BOUNDS_ERROR_WITH_BOUNDS(index, min, max) \ raise_error_with_message(ERROR_OUT_OF_BOUNDS, "index %d exceeds minimum of %d and maximum %d", index, min, max) -#define RAISE_NULL_ERROR() \ - raise_error(ERROR_NULL_POINTER) +#define RAISE_NULL_ERROR() raise_error(ERROR_NULL_POINTER) -#define RAISE_NULL_ERROR_WITH_NAME(var_name) \ +#define RAISE_NULL_ERROR_WITH_NAME(var_name) \ raise_error_with_message(ERROR_NULL_POINTER, "variable '%s' is null", var_name) -#define RAISE_MEMORY_ERROR() \ - raise_error(ERROR_MEMORY_ALLOCATION) +#define RAISE_MEMORY_ERROR() raise_error(ERROR_MEMORY_ALLOCATION) -#define RAISE_MEMORY_ERROR_WITH_SIZE(size) \ +#define RAISE_MEMORY_ERROR_WITH_SIZE(size) \ raise_error_with_message(ERROR_MEMORY_ALLOCATION, "failed to allocate %zu bytes", size) -#define RAISE_INVALID_ARG_ERROR() \ - raise_error(ERROR_INVALID_ARGUMENT) +#define RAISE_INVALID_ARG_ERROR() raise_error(ERROR_INVALID_ARGUMENT) -#define RAISE_INVALID_ARG_ERROR_WITH_ARG(arg_name, value) \ +#define RAISE_INVALID_ARG_ERROR_WITH_ARG(arg_name, value) \ raise_error_with_message(ERROR_INVALID_ARGUMENT, "invalid value for '%s': %d", arg_name, value) #endif diff --git a/pufferlib/ocean/drive/visualize.c b/pufferlib/ocean/drive/visualize.c index 172ddd1a8f..4820bd9be0 100644 --- a/pufferlib/ocean/drive/visualize.c +++ b/pufferlib/ocean/drive/visualize.c @@ -43,20 +43,9 @@ bool OpenVideo(VideoRecorder *recorder, const char *output_filename, int width, for (int fd = 3; fd < 256; fd++) { close(fd); } - execlp("ffmpeg", "ffmpeg", - "-y", - "-f", "rawvideo", - "-pix_fmt", "rgba", - "-s", size_str, - "-r", "30", - "-i", "-", - "-c:v", "libx264", - "-pix_fmt", "yuv420p", - "-preset", "ultrafast", - "-crf", "23", - "-loglevel", "error", - output_filename, - NULL); + execlp("ffmpeg", "ffmpeg", "-y", "-f", "rawvideo", "-pix_fmt", "rgba", "-s", size_str, "-r", "30", "-i", "-", + "-c:v", "libx264", "-pix_fmt", "yuv420p", "-preset", "ultrafast", "-crf", "23", "-loglevel", "error", + output_filename, NULL); TraceLog(LOG_ERROR, "Failed to launch ffmpeg"); return false; } @@ -76,16 +65,25 @@ void CloseVideo(VideoRecorder *recorder) { waitpid(recorder->pid, NULL, 0); } -void renderTopDownView(Drive* env, Client* client, int map_height, int obs, int lasers, int trajectories, int frame_count, float* path, int log_trajectories, int show_grid, int img_width, int img_height) { - +void renderTopDownView(Drive *env, Client *client, int map_height, int obs, int lasers, int trajectories, + int frame_count, float *path, int show_human_logs, int show_grid, int img_width, int img_height, + int zoom_in) { BeginDrawing(); // Top-down orthographic camera Camera3D camera = {0}; - camera.position = (Vector3){ 0.0f, 0.0f, 500.0f }; // above the scene - camera.target = (Vector3){ 0.0f, 0.0f, 0.0f }; // look at origin - camera.up = (Vector3){ 0.0f, -1.0f, 0.0f }; - camera.fovy = map_height; + + if (zoom_in) { // Zoom in on part of the map + camera.position = (Vector3){0.0f, 0.0f, 500.0f}; // above the scene + camera.target = (Vector3){0.0f, 0.0f, 0.0f}; // look at origin + camera.fovy = map_height; + } else { // Show full map + camera.position = (Vector3){env->grid_map->top_left_x, env->grid_map->bottom_right_y, 500.0f}; + camera.target = (Vector3){env->grid_map->top_left_x, env->grid_map->bottom_right_y, 0.0f}; + camera.fovy = 2 * map_height; + } + + camera.up = (Vector3){0.0f, -1.0f, 0.0f}; camera.projection = CAMERA_ORTHOGRAPHIC; client->width = img_width; @@ -97,25 +95,25 @@ void renderTopDownView(Drive* env, Client* client, int map_height, int obs, int rlEnableDepthTest(); // Draw human replay trajectories if enabled - if(log_trajectories){ - for(int i=0; iactive_agent_count; i++){ + if (show_human_logs) { + for (int i = 0; i < env->active_agent_count; i++) { int idx = env->active_agent_indices[i]; Vector3 prev_point = {0}; bool has_prev = false; - for(int j = 0; j < env->entities[idx].array_size; j++){ + for (int j = 0; j < env->entities[idx].array_size; j++) { float x = env->entities[idx].traj_x[j]; float y = env->entities[idx].traj_y[j]; float valid = env->entities[idx].traj_valid[j]; - if(!valid) { + if (!valid) { has_prev = false; continue; } Vector3 curr_point = {x, y, 0.5f}; - if(has_prev) { + if (has_prev) { DrawLine3D(prev_point, curr_point, Fade(LIGHTGREEN, 0.6f)); } @@ -126,9 +124,9 @@ void renderTopDownView(Drive* env, Client* client, int map_height, int obs, int } // Draw agent trajs - if(trajectories){ - for(int i=0; iactive_agent_indices[env->human_agent_idx]; - Entity* agent = &env->entities[agent_idx]; + Entity *agent = &env->entities[agent_idx]; BeginDrawing(); Camera3D camera = {0}; // Position camera behind and above the agent - camera.position = (Vector3){ - agent->x - (25.0f * cosf(agent->heading)), - agent->y - (25.0f * sinf(agent->heading)), - 15.0f - }; - camera.target = (Vector3){ - agent->x + 40.0f * cosf(agent->heading), - agent->y + 40.0f * sinf(agent->heading), - 1.0f - }; - camera.up = (Vector3){ 0.0f, 0.0f, 1.0f }; + camera.position = + (Vector3){agent->x - (25.0f * cosf(agent->heading)), agent->y - (25.0f * sinf(agent->heading)), 15.0f}; + camera.target = (Vector3){agent->x + 40.0f * cosf(agent->heading), agent->y + 40.0f * sinf(agent->heading), 1.0f}; + camera.up = (Vector3){0.0f, 0.0f, 1.0f}; camera.fovy = 45.0f; camera.projection = CAMERA_PERSPECTIVE; @@ -180,33 +171,32 @@ static int run_cmd(const char *cmd) { } // Make a high-quality GIF from numbered PNG frames like frame_000.png -static int make_gif_from_frames(const char *pattern, int fps, - const char *palette_path, - const char *out_gif) { +static int make_gif_from_frames(const char *pattern, int fps, const char *palette_path, const char *out_gif) { char cmd[1024]; // 1) Generate palette (no quotes needed for simple filter) // NOTE: if your frames start at 000, you don't need -start_number. - snprintf(cmd, sizeof(cmd), - "ffmpeg -y -framerate %d -i %s -vf palettegen %s", - fps, pattern, palette_path); - if (run_cmd(cmd) != 0) return -1; + snprintf(cmd, sizeof(cmd), "ffmpeg -y -framerate %d -i %s -vf palettegen %s", fps, pattern, palette_path); + if (run_cmd(cmd) != 0) + return -1; // 2) Use palette to encode the GIF - snprintf(cmd, sizeof(cmd), - "ffmpeg -y -framerate %d -i %s -i %s -lavfi paletteuse -loop 0 %s", - fps, pattern, palette_path, out_gif); - if (run_cmd(cmd) != 0) return -1; + snprintf(cmd, sizeof(cmd), "ffmpeg -y -framerate %d -i %s -i %s -lavfi paletteuse -loop 0 %s", fps, pattern, + palette_path, out_gif); + if (run_cmd(cmd) != 0) + return -1; return 0; } -int eval_gif(const char* map_name, const char* policy_name, int show_grid, int obs_only, int lasers, int log_trajectories, int frame_skip, float goal_radius, int init_steps, int use_rc, int use_ec, int use_dc, int max_controlled_agents, const char* view_mode, const char* output_topdown, const char* output_agent, int num_maps, int scenario_length_override, int init_mode, int control_mode, int goal_behavior) { +int eval_gif(const char *map_name, const char *policy_name, int show_grid, int obs_only, int lasers, + int show_human_logs, int frame_skip, const char *view_mode, const char *output_topdown, + const char *output_agent, int num_maps, int zoom_in) { // Parse configuration from INI file - env_init_config conf = {0}; // Initialize to zero - const char* ini_file = "pufferlib/config/ocean/drive.ini"; - if(ini_parse(ini_file, handler, &conf) < 0) { + env_init_config conf = {0}; + const char *ini_file = "pufferlib/config/ocean/drive.ini"; + if (ini_parse(ini_file, handler, &conf) < 0) { fprintf(stderr, "Error: Could not load %s. Cannot determine environment configuration.\n", ini_file); return -1; } @@ -215,92 +205,123 @@ int eval_gif(const char* map_name, const char* policy_name, int show_grid, int o if (map_name == NULL) { srand(time(NULL)); int random_map = rand() % num_maps; - sprintf(map_buffer, "resources/drive/binaries/map_%03d.bin", random_map); // random map file + sprintf(map_buffer, "%s/map_%03d.bin", conf.map_dir, random_map); map_name = map_buffer; } if (frame_skip <= 0) { - frame_skip = 1; // Default: render every frame + frame_skip = 1; } // Check if map file exists - FILE* map_file = fopen(map_name, "rb"); + FILE *map_file = fopen(map_name, "rb"); if (map_file == NULL) { RAISE_FILE_ERROR(map_name); } fclose(map_file); - FILE* policy_file = fopen(policy_name, "rb"); + FILE *policy_file = fopen(policy_name, "rb"); if (policy_file == NULL) { RAISE_FILE_ERROR(policy_name); } fclose(policy_file); + int use_rc = (conf.conditioning != NULL) + ? (strcmp(conf.conditioning->type, "reward") == 0 || strcmp(conf.conditioning->type, "all") == 0) + : 0; + int use_ec = (conf.conditioning != NULL) + ? (strcmp(conf.conditioning->type, "entropy") == 0 || strcmp(conf.conditioning->type, "all") == 0) + : 0; + int use_dc = (conf.conditioning != NULL) + ? (strcmp(conf.conditioning->type, "discount") == 0 || strcmp(conf.conditioning->type, "all") == 0) + : 0; + // Initialize environment with all config values from INI [env] section Drive env = { + .action_type = conf.action_type, .dynamics_model = conf.dynamics_model, .reward_vehicle_collision = conf.reward_vehicle_collision, .reward_offroad_collision = conf.reward_offroad_collision, - .reward_ade = conf.reward_ade, - .goal_radius = goal_radius, + .reward_goal = conf.reward_goal, + .reward_goal_post_respawn = conf.reward_goal_post_respawn, + .goal_radius = conf.goal_radius, + .goal_behavior = conf.goal_behavior, + .goal_target_distance = conf.goal_target_distance, + .goal_speed = conf.goal_speed, .dt = conf.dt, - .map_name = (char*)map_name, - .init_steps = init_steps, - .max_controlled_agents = max_controlled_agents, + .scenario_length = conf.scenario_length, + .termination_mode = conf.termination_mode, .collision_behavior = conf.collision_behavior, .offroad_behavior = conf.offroad_behavior, - .goal_behavior = goal_behavior, - .init_mode = init_mode, - .control_mode = control_mode, + .init_steps = conf.init_steps, + .init_mode = conf.init_mode, + .control_mode = conf.control_mode, + .map_name = (char *)map_name, .use_rc = use_rc, .use_ec = use_ec, .use_dc = use_dc, - // Conditioning weight bounds (defaults from drive.py) - .collision_weight_lb = -0.0f, - .collision_weight_ub = -0.0f, - .offroad_weight_lb = -0.0f, - .offroad_weight_ub = -0.0f, - .goal_weight_lb = 1.0f, - .goal_weight_ub = 1.0f, - .entropy_weight_lb = 0.001f, - .entropy_weight_ub = 0.001f, - .discount_weight_lb = 0.98f, - .discount_weight_ub = 0.98f, + .collision_weight_lb = (conf.conditioning != NULL) ? conf.conditioning->reward_collision_weight_lb : 0.0f, + .collision_weight_ub = (conf.conditioning != NULL) ? conf.conditioning->reward_collision_weight_ub : 0.0f, + .offroad_weight_lb = (conf.conditioning != NULL) ? conf.conditioning->reward_offroad_weight_lb : 0.0f, + .offroad_weight_ub = (conf.conditioning != NULL) ? conf.conditioning->reward_offroad_weight_ub : 0.0f, + .goal_weight_lb = (conf.conditioning != NULL) ? conf.conditioning->reward_goal_weight_lb : 0.0f, + .goal_weight_ub = (conf.conditioning != NULL) ? conf.conditioning->reward_goal_weight_ub : 0.0f, + .entropy_weight_lb = (conf.conditioning != NULL) ? conf.conditioning->entropy_weight_lb : 0.0f, + .entropy_weight_ub = (conf.conditioning != NULL) ? conf.conditioning->entropy_weight_ub : 0.0f, + .discount_weight_lb = (conf.conditioning != NULL) ? conf.conditioning->discount_weight_lb : 0.0f, + .discount_weight_ub = (conf.conditioning != NULL) ? conf.conditioning->discount_weight_ub : 0.0f, + .max_controlled_agents = 32, }; - env.scenario_length = (scenario_length_override > 0) ? scenario_length_override : - (conf.scenario_length > 0) ? conf.scenario_length : TRAJECTORY_LENGTH_DEFAULT; allocate(&env); + // Check if map has any active agents + if (env.active_agent_count == 0) { + fprintf(stderr, "Error: Map %s has no controllable agents\n", map_name); + free_allocated(&env); + return -1; + } + // Set which vehicle to focus on for obs mode - env.human_agent_idx = 0; + int random_agent_idx = rand() % env.active_agent_count; + env.human_agent_idx = random_agent_idx; c_reset(&env); + // Make client for rendering - Client* client = (Client*)calloc(1, sizeof(Client)); + Client *client = (Client *)calloc(1, sizeof(Client)); env.client = client; SetConfigFlags(FLAG_WINDOW_HIDDEN); - SetTargetFPS(6000); float map_width = env.grid_map->bottom_right_x - env.grid_map->top_left_x; float map_height = env.grid_map->top_left_y - env.grid_map->bottom_right_y; printf("Map size: %.1fx%.1f\n", map_width, map_height); - float scale = 6.0f; // Can be used to increase the video quality + float scale = 6.0f; - // Calculate video width and height; round to nearest even number int img_width = (int)roundf(map_width * scale / 2.0f) * 2; int img_height = (int)roundf(map_height * scale / 2.0f) * 2; + InitWindow(img_width, img_height, "Puffer Drive"); SetConfigFlags(FLAG_MSAA_4X_HINT); - Weights* weights = load_weights(policy_name); + // Load the textures and models + client->puffers = LoadTexture("resources/puffers_128.png"); + client->cars[0] = LoadModel("resources/drive/RedCar.glb"); + client->cars[1] = LoadModel("resources/drive/WhiteCar.glb"); + client->cars[2] = LoadModel("resources/drive/BlueCar.glb"); + client->cars[3] = LoadModel("resources/drive/YellowCar.glb"); + client->cars[4] = LoadModel("resources/drive/GreenCar.glb"); + client->cars[5] = LoadModel("resources/drive/GreyCar.glb"); + client->cyclist = LoadModel("resources/drive/cyclist.glb"); + client->pedestrian = LoadModel("resources/drive/pedestrian.glb"); + + Weights *weights = load_weights(policy_name); printf("Active agents in map: %d\n", env.active_agent_count); - DriveNet* net = init_drivenet(weights, env.active_agent_count, env.dynamics_model, use_rc, use_ec, use_dc); + DriveNet *net = init_drivenet(weights, env.active_agent_count, env.dynamics_model, use_rc, use_ec, use_dc); int frame_count = env.scenario_length > 0 ? env.scenario_length : TRAJECTORY_LENGTH_DEFAULT; - int log_trajectory = log_trajectories; char filename_topdown[256]; char filename_agent[256]; @@ -313,10 +334,9 @@ int eval_gif(const char* map_name, const char* policy_name, int show_grid, int o *strrchr(policy_base, '.') = '\0'; char map[256]; - strcpy(map, basename((char*)map_name)); + strcpy(map, basename((char *)map_name)); *strrchr(map, '.') = '\0'; - // Create video directory if it doesn't exist char video_dir[256]; sprintf(video_dir, "%s/video", policy_base); char mkdir_cmd[512]; @@ -346,7 +366,8 @@ int eval_gif(const char* map_name, const char* policy_name, int show_grid, int o if (render_agent) { if (!OpenVideo(&agent_recorder, filename_agent, img_width, img_height)) { - if (render_topdown) CloseVideo(&topdown_recorder); + if (render_topdown) + CloseVideo(&topdown_recorder); CloseWindow(); return -1; } @@ -354,30 +375,32 @@ int eval_gif(const char* map_name, const char* policy_name, int show_grid, int o if (render_topdown) { printf("Recording topdown view...\n"); - for(int i = 0; i < frame_count; i++) { + for (int i = 0; i < frame_count; i++) { if (i % frame_skip == 0) { - renderTopDownView(&env, client, map_height, 0, 0, 0, frame_count, NULL, log_trajectories, show_grid, img_width, img_height); + renderTopDownView(&env, client, map_height, 0, 0, 0, frame_count, NULL, show_human_logs, show_grid, + img_width, img_height, zoom_in); WriteFrame(&topdown_recorder, img_width, img_height); rendered_frames++; } - int (*actions)[2] = (int(*)[2])env.actions; - forward(net, env.observations, (int*)env.actions); + forward(net, env.observations, (int *)env.actions); c_step(&env); } - } if (render_agent) { c_reset(&env); printf("Recording agent view...\n"); - for(int i = 0; i < frame_count; i++) { + for (int i = 0; i < frame_count; i++) { + int human_idx = env.active_agent_indices[env.human_agent_idx]; + if (env.entities[human_idx].respawn_count > 0) { + break; + } if (i % frame_skip == 0) { renderAgentView(&env, client, map_height, obs_only, lasers, show_grid); WriteFrame(&agent_recorder, img_width, img_height); rendered_frames++; } - int (*actions)[2] = (int(*)[2])env.actions; - forward(net, env.observations, (int*)env.actions); + forward(net, env.observations, (int *)env.actions); c_step(&env); } } @@ -386,8 +409,8 @@ int eval_gif(const char* map_name, const char* policy_name, int show_grid, int o double elapsedTime = endTime - startTime; double writeFPS = (elapsedTime > 0) ? rendered_frames / elapsedTime : 0; - printf("Wrote %d frames in %.2f seconds (%.2f FPS) to %s \n", - rendered_frames, elapsedTime, writeFPS, filename_topdown); + printf("Wrote %d frames in %.2f seconds (%.2f FPS) to %s\n", rendered_frames, elapsedTime, writeFPS, + filename_topdown); if (render_topdown) { CloseVideo(&topdown_recorder); @@ -397,7 +420,6 @@ int eval_gif(const char* map_name, const char* policy_name, int show_grid, int o } CloseWindow(); - // Clean up resources free(client); free_allocated(&env); free_drivenet(net); @@ -405,17 +427,21 @@ int eval_gif(const char* map_name, const char* policy_name, int show_grid, int o return 0; } -int main(int argc, char* argv[]) { +int main(int argc, char *argv[]) { + // Visualization-only parameters (not in [env] section) int show_grid = 0; int obs_only = 0; int lasers = 0; - int log_trajectories = 1; + int show_human_logs = 0; int frame_skip = 1; - float goal_radius = 2.0f; - int init_steps = 0; - const char* map_name = NULL; - const char* policy_name = "resources/drive/puffer_drive_weights.bin"; - int max_controlled_agents = -1; + int zoom_in = 0; + const char *view_mode = "both"; + + // File paths and num_maps (not in [env] section) + const char *map_name = NULL; + const char *policy_name = "resources/drive/puffer_drive_weights.bin"; + const char *output_topdown = NULL; + const char *output_agent = NULL; int num_maps = 1; int scenario_length_cli = -1; int use_rc = 0; @@ -425,10 +451,6 @@ int main(int argc, char* argv[]) { int control_mode = 0; int goal_behavior = 0; - const char* view_mode = "both"; // "both", "topdown", "agent" - const char* output_topdown = NULL; - const char* output_agent = NULL; - // Parse command line arguments for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "--show-grid") == 0) { @@ -438,28 +460,34 @@ int main(int argc, char* argv[]) { } else if (strcmp(argv[i], "--lasers") == 0) { lasers = 1; } else if (strcmp(argv[i], "--log-trajectories") == 0) { - log_trajectories = 1; + show_human_logs = 1; } else if (strcmp(argv[i], "--frame-skip") == 0) { if (i + 1 < argc) { frame_skip = atoi(argv[i + 1]); - i++; // Skip the next argument since we consumed it + i++; if (frame_skip <= 0) { - frame_skip = 1; // Ensure valid value + frame_skip = 1; } } - } else if (strcmp(argv[i], "--goal-radius") == 0) { + } else if (strcmp(argv[i], "--zoom-in") == 0) { + zoom_in = 1; + } else if (strcmp(argv[i], "--view") == 0) { if (i + 1 < argc) { - goal_radius = atof(argv[i + 1]); + view_mode = argv[i + 1]; i++; - if (goal_radius <= 0) { - goal_radius = 2.0f; // Ensure valid value + if (strcmp(view_mode, "both") != 0 && strcmp(view_mode, "topdown") != 0 && + strcmp(view_mode, "agent") != 0) { + fprintf(stderr, "Error: --view must be 'both', 'topdown', or 'agent'\n"); + return 1; } + } else { + fprintf(stderr, "Error: --view option requires a value (both/topdown/agent)\n"); + return 1; } } else if (strcmp(argv[i], "--map-name") == 0) { - // Check if there's a next argument for the map path if (i + 1 < argc) { map_name = argv[i + 1]; - i++; // Skip the next argument since we used it as map path + i++; } else { fprintf(stderr, "Error: --map-name option requires a map file path\n"); return 1; @@ -472,20 +500,6 @@ int main(int argc, char* argv[]) { fprintf(stderr, "Error: --policy-name option requires a policy file path\n"); return 1; } - } else if (strcmp(argv[i], "--view") == 0) { - if (i + 1 < argc) { - view_mode = argv[i + 1]; - i++; - if (strcmp(view_mode, "both") != 0 && - strcmp(view_mode, "topdown") != 0 && - strcmp(view_mode, "agent") != 0) { - fprintf(stderr, "Error: --view must be 'both', 'topdown', or 'agent'\n"); - return 1; - } - } else { - fprintf(stderr, "Error: --view option requires a value (both/topdown/agent)\n"); - return 1; - } } else if (strcmp(argv[i], "--output-topdown") == 0) { if (i + 1 < argc) { output_topdown = argv[i + 1]; @@ -496,61 +510,15 @@ int main(int argc, char* argv[]) { output_agent = argv[i + 1]; i++; } - } else if (strcmp(argv[i], "--init-steps") == 0) { - if (i + 1 < argc) { - init_steps = atoi(argv[i + 1]); - i++; - if (init_steps < 0) { - init_steps = 0; - } - } - } else if (strcmp(argv[i], "--init-mode") == 0) { - if (i + 1 < argc) { - init_mode = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--control-mode") == 0) { - if (i + 1 < argc) { - control_mode = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--max-controlled-agents") == 0) { - if (i + 1 < argc) { - max_controlled_agents = atoi(argv[i + 1]); - i++; - } } else if (strcmp(argv[i], "--num-maps") == 0) { if (i + 1 < argc) { num_maps = atoi(argv[i + 1]); i++; } - } else if (strcmp(argv[i], "--scenario-length") == 0) { - if (i + 1 < argc) { - scenario_length_cli = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--use-rc") == 0) { - if (i + 1 < argc) { - use_rc = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--use-ec") == 0) { - if (i + 1 < argc) { - use_ec = atoi(argv[i + 1]); - i++; - } - } else if (strcmp(argv[i], "--use-dc") == 0) { - if (i + 1 < argc) { - use_dc = atoi(argv[i + 1]); - } - } else if (strcmp(argv[i], "--goal-behavior") == 0) { - if (i + 1 < argc) { - goal_behavior = atoi(argv[i + 1]); - i++; - } } } - eval_gif(map_name, policy_name, show_grid, obs_only, lasers, log_trajectories, frame_skip, goal_radius, init_steps, use_rc, use_ec, use_dc, max_controlled_agents, view_mode, output_topdown, output_agent, num_maps, scenario_length_cli, init_mode, control_mode, goal_behavior); + eval_gif(map_name, policy_name, show_grid, obs_only, lasers, show_human_logs, frame_skip, view_mode, output_topdown, + output_agent, num_maps, zoom_in); return 0; } diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 6f7fa3c7df..b0ea15458f 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -3,42 +3,36 @@ #include // Forward declarations for env-specific functions supplied by user -static int my_log(PyObject* dict, Log* log); -static int my_init(Env* env, PyObject* args, PyObject* kwargs); +static int my_log(PyObject *dict, Log *log); +static int my_init(Env *env, PyObject *args, PyObject *kwargs); -static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs); +static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs); #ifndef MY_SHARED -static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { - return NULL; -} +static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) { return NULL; } #endif -static PyObject* my_get(PyObject* dict, Env* env); +static PyObject *my_get(PyObject *dict, Env *env); #ifndef MY_GET -static PyObject* my_get(PyObject* dict, Env* env) { - return NULL; -} +static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } #endif -static int my_put(Env* env, PyObject* args, PyObject* kwargs); +static int my_put(Env *env, PyObject *args, PyObject *kwargs); #ifndef MY_PUT -static int my_put(Env* env, PyObject* args, PyObject* kwargs) { - return 0; -} +static int my_put(Env *env, PyObject *args, PyObject *kwargs) { return 0; } #endif #ifndef MY_METHODS #define MY_METHODS {NULL, NULL, 0, NULL} #endif -static Env* unpack_env(PyObject* args) { - PyObject* handle_obj = PyTuple_GetItem(args, 0); +static Env *unpack_env(PyObject *args) { + PyObject *handle_obj = PyTuple_GetItem(args, 0); if (!PyObject_TypeCheck(handle_obj, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "env_handle must be an integer"); return NULL; } - Env* env = (Env*)PyLong_AsVoidPtr(handle_obj); + Env *env = (Env *)PyLong_AsVoidPtr(handle_obj); if (!env) { PyErr_SetString(PyExc_ValueError, "Invalid env handle"); return NULL; @@ -48,36 +42,36 @@ static Env* unpack_env(PyObject* args) { } // Python function to initialize the environment -static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *env_init(PyObject *self, PyObject *args, PyObject *kwargs) { if (PyTuple_Size(args) != 6) { PyErr_SetString(PyExc_TypeError, "Environment requires 5 arguments"); return NULL; } - Env* env = (Env*)calloc(1, sizeof(Env)); + Env *env = (Env *)calloc(1, sizeof(Env)); if (!env) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate environment"); return NULL; } - PyObject* obs = PyTuple_GetItem(args, 0); + PyObject *obs = PyTuple_GetItem(args, 0); if (!PyObject_TypeCheck(obs, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Observations must be a NumPy array"); return NULL; } - PyArrayObject* observations = (PyArrayObject*)obs; + PyArrayObject *observations = (PyArrayObject *)obs; if (!PyArray_ISCONTIGUOUS(observations)) { PyErr_SetString(PyExc_ValueError, "Observations must be contiguous"); return NULL; } env->observations = PyArray_DATA(observations); - PyObject* act = PyTuple_GetItem(args, 1); + PyObject *act = PyTuple_GetItem(args, 1); if (!PyObject_TypeCheck(act, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Actions must be a NumPy array"); return NULL; } - PyArrayObject* actions = (PyArrayObject*)act; + PyArrayObject *actions = (PyArrayObject *)act; if (!PyArray_ISCONTIGUOUS(actions)) { PyErr_SetString(PyExc_ValueError, "Actions must be contiguous"); return NULL; @@ -88,12 +82,12 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } - PyObject* rew = PyTuple_GetItem(args, 2); + PyObject *rew = PyTuple_GetItem(args, 2); if (!PyObject_TypeCheck(rew, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Rewards must be a NumPy array"); return NULL; } - PyArrayObject* rewards = (PyArrayObject*)rew; + PyArrayObject *rewards = (PyArrayObject *)rew; if (!PyArray_ISCONTIGUOUS(rewards)) { PyErr_SetString(PyExc_ValueError, "Rewards must be contiguous"); return NULL; @@ -104,12 +98,12 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { } env->rewards = PyArray_DATA(rewards); - PyObject* term = PyTuple_GetItem(args, 3); + PyObject *term = PyTuple_GetItem(args, 3); if (!PyObject_TypeCheck(term, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Terminals must be a NumPy array"); return NULL; } - PyArrayObject* terminals = (PyArrayObject*)term; + PyArrayObject *terminals = (PyArrayObject *)term; if (!PyArray_ISCONTIGUOUS(terminals)) { PyErr_SetString(PyExc_ValueError, "Terminals must be contiguous"); return NULL; @@ -120,12 +114,12 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { } env->terminals = PyArray_DATA(terminals); - PyObject* trunc = PyTuple_GetItem(args, 4); + PyObject *trunc = PyTuple_GetItem(args, 4); if (!PyObject_TypeCheck(trunc, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Truncations must be a NumPy array"); return NULL; } - PyArrayObject* truncations = (PyArrayObject*)trunc; + PyArrayObject *truncations = (PyArrayObject *)trunc; if (!PyArray_ISCONTIGUOUS(truncations)) { PyErr_SetString(PyExc_ValueError, "Truncations must be contiguous"); return NULL; @@ -136,8 +130,7 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { } // env->truncations = PyArray_DATA(truncations); - - PyObject* seed_arg = PyTuple_GetItem(args, 5); + PyObject *seed_arg = PyTuple_GetItem(args, 5); if (!PyObject_TypeCheck(seed_arg, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "seed must be an integer"); return NULL; @@ -151,11 +144,11 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { if (kwargs == NULL) { kwargs = PyDict_New(); } else { - Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it + Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it } // Add the seed to kwargs - PyObject* py_seed = PyLong_FromLong(seed); + PyObject *py_seed = PyLong_FromLong(seed); if (PyDict_SetItemString(kwargs, "seed", py_seed) < 0) { PyErr_SetString(PyExc_RuntimeError, "Failed to set seed in kwargs"); Py_DECREF(py_seed); @@ -164,7 +157,7 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { } Py_DECREF(py_seed); - PyObject* empty_args = PyTuple_New(0); + PyObject *empty_args = PyTuple_New(0); my_init(env, empty_args, kwargs); Py_DECREF(kwargs); if (PyErr_Occurred()) { @@ -175,14 +168,14 @@ static PyObject* env_init(PyObject* self, PyObject* args, PyObject* kwargs) { } // Python function to reset the environment -static PyObject* env_reset(PyObject* self, PyObject* args) { +static PyObject *env_reset(PyObject *self, PyObject *args) { if (PyTuple_Size(args) != 2) { PyErr_SetString(PyExc_TypeError, "env_reset requires 2 arguments"); return NULL; } - Env* env = unpack_env(args); - if (!env){ + Env *env = unpack_env(args); + if (!env) { return NULL; } c_reset(env); @@ -190,15 +183,15 @@ static PyObject* env_reset(PyObject* self, PyObject* args) { } // Python function to step the environment -static PyObject* env_step(PyObject* self, PyObject* args) { +static PyObject *env_step(PyObject *self, PyObject *args) { int num_args = PyTuple_Size(args); if (num_args != 1) { PyErr_SetString(PyExc_TypeError, "vec_render requires 1 argument"); return NULL; } - Env* env = unpack_env(args); - if (!env){ + Env *env = unpack_env(args); + if (!env) { return NULL; } c_step(env); @@ -206,9 +199,9 @@ static PyObject* env_step(PyObject* self, PyObject* args) { } // Python function to step the environment -static PyObject* env_render(PyObject* self, PyObject* args) { - Env* env = unpack_env(args); - if (!env){ +static PyObject *env_render(PyObject *self, PyObject *args) { + Env *env = unpack_env(args); + if (!env) { return NULL; } c_render(env); @@ -216,9 +209,9 @@ static PyObject* env_render(PyObject* self, PyObject* args) { } // Python function to close the environment -static PyObject* env_close(PyObject* self, PyObject* args) { - Env* env = unpack_env(args); - if (!env){ +static PyObject *env_close(PyObject *self, PyObject *args) { + Env *env = unpack_env(args); + if (!env) { return NULL; } c_close(env); @@ -226,12 +219,12 @@ static PyObject* env_close(PyObject* self, PyObject* args) { Py_RETURN_NONE; } -static PyObject* env_get(PyObject* self, PyObject* args) { - Env* env = unpack_env(args); - if (!env){ +static PyObject *env_get(PyObject *self, PyObject *args) { + Env *env = unpack_env(args); + if (!env) { return NULL; } - PyObject* dict = PyDict_New(); + PyObject *dict = PyDict_New(); my_get(dict, env); if (PyErr_Occurred()) { return NULL; @@ -239,19 +232,19 @@ static PyObject* env_get(PyObject* self, PyObject* args) { return dict; } -static PyObject* env_put(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *env_put(PyObject *self, PyObject *args, PyObject *kwargs) { int num_args = PyTuple_Size(args); if (num_args != 1) { PyErr_SetString(PyExc_TypeError, "env_put requires 1 positional argument"); return NULL; } - Env* env = unpack_env(args); - if (!env){ + Env *env = unpack_env(args); + if (!env) { return NULL; } - PyObject* empty_args = PyTuple_New(0); + PyObject *empty_args = PyTuple_New(0); my_put(env, empty_args, kwargs); if (PyErr_Occurred()) { return NULL; @@ -261,18 +254,18 @@ static PyObject* env_put(PyObject* self, PyObject* args, PyObject* kwargs) { } typedef struct { - Env** envs; + Env **envs; int num_envs; } VecEnv; -static VecEnv* unpack_vecenv(PyObject* args) { - PyObject* handle_obj = PyTuple_GetItem(args, 0); +static VecEnv *unpack_vecenv(PyObject *args) { + PyObject *handle_obj = PyTuple_GetItem(args, 0); if (!PyObject_TypeCheck(handle_obj, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "env_handle must be an integer"); return NULL; } - VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(handle_obj); + VecEnv *vec = (VecEnv *)PyLong_AsVoidPtr(handle_obj); if (!vec) { PyErr_SetString(PyExc_ValueError, "Missing or invalid vec env handle"); return NULL; @@ -286,18 +279,18 @@ static VecEnv* unpack_vecenv(PyObject* args) { return vec; } -static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject *vec_init(PyObject *self, PyObject *args, PyObject *kwargs) { if (PyTuple_Size(args) != 7) { PyErr_SetString(PyExc_TypeError, "vec_init requires 6 arguments"); return NULL; } - VecEnv* vec = (VecEnv*)calloc(1, sizeof(VecEnv)); + VecEnv *vec = (VecEnv *)calloc(1, sizeof(VecEnv)); if (!vec) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env"); return NULL; } - PyObject* num_envs_arg = PyTuple_GetItem(args, 5); + PyObject *num_envs_arg = PyTuple_GetItem(args, 5); if (!PyObject_TypeCheck(num_envs_arg, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "num_envs must be an integer"); return NULL; @@ -308,25 +301,25 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } vec->num_envs = num_envs; - vec->envs = (Env**)calloc(num_envs, sizeof(Env*)); + vec->envs = (Env **)calloc(num_envs, sizeof(Env *)); if (!vec->envs) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env"); return NULL; } - PyObject* seed_obj = PyTuple_GetItem(args, 6); + PyObject *seed_obj = PyTuple_GetItem(args, 6); if (!PyObject_TypeCheck(seed_obj, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "seed must be an integer"); return NULL; } int seed = PyLong_AsLong(seed_obj); - PyObject* obs = PyTuple_GetItem(args, 0); + PyObject *obs = PyTuple_GetItem(args, 0); if (!PyObject_TypeCheck(obs, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Observations must be a NumPy array"); return NULL; } - PyArrayObject* observations = (PyArrayObject*)obs; + PyArrayObject *observations = (PyArrayObject *)obs; if (!PyArray_ISCONTIGUOUS(observations)) { PyErr_SetString(PyExc_ValueError, "Observations must be contiguous"); return NULL; @@ -336,12 +329,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } - PyObject* act = PyTuple_GetItem(args, 1); + PyObject *act = PyTuple_GetItem(args, 1); if (!PyObject_TypeCheck(act, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Actions must be a NumPy array"); return NULL; } - PyArrayObject* actions = (PyArrayObject*)act; + PyArrayObject *actions = (PyArrayObject *)act; if (!PyArray_ISCONTIGUOUS(actions)) { PyErr_SetString(PyExc_ValueError, "Actions must be contiguous"); return NULL; @@ -351,12 +344,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } - PyObject* rew = PyTuple_GetItem(args, 2); + PyObject *rew = PyTuple_GetItem(args, 2); if (!PyObject_TypeCheck(rew, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Rewards must be a NumPy array"); return NULL; } - PyArrayObject* rewards = (PyArrayObject*)rew; + PyArrayObject *rewards = (PyArrayObject *)rew; if (!PyArray_ISCONTIGUOUS(rewards)) { PyErr_SetString(PyExc_ValueError, "Rewards must be contiguous"); return NULL; @@ -366,12 +359,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } - PyObject* term = PyTuple_GetItem(args, 3); + PyObject *term = PyTuple_GetItem(args, 3); if (!PyObject_TypeCheck(term, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Terminals must be a NumPy array"); return NULL; } - PyArrayObject* terminals = (PyArrayObject*)term; + PyArrayObject *terminals = (PyArrayObject *)term; if (!PyArray_ISCONTIGUOUS(terminals)) { PyErr_SetString(PyExc_ValueError, "Terminals must be contiguous"); return NULL; @@ -381,12 +374,12 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } - PyObject* trunc = PyTuple_GetItem(args, 4); + PyObject *trunc = PyTuple_GetItem(args, 4); if (!PyObject_TypeCheck(trunc, &PyArray_Type)) { PyErr_SetString(PyExc_TypeError, "Truncations must be a NumPy array"); return NULL; } - PyArrayObject* truncations = (PyArrayObject*)trunc; + PyArrayObject *truncations = (PyArrayObject *)trunc; if (!PyArray_ISCONTIGUOUS(truncations)) { PyErr_SetString(PyExc_ValueError, "Truncations must be contiguous"); return NULL; @@ -400,11 +393,11 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { if (kwargs == NULL) { kwargs = PyDict_New(); } else { - Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it + Py_INCREF(kwargs); // We need to increment the reference since we'll be modifying it } for (int i = 0; i < num_envs; i++) { - Env* env = (Env*)calloc(1, sizeof(Env)); + Env *env = (Env *)calloc(1, sizeof(Env)); if (!env) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate environment"); Py_DECREF(kwargs); @@ -415,18 +408,18 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { // // Make sure the log is initialized to 0 memset(&env->log, 0, sizeof(Log)); - env->observations = (void*)((char*)PyArray_DATA(observations) + i*PyArray_STRIDE(observations, 0)); - env->actions = (void*)((char*)PyArray_DATA(actions) + i*PyArray_STRIDE(actions, 0)); - env->rewards = (void*)((char*)PyArray_DATA(rewards) + i*PyArray_STRIDE(rewards, 0)); - env->terminals = (void*)((char*)PyArray_DATA(terminals) + i*PyArray_STRIDE(terminals, 0)); + env->observations = (void *)((char *)PyArray_DATA(observations) + i * PyArray_STRIDE(observations, 0)); + env->actions = (void *)((char *)PyArray_DATA(actions) + i * PyArray_STRIDE(actions, 0)); + env->rewards = (void *)((char *)PyArray_DATA(rewards) + i * PyArray_STRIDE(rewards, 0)); + env->terminals = (void *)((char *)PyArray_DATA(terminals) + i * PyArray_STRIDE(terminals, 0)); // env->truncations = (void*)((char*)PyArray_DATA(truncations) + i*PyArray_STRIDE(truncations, 0)); // Assumes each process has the same number of environments - int env_seed = i + seed*vec->num_envs; + int env_seed = i + seed * vec->num_envs; srand(env_seed); // Add the seed to kwargs for this environment - PyObject* py_seed = PyLong_FromLong(env_seed); + PyObject *py_seed = PyLong_FromLong(env_seed); if (PyDict_SetItemString(kwargs, "seed", py_seed) < 0) { PyErr_SetString(PyExc_RuntimeError, "Failed to set seed in kwargs"); Py_DECREF(py_seed); @@ -435,7 +428,7 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { } Py_DECREF(py_seed); - PyObject* empty_args = PyTuple_New(0); + PyObject *empty_args = PyTuple_New(0); my_init(env, empty_args, kwargs); if (PyErr_Occurred()) { return NULL; @@ -446,22 +439,21 @@ static PyObject* vec_init(PyObject* self, PyObject* args, PyObject* kwargs) { return PyLong_FromVoidPtr(vec); } - // Python function to close the environment -static PyObject* vectorize(PyObject* self, PyObject* args) { +static PyObject *vectorize(PyObject *self, PyObject *args) { int num_envs = PyTuple_Size(args); if (num_envs == 0) { PyErr_SetString(PyExc_TypeError, "make_vec requires at least 1 env id"); return NULL; } - VecEnv* vec = (VecEnv*)calloc(1, sizeof(VecEnv)); + VecEnv *vec = (VecEnv *)calloc(1, sizeof(VecEnv)); if (!vec) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env"); return NULL; } - vec->envs = (Env**)calloc(num_envs, sizeof(Env*)); + vec->envs = (Env **)calloc(num_envs, sizeof(Env *)); if (!vec->envs) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate vec env"); return NULL; @@ -469,29 +461,30 @@ static PyObject* vectorize(PyObject* self, PyObject* args) { vec->num_envs = num_envs; for (int i = 0; i < num_envs; i++) { - PyObject* handle_obj = PyTuple_GetItem(args, i); + PyObject *handle_obj = PyTuple_GetItem(args, i); if (!PyObject_TypeCheck(handle_obj, &PyLong_Type)) { - PyErr_SetString(PyExc_TypeError, "Env ids must be integers. Pass them as separate args with *env_ids, not as a list."); + PyErr_SetString(PyExc_TypeError, + "Env ids must be integers. Pass them as separate args with *env_ids, not as a list."); return NULL; } - vec->envs[i] = (Env*)PyLong_AsVoidPtr(handle_obj); + vec->envs[i] = (Env *)PyLong_AsVoidPtr(handle_obj); } return PyLong_FromVoidPtr(vec); } -static PyObject* vec_reset(PyObject* self, PyObject* args) { +static PyObject *vec_reset(PyObject *self, PyObject *args) { if (PyTuple_Size(args) != 2) { PyErr_SetString(PyExc_TypeError, "vec_reset requires 2 arguments"); return NULL; } - VecEnv* vec = unpack_vecenv(args); + VecEnv *vec = unpack_vecenv(args); if (!vec) { return NULL; } - PyObject* seed_arg = PyTuple_GetItem(args, 1); + PyObject *seed_arg = PyTuple_GetItem(args, 1); if (!PyObject_TypeCheck(seed_arg, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "seed must be an integer"); return NULL; @@ -500,20 +493,20 @@ static PyObject* vec_reset(PyObject* self, PyObject* args) { for (int i = 0; i < vec->num_envs; i++) { // Assumes each process has the same number of environments - srand(i + seed*vec->num_envs); + srand(i + seed * vec->num_envs); c_reset(vec->envs[i]); } Py_RETURN_NONE; } -static PyObject* vec_step(PyObject* self, PyObject* arg) { +static PyObject *vec_step(PyObject *self, PyObject *arg) { int num_args = PyTuple_Size(arg); if (num_args != 1) { PyErr_SetString(PyExc_TypeError, "vec_step requires 1 argument"); return NULL; } - VecEnv* vec = unpack_vecenv(arg); + VecEnv *vec = unpack_vecenv(arg); if (!vec) { return NULL; } @@ -524,20 +517,20 @@ static PyObject* vec_step(PyObject* self, PyObject* arg) { Py_RETURN_NONE; } -static PyObject* vec_render(PyObject* self, PyObject* args) { +static PyObject *vec_render(PyObject *self, PyObject *args) { int num_args = PyTuple_Size(args); if (num_args != 2) { PyErr_SetString(PyExc_TypeError, "vec_render requires 2 arguments"); return NULL; } - VecEnv* vec = (VecEnv*)PyLong_AsVoidPtr(PyTuple_GetItem(args, 0)); + VecEnv *vec = (VecEnv *)PyLong_AsVoidPtr(PyTuple_GetItem(args, 0)); if (!vec) { PyErr_SetString(PyExc_ValueError, "Invalid vec_env handle"); return NULL; } - PyObject* env_id_arg = PyTuple_GetItem(args, 1); + PyObject *env_id_arg = PyTuple_GetItem(args, 1); if (!PyObject_TypeCheck(env_id_arg, &PyLong_Type)) { PyErr_SetString(PyExc_TypeError, "env_id must be an integer"); return NULL; @@ -548,13 +541,13 @@ static PyObject* vec_render(PyObject* self, PyObject* args) { Py_RETURN_NONE; } -static int assign_to_dict(PyObject* dict, char* key, float value) { - PyObject* v = PyFloat_FromDouble(value); +static int assign_to_dict(PyObject *dict, char *key, float value) { + PyObject *v = PyFloat_FromDouble(value); if (v == NULL) { PyErr_SetString(PyExc_TypeError, "Failed to convert log value"); return 1; } - if(PyDict_SetItemString(dict, key, v) < 0) { + if (PyDict_SetItemString(dict, key, v) < 0) { PyErr_SetString(PyExc_TypeError, "Failed to set log value"); return 1; } @@ -562,99 +555,114 @@ static int assign_to_dict(PyObject* dict, char* key, float value) { return 0; } -static PyObject* vec_log(PyObject* self, PyObject* args) { - VecEnv* vec = unpack_vecenv(args); +static PyObject *vec_log(PyObject *self, PyObject *args) { + if (PyTuple_Size(args) != 2) { + PyErr_SetString(PyExc_TypeError, "vec_log requires 2 arguments"); + return NULL; + } + VecEnv *vec = unpack_vecenv(args); if (!vec) { return NULL; } + PyObject *num_agents_arg = PyTuple_GetItem(args, 1); + float num_agents = (float)PyLong_AsLong(num_agents_arg); + // Iterates over logs one float at a time. Will break // horribly if Log has non-float data. Log aggregate = {0}; int num_keys = sizeof(Log) / sizeof(float); - // Adaptive agent logging variables - float ada_delta_completion_rate = 0.0f; - float ada_delta_score = 0.0f; - float ada_delta_perf = 0.0f; - float ada_delta_collision_rate = 0.0f; - float ada_delta_offroad_rate = 0.0f; - float ada_delta_num_goals_reached = 0.0f; - float ada_delta_dnf_rate = 0.0f; - float ada_delta_lane_alignment_rate = 0.0f; - float ada_delta_avg_displacement_error = 0.0f; - float ada_delta_episode_return = 0.0f; - int ada_agent_count = 0; - int has_co_players = 0; // Flag to check if any env has co-players - Co_Player_Log co_player_aggregate = {0}; - int num_co_player_keys = sizeof(Co_Player_Log) / sizeof(float); + int has_co_players = 0; // Flag to check if any env has co-players + Log co_player_aggregate = {0}; // Now using Log struct instead of Co_Player_Log for (int i = 0; i < vec->num_envs; i++) { - Env* env = vec->envs[i]; - + Env *env = vec->envs[i]; for (int j = 0; j < num_keys; j++) { - ((float*)&aggregate)[j] += ((float*)&env->log)[j]; - ((float*)&env->log)[j] = 0.0f; + ((float *)&aggregate)[j] += ((float *)&env->log)[j]; } if (env->population_play && env->num_co_players > 0 && env->co_player_ids != NULL) { has_co_players = 1; - - // Aggregate co-player logs - for (int j = 0; j < num_co_player_keys; j++) { - ((float*)&co_player_aggregate)[j] += ((float*)&env->co_player_log)[j]; - ((float*)&env->co_player_log)[j] = 0.0f; // Reset after aggregating + // Aggregate co-player logs (now same structure as ego logs) + for (int j = 0; j < num_keys; j++) { + ((float *)&co_player_aggregate)[j] += ((float *)&env->co_player_log)[j]; } } + } + + PyObject *dict = PyDict_New(); + // Check if we have enough data from EITHER ego agents OR total (ego + co-players) + float total_n = aggregate.n + (has_co_players ? co_player_aggregate.n : 0.0f); + if (total_n < num_agents) { + return dict; // Not enough data yet } - PyObject* dict = PyDict_New(); + // Got enough data. Reset logs and return metrics + for (int i = 0; i < vec->num_envs; i++) { + Env *env = vec->envs[i]; + for (int j = 0; j < num_keys; j++) { + ((float *)&env->log)[j] = 0.0f; + } - // Average regular logs - if (aggregate.n > 0.0f) { - float n = aggregate.n; + if (env->population_play && env->num_co_players > 0 && env->co_player_ids != NULL) { + for (int j = 0; j < num_keys; j++) { + ((float *)&env->co_player_log)[j] = 0.0f; + } + } + } + + float n = aggregate.n; + // Average across EGO agents only + if (n > 0) { for (int i = 0; i < num_keys; i++) { - ((float*)&aggregate)[i] /= n; + ((float *)&aggregate)[i] /= n; } - // User populates dict - my_log(dict, &aggregate); - assign_to_dict(dict, "n", n); + // Compute completion_rate from aggregated counts + aggregate.completion_rate = aggregate.goals_reached_this_episode / aggregate.goals_sampled_this_episode; } - if (has_co_players && co_player_aggregate.co_player_n > 0.0f) { - float co_player_n = co_player_aggregate.co_player_n; - // Only divide non-zero values to avoid corruption - for (int i = 0; i < num_co_player_keys; i++) { - if (((float*)&co_player_aggregate)[i] != 0.0f) { - ((float*)&co_player_aggregate)[i] /= co_player_n; - } + // User populates dict + my_log(dict, &aggregate); + assign_to_dict(dict, "n", n); + + // Handle co-player metrics + if (has_co_players && co_player_aggregate.n > 0.0f) { + float co_player_n = co_player_aggregate.n; + + // Average co-player metrics across CO-PLAYER agents only + for (int i = 0; i < num_keys; i++) { + ((float *)&co_player_aggregate)[i] /= co_player_n; } - // Add co-player metrics directly - assign_to_dict(dict, "ego_co_player_ratio", aggregate.n / co_player_n); - assign_to_dict(dict, "co_player_completion_rate", co_player_aggregate.co_player_completion_rate); - assign_to_dict(dict, "co_player_collision_rate", co_player_aggregate.co_player_collision_rate); - assign_to_dict(dict, "co_player_offroad_rate", co_player_aggregate.co_player_offroad_rate); - assign_to_dict(dict, "co_player_clean_collision_rate", co_player_aggregate.co_player_clean_collision_rate); - assign_to_dict(dict, "co_player_num_goals_reached", co_player_aggregate.co_player_num_goals_reached); - assign_to_dict(dict, "co_player_score", co_player_aggregate.co_player_score); - assign_to_dict(dict, "co_player_perf", co_player_aggregate.co_player_perf); - assign_to_dict(dict, "co_player_dnf_rate", co_player_aggregate.co_player_dnf_rate); - assign_to_dict(dict, "co_player_episode_length", co_player_aggregate.co_player_episode_length); - assign_to_dict(dict, "co_player_episode_return", co_player_aggregate.co_player_episode_return); - assign_to_dict(dict, "co_player_lane_alignment_rate", co_player_aggregate.co_player_lane_alignment_rate); - assign_to_dict(dict, "co_player_avg_displacement_error", co_player_aggregate.co_player_avg_displacement_error); + // Compute co-player completion rate + co_player_aggregate.completion_rate = + co_player_aggregate.goals_reached_this_episode / co_player_aggregate.goals_sampled_this_episode; + + // Add co-player metrics to dict with co_player_ prefix + assign_to_dict(dict, "ego_co_player_ratio", n / co_player_n); + assign_to_dict(dict, "co_player_completion_rate", co_player_aggregate.completion_rate); + assign_to_dict(dict, "co_player_collision_rate", co_player_aggregate.collision_rate); + assign_to_dict(dict, "co_player_collisions_per_agent", co_player_aggregate.collisions_per_agent); + assign_to_dict(dict, "co_player_offroad_rate", co_player_aggregate.offroad_rate); + assign_to_dict(dict, "co_player_offroad_per_agent", co_player_aggregate.offroad_per_agent); + assign_to_dict(dict, "co_player_score", co_player_aggregate.score); + assign_to_dict(dict, "co_player_dnf_rate", co_player_aggregate.dnf_rate); + assign_to_dict(dict, "co_player_episode_length", co_player_aggregate.episode_length); + assign_to_dict(dict, "co_player_episode_return", co_player_aggregate.episode_return); + assign_to_dict(dict, "co_player_lane_alignment_rate", co_player_aggregate.lane_alignment_rate); + assign_to_dict(dict, "co_player_speed_at_goal", co_player_aggregate.speed_at_goal); + assign_to_dict(dict, "co_player_goals_reached_this_episode", co_player_aggregate.goals_reached_this_episode); + assign_to_dict(dict, "co_player_goals_sampled_this_episode", co_player_aggregate.goals_sampled_this_episode); assign_to_dict(dict, "co_player_n", co_player_n); } - return dict; } - -static PyObject* vec_close(PyObject* self, PyObject* args) { - VecEnv* vec = unpack_vecenv(args); +static PyObject *vec_close(PyObject *self, PyObject *args) { + VecEnv *vec = unpack_vecenv(args); if (!vec) { return NULL; } @@ -668,93 +676,97 @@ static PyObject* vec_close(PyObject* self, PyObject* args) { Py_RETURN_NONE; } -static PyObject* get_global_agent_state(PyObject* self, PyObject* args) { - if (PyTuple_Size(args) != 5) { - PyErr_SetString(PyExc_TypeError, "get_global_agent_state requires 5 arguments"); +static PyObject *get_global_agent_state(PyObject *self, PyObject *args) { + if (PyTuple_Size(args) != 7) { + PyErr_SetString(PyExc_TypeError, "get_global_agent_state requires 7 arguments"); return NULL; } - Env* env = unpack_env(args); + Env *env = unpack_env(args); if (!env) { return NULL; } - Drive* drive = (Drive*)env; // Cast to Drive* + Drive *drive = (Drive *)env; // Cast to Drive* // Get the numpy arrays from arguments - PyObject* x_arr = PyTuple_GetItem(args, 1); - PyObject* y_arr = PyTuple_GetItem(args, 2); - PyObject* z_arr = PyTuple_GetItem(args, 3); - PyObject* heading_arr = PyTuple_GetItem(args, 4); - PyObject* id_arr = PyTuple_GetItem(args, 5); - - if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || - !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || - !PyArray_Check(id_arr)) { + PyObject *x_arr = PyTuple_GetItem(args, 1); + PyObject *y_arr = PyTuple_GetItem(args, 2); + PyObject *z_arr = PyTuple_GetItem(args, 3); + PyObject *heading_arr = PyTuple_GetItem(args, 4); + PyObject *id_arr = PyTuple_GetItem(args, 5); + PyObject *length_arr = PyTuple_GetItem(args, 6); + PyObject *width_arr = PyTuple_GetItem(args, 7); + + if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || + !PyArray_Check(id_arr) || !PyArray_Check(length_arr) || !PyArray_Check(width_arr)) { PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); return NULL; } - float* x_data = (float*)PyArray_DATA((PyArrayObject*)x_arr); - float* y_data = (float*)PyArray_DATA((PyArrayObject*)y_arr); - float* z_data = (float*)PyArray_DATA((PyArrayObject*)z_arr); - float* heading_data = (float*)PyArray_DATA((PyArrayObject*)heading_arr); - int* id_data = (int*)PyArray_DATA((PyArrayObject*)id_arr); + float *x_data = (float *)PyArray_DATA((PyArrayObject *)x_arr); + float *y_data = (float *)PyArray_DATA((PyArrayObject *)y_arr); + float *z_data = (float *)PyArray_DATA((PyArrayObject *)z_arr); + float *heading_data = (float *)PyArray_DATA((PyArrayObject *)heading_arr); + int *id_data = (int *)PyArray_DATA((PyArrayObject *)id_arr); + float *length_data = (float *)PyArray_DATA((PyArrayObject *)length_arr); + float *width_data = (float *)PyArray_DATA((PyArrayObject *)width_arr); - c_get_global_agent_state(drive, x_data, y_data, z_data, heading_data, id_data); + c_get_global_agent_state(drive, x_data, y_data, z_data, heading_data, id_data, length_data, width_data); Py_RETURN_NONE; } -static PyObject* vec_get_global_agent_state(PyObject* self, PyObject* args) { - if (PyTuple_Size(args) != 6) { - PyErr_SetString(PyExc_TypeError, "vec_get_global_agent_state requires 6 arguments"); +static PyObject *vec_get_global_agent_state(PyObject *self, PyObject *args) { + if (PyTuple_Size(args) != 8) { + PyErr_SetString(PyExc_TypeError, "vec_get_global_agent_state requires 8 arguments"); return NULL; } - VecEnv* vec = unpack_vecenv(args); + VecEnv *vec = unpack_vecenv(args); if (!vec) { return NULL; } // Get the numpy arrays from arguments - PyObject* x_arr = PyTuple_GetItem(args, 1); - PyObject* y_arr = PyTuple_GetItem(args, 2); - PyObject* z_arr = PyTuple_GetItem(args, 3); - PyObject* heading_arr = PyTuple_GetItem(args, 4); - PyObject* id_arr = PyTuple_GetItem(args, 5); - - if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || - !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || - !PyArray_Check(id_arr)) { + PyObject *x_arr = PyTuple_GetItem(args, 1); + PyObject *y_arr = PyTuple_GetItem(args, 2); + PyObject *z_arr = PyTuple_GetItem(args, 3); + PyObject *heading_arr = PyTuple_GetItem(args, 4); + PyObject *id_arr = PyTuple_GetItem(args, 5); + PyObject *length_arr = PyTuple_GetItem(args, 6); + PyObject *width_arr = PyTuple_GetItem(args, 7); + + if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || + !PyArray_Check(id_arr) || !PyArray_Check(length_arr) || !PyArray_Check(width_arr)) { PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); return NULL; } - PyArrayObject* x_array = (PyArrayObject*)x_arr; - PyArrayObject* y_array = (PyArrayObject*)y_arr; - PyArrayObject* z_array = (PyArrayObject*)z_arr; - PyArrayObject* heading_array = (PyArrayObject*)heading_arr; - PyArrayObject* id_array = (PyArrayObject*)id_arr; + PyArrayObject *x_array = (PyArrayObject *)x_arr; + PyArrayObject *y_array = (PyArrayObject *)y_arr; + PyArrayObject *z_array = (PyArrayObject *)z_arr; + PyArrayObject *heading_array = (PyArrayObject *)heading_arr; + PyArrayObject *id_array = (PyArrayObject *)id_arr; + PyArrayObject *length_array = (PyArrayObject *)length_arr; + PyArrayObject *width_array = (PyArrayObject *)width_arr; // Get base pointers to the arrays - float* x_base = (float*)PyArray_DATA(x_array); - float* y_base = (float*)PyArray_DATA(y_array); - float* z_base = (float*)PyArray_DATA(z_array); - float* heading_base = (float*)PyArray_DATA(heading_array); - int* id_base = (int*)PyArray_DATA(id_array); + float *x_base = (float *)PyArray_DATA(x_array); + float *y_base = (float *)PyArray_DATA(y_array); + float *z_base = (float *)PyArray_DATA(z_array); + float *heading_base = (float *)PyArray_DATA(heading_array); + int *id_base = (int *)PyArray_DATA(id_array); + float *length_base = (float *)PyArray_DATA(length_array); + float *width_base = (float *)PyArray_DATA(width_array); // Iterate through environments and write to correct offsets int offset = 0; for (int i = 0; i < vec->num_envs; i++) { - Drive* drive = (Drive*)vec->envs[i]; + Drive *drive = (Drive *)vec->envs[i]; // Write to the arrays at the current offset - c_get_global_agent_state(drive, - &x_base[offset], - &y_base[offset], - &z_base[offset], - &heading_base[offset], - &id_base[offset]); + c_get_global_agent_state(drive, &x_base[offset], &y_base[offset], &z_base[offset], &heading_base[offset], + &id_base[offset], &length_base[offset], &width_base[offset]); // Move offset forward by the number of agents in this environment offset += drive->active_agent_count; @@ -763,111 +775,105 @@ static PyObject* vec_get_global_agent_state(PyObject* self, PyObject* args) { Py_RETURN_NONE; } -static PyObject* get_ground_truth_trajectories(PyObject* self, PyObject* args) { +static PyObject *get_ground_truth_trajectories(PyObject *self, PyObject *args) { if (PyTuple_Size(args) != 7) { PyErr_SetString(PyExc_TypeError, "get_ground_truth_trajectories requires 7 arguments"); return NULL; } - Env* env = unpack_env(args); + Env *env = unpack_env(args); if (!env) { return NULL; } - Drive* drive = (Drive*)env; + Drive *drive = (Drive *)env; // Get the numpy arrays from arguments - PyObject* x_arr = PyTuple_GetItem(args, 1); - PyObject* y_arr = PyTuple_GetItem(args, 2); - PyObject* z_arr = PyTuple_GetItem(args, 3); - PyObject* heading_arr = PyTuple_GetItem(args, 4); - PyObject* valid_arr = PyTuple_GetItem(args, 5); - PyObject* id_arr = PyTuple_GetItem(args, 6); - PyObject* scenario_id_arr = PyTuple_GetItem(args, 7); - - if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || - !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || + PyObject *x_arr = PyTuple_GetItem(args, 1); + PyObject *y_arr = PyTuple_GetItem(args, 2); + PyObject *z_arr = PyTuple_GetItem(args, 3); + PyObject *heading_arr = PyTuple_GetItem(args, 4); + PyObject *valid_arr = PyTuple_GetItem(args, 5); + PyObject *id_arr = PyTuple_GetItem(args, 6); + PyObject *scenario_id_arr = PyTuple_GetItem(args, 7); + + if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || !PyArray_Check(valid_arr) || !PyArray_Check(id_arr) || !PyArray_Check(scenario_id_arr)) { PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); return NULL; } - float* x_data = (float*)PyArray_DATA((PyArrayObject*)x_arr); - float* y_data = (float*)PyArray_DATA((PyArrayObject*)y_arr); - float* z_data = (float*)PyArray_DATA((PyArrayObject*)z_arr); - float* heading_data = (float*)PyArray_DATA((PyArrayObject*)heading_arr); - int* valid_data = (int*)PyArray_DATA((PyArrayObject*)valid_arr); - int* id_data = (int*)PyArray_DATA((PyArrayObject*)id_arr); - int* scenario_id_data = (int*)PyArray_DATA((PyArrayObject*)scenario_id_arr); + float *x_data = (float *)PyArray_DATA((PyArrayObject *)x_arr); + float *y_data = (float *)PyArray_DATA((PyArrayObject *)y_arr); + float *z_data = (float *)PyArray_DATA((PyArrayObject *)z_arr); + float *heading_data = (float *)PyArray_DATA((PyArrayObject *)heading_arr); + int *valid_data = (int *)PyArray_DATA((PyArrayObject *)valid_arr); + int *id_data = (int *)PyArray_DATA((PyArrayObject *)id_arr); + int *scenario_id_data = (int *)PyArray_DATA((PyArrayObject *)scenario_id_arr); - c_get_global_ground_truth_trajectories(drive, x_data, y_data, z_data, heading_data, valid_data, id_data, scenario_id_data); + c_get_global_ground_truth_trajectories(drive, x_data, y_data, z_data, heading_data, valid_data, id_data, + scenario_id_data); Py_RETURN_NONE; } -static PyObject* vec_get_global_ground_truth_trajectories(PyObject* self, PyObject* args) { +static PyObject *vec_get_global_ground_truth_trajectories(PyObject *self, PyObject *args) { if (PyTuple_Size(args) != 8) { PyErr_SetString(PyExc_TypeError, "vec_get_global_ground_truth_trajectories requires 8 arguments"); return NULL; } - VecEnv* vec = unpack_vecenv(args); + VecEnv *vec = unpack_vecenv(args); if (!vec) { return NULL; } // Get the numpy arrays from arguments - PyObject* x_arr = PyTuple_GetItem(args, 1); - PyObject* y_arr = PyTuple_GetItem(args, 2); - PyObject* z_arr = PyTuple_GetItem(args, 3); - PyObject* heading_arr = PyTuple_GetItem(args, 4); - PyObject* valid_arr = PyTuple_GetItem(args, 5); - PyObject* id_arr = PyTuple_GetItem(args, 6); - PyObject* scenario_id_arr = PyTuple_GetItem(args, 7); - - if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || - !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || + PyObject *x_arr = PyTuple_GetItem(args, 1); + PyObject *y_arr = PyTuple_GetItem(args, 2); + PyObject *z_arr = PyTuple_GetItem(args, 3); + PyObject *heading_arr = PyTuple_GetItem(args, 4); + PyObject *valid_arr = PyTuple_GetItem(args, 5); + PyObject *id_arr = PyTuple_GetItem(args, 6); + PyObject *scenario_id_arr = PyTuple_GetItem(args, 7); + + if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(z_arr) || !PyArray_Check(heading_arr) || !PyArray_Check(valid_arr) || !PyArray_Check(id_arr) || !PyArray_Check(scenario_id_arr)) { PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); return NULL; } - PyArrayObject* x_array = (PyArrayObject*)x_arr; - PyArrayObject* y_array = (PyArrayObject*)y_arr; - PyArrayObject* z_array = (PyArrayObject*)z_arr; - PyArrayObject* heading_array = (PyArrayObject*)heading_arr; - PyArrayObject* valid_array = (PyArrayObject*)valid_arr; - PyArrayObject* id_array = (PyArrayObject*)id_arr; - PyArrayObject* scenario_id_array = (PyArrayObject*)scenario_id_arr; + PyArrayObject *x_array = (PyArrayObject *)x_arr; + PyArrayObject *y_array = (PyArrayObject *)y_arr; + PyArrayObject *z_array = (PyArrayObject *)z_arr; + PyArrayObject *heading_array = (PyArrayObject *)heading_arr; + PyArrayObject *valid_array = (PyArrayObject *)valid_arr; + PyArrayObject *id_array = (PyArrayObject *)id_arr; + PyArrayObject *scenario_id_array = (PyArrayObject *)scenario_id_arr; // Get base pointers to the arrays - float* x_base = (float*)PyArray_DATA(x_array); - float* y_base = (float*)PyArray_DATA(y_array); - float* z_base = (float*)PyArray_DATA(z_array); - float* heading_base = (float*)PyArray_DATA(heading_array); - int* valid_base = (int*)PyArray_DATA(valid_array); - int* id_base = (int*)PyArray_DATA(id_array); - int* scenario_id_base = (int*)PyArray_DATA(scenario_id_array); + float *x_base = (float *)PyArray_DATA(x_array); + float *y_base = (float *)PyArray_DATA(y_array); + float *z_base = (float *)PyArray_DATA(z_array); + float *heading_base = (float *)PyArray_DATA(heading_array); + int *valid_base = (int *)PyArray_DATA(valid_array); + int *id_base = (int *)PyArray_DATA(id_array); + int *scenario_id_base = (int *)PyArray_DATA(scenario_id_array); // Get number of timesteps from array shape - npy_intp* x_shape = PyArray_DIMS(x_array); - int num_timesteps = x_shape[1]; // Second dimension for 2D arrays + npy_intp *x_shape = PyArray_DIMS(x_array); + int num_timesteps = x_shape[1]; // Second dimension for 2D arrays // Iterate through environments and write to correct offsets - int agent_offset = 0; // Offset for 1D arrays (id, scenario_id) - int traj_offset = 0; // Offset for 2D arrays (x, y, z, heading, valid) + int agent_offset = 0; // Offset for 1D arrays (id, scenario_id) + int traj_offset = 0; // Offset for 2D arrays (x, y, z, heading, valid) for (int i = 0; i < vec->num_envs; i++) { - Drive* drive = (Drive*)vec->envs[i]; + Drive *drive = (Drive *)vec->envs[i]; - c_get_global_ground_truth_trajectories(drive, - &x_base[traj_offset], - &y_base[traj_offset], - &z_base[traj_offset], - &heading_base[traj_offset], - &valid_base[traj_offset], - &id_base[agent_offset], - &scenario_id_base[agent_offset]); + c_get_global_ground_truth_trajectories(drive, &x_base[traj_offset], &y_base[traj_offset], &z_base[traj_offset], + &heading_base[traj_offset], &valid_base[traj_offset], + &id_base[agent_offset], &scenario_id_base[agent_offset]); // Move offsets forward agent_offset += drive->active_agent_count; @@ -876,8 +882,64 @@ static PyObject* vec_get_global_ground_truth_trajectories(PyObject* self, PyObje Py_RETURN_NONE; } -static double unpack(PyObject* kwargs, char* key) { - PyObject* val = PyDict_GetItemString(kwargs, key); + +static PyObject *vec_get_road_edge_counts(PyObject *self, PyObject *args) { + VecEnv *vec = unpack_vecenv(args); + if (!vec) + return NULL; + + int total_polylines = 0, total_points = 0; + for (int i = 0; i < vec->num_envs; i++) { + Drive *drive = (Drive *)vec->envs[i]; + int np, tp; + c_get_road_edge_counts(drive, &np, &tp); + total_polylines += np; + total_points += tp; + } + return Py_BuildValue("(ii)", total_polylines, total_points); +} + +static PyObject *vec_get_road_edge_polylines(PyObject *self, PyObject *args) { + if (PyTuple_Size(args) != 5) { + PyErr_SetString(PyExc_TypeError, "vec_get_road_edge_polylines requires 5 arguments"); + return NULL; + } + + VecEnv *vec = unpack_vecenv(args); + if (!vec) + return NULL; + + PyObject *x_arr = PyTuple_GetItem(args, 1); + PyObject *y_arr = PyTuple_GetItem(args, 2); + PyObject *lengths_arr = PyTuple_GetItem(args, 3); + PyObject *scenario_ids_arr = PyTuple_GetItem(args, 4); + + if (!PyArray_Check(x_arr) || !PyArray_Check(y_arr) || !PyArray_Check(lengths_arr) || + !PyArray_Check(scenario_ids_arr)) { + PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); + return NULL; + } + + float *x_base = (float *)PyArray_DATA((PyArrayObject *)x_arr); + float *y_base = (float *)PyArray_DATA((PyArrayObject *)y_arr); + int *lengths_base = (int *)PyArray_DATA((PyArrayObject *)lengths_arr); + int *scenario_ids_base = (int *)PyArray_DATA((PyArrayObject *)scenario_ids_arr); + + int poly_offset = 0, pt_offset = 0; + for (int i = 0; i < vec->num_envs; i++) { + Drive *drive = (Drive *)vec->envs[i]; + int np, tp; + c_get_road_edge_counts(drive, &np, &tp); + c_get_road_edge_polylines(drive, &x_base[pt_offset], &y_base[pt_offset], &lengths_base[poly_offset], + &scenario_ids_base[poly_offset]); + poly_offset += np; + pt_offset += tp; + } + Py_RETURN_NONE; +} + +static double unpack(PyObject *kwargs, char *key) { + PyObject *val = PyDict_GetItemString(kwargs, key); if (val == NULL) { char error_msg[100]; snprintf(error_msg, sizeof(error_msg), "Missing required keyword argument '%s'", key); @@ -904,8 +966,8 @@ static double unpack(PyObject* kwargs, char* key) { return 1; } -static char* unpack_str(PyObject* kwargs, char* key) { - PyObject* val = PyDict_GetItemString(kwargs, key); +static char *unpack_str(PyObject *kwargs, char *key) { + PyObject *val = PyDict_GetItemString(kwargs, key); if (val == NULL) { char error_msg[100]; snprintf(error_msg, sizeof(error_msg), "Missing required keyword argument '%s'", key); @@ -918,12 +980,12 @@ static char* unpack_str(PyObject* kwargs, char* key) { PyErr_SetString(PyExc_TypeError, error_msg); return NULL; } - const char* str_val = PyUnicode_AsUTF8(val); + const char *str_val = PyUnicode_AsUTF8(val); if (str_val == NULL) { // PyUnicode_AsUTF8 sets an error on failure return NULL; } - char* ret = strdup(str_val); + char *ret = strdup(str_val); if (ret == NULL) { PyErr_SetString(PyExc_MemoryError, "strdup failed in unpack_str"); } @@ -932,7 +994,8 @@ static char* unpack_str(PyObject* kwargs, char* key) { // Method table static PyMethodDef methods[] = { - {"env_init", (PyCFunction)env_init, METH_VARARGS | METH_KEYWORDS, "Init environment with observation, action, reward, terminal, truncation arrays"}, + {"env_init", (PyCFunction)env_init, METH_VARARGS | METH_KEYWORDS, + "Init environment with observation, action, reward, terminal, truncation arrays"}, {"env_reset", env_reset, METH_VARARGS, "Reset the environment"}, {"env_step", env_step, METH_VARARGS, "Step the environment"}, {"env_render", env_render, METH_VARARGS, "Render the environment"}, @@ -950,21 +1013,36 @@ static PyMethodDef methods[] = { {"get_global_agent_state", get_global_agent_state, METH_VARARGS, "Get global agent state"}, {"vec_get_global_agent_state", vec_get_global_agent_state, METH_VARARGS, "Get agent state from vectorized env"}, {"get_ground_truth_trajectories", get_ground_truth_trajectories, METH_VARARGS, "Get ground truth trajectories"}, - {"vec_get_global_ground_truth_trajectories", vec_get_global_ground_truth_trajectories, METH_VARARGS, "Get ground truth trajectories from vectorized env"}, + {"vec_get_global_ground_truth_trajectories", vec_get_global_ground_truth_trajectories, METH_VARARGS, + "Get ground truth trajectories from vectorized env"}, + {"vec_get_road_edge_counts", vec_get_road_edge_counts, METH_VARARGS, + "Get road edge polyline counts from vectorized env"}, + {"vec_get_road_edge_polylines", vec_get_road_edge_polylines, METH_VARARGS, + "Get road edge polylines from vectorized env"}, MY_METHODS, - {NULL, NULL, 0, NULL} -}; + {NULL, NULL, 0, NULL}}; // Module definition -static PyModuleDef module = { - PyModuleDef_HEAD_INIT, - "binding", - NULL, - -1, - methods -}; +static PyModuleDef module = {PyModuleDef_HEAD_INIT, "binding", NULL, -1, methods}; PyMODINIT_FUNC PyInit_binding(void) { import_array(); - return PyModule_Create(&module); + PyObject *m = PyModule_Create(&module); // Changed variable name from 'module' to 'm' + + if (m == NULL) { + return NULL; + } + + // Make constants accessible from Python + PyModule_AddIntConstant(m, "MAX_ROAD_SEGMENT_OBSERVATIONS", MAX_ROAD_SEGMENT_OBSERVATIONS); + PyModule_AddIntConstant(m, "MAX_AGENTS", MAX_AGENTS); + PyModule_AddIntConstant(m, "TRAJECTORY_LENGTH", TRAJECTORY_LENGTH); + PyModule_AddIntConstant(m, "MAX_ENTITIES_PER_CELL", MAX_ENTITIES_PER_CELL); + + PyModule_AddIntConstant(m, "ROAD_FEATURES", ROAD_FEATURES); + PyModule_AddIntConstant(m, "PARTNER_FEATURES", PARTNER_FEATURES); + PyModule_AddIntConstant(m, "EGO_FEATURES_CLASSIC", EGO_FEATURES_CLASSIC); + PyModule_AddIntConstant(m, "EGO_FEATURES_JERK", EGO_FEATURES_JERK); + + return m; } diff --git a/pufferlib/ocean/env_config.h b/pufferlib/ocean/env_config.h index e8400c51c9..4f26a8b32e 100644 --- a/pufferlib/ocean/env_config.h +++ b/pufferlib/ocean/env_config.h @@ -6,9 +6,22 @@ #include #include +typedef struct { + char *type; + float reward_offroad_weight_lb; + float reward_offroad_weight_ub; + float reward_collision_weight_lb; + float reward_collision_weight_ub; + float reward_goal_weight_lb; + float reward_goal_weight_ub; + float entropy_weight_lb; + float entropy_weight_ub; + float discount_weight_lb; + float discount_weight_ub; +} conditioning_config; + // Config struct for parsing INI files - contains all environment configuration -typedef struct -{ +typedef struct { int action_type; int dynamics_model; float reward_vehicle_collision; @@ -16,49 +29,50 @@ typedef struct float reward_goal; float reward_goal_post_respawn; float reward_vehicle_collision_post_respawn; - float reward_ade; float goal_radius; + float goal_speed; int collision_behavior; int offroad_behavior; int spawn_immunity_timer; float dt; int goal_behavior; + float goal_target_distance; int scenario_length; + int termination_mode; int init_steps; int init_mode; int control_mode; + char map_dir[256]; + conditioning_config *conditioning; } env_init_config; // INI file parser handler - parses all environment configuration from drive.ini -static int handler( - void* config, - const char* section, - const char* name, - const char* value -) { - env_init_config* env_config = (env_init_config*)config; - #define MATCH(s, n) strcmp(section, s) == 0 && strcmp(name, n) == 0 +static int handler(void *config, const char *section, const char *name, const char *value) { + env_init_config *env_config = (env_init_config *)config; +#define MATCH(s, n) strcmp(section, s) == 0 && strcmp(name, n) == 0 if (MATCH("env", "action_type")) { - if (strcmp(value, "\"discrete\"") == 0 ||strcmp(value, "discrete") == 0) { - env_config->action_type = 0; // DISCRETE + if (strcmp(value, "\"discrete\"") == 0 || strcmp(value, "discrete") == 0) { + env_config->action_type = 0; // DISCRETE } else if (strcmp(value, "\"continuous\"") == 0 || strcmp(value, "continuous") == 0) { - env_config->action_type = 1; // CONTINUOUS + env_config->action_type = 1; // CONTINUOUS } else { printf("Warning: Unknown action_type value '%s', defaulting to DISCRETE\n", value); - env_config->action_type = 0; // Default to DISCRETE + env_config->action_type = 0; // Default to DISCRETE } } else if (MATCH("env", "dynamics_model")) { if (strcmp(value, "\"classic\"") == 0 || strcmp(value, "classic") == 0) { - env_config->dynamics_model = 0; // CLASSIC + env_config->dynamics_model = 0; // CLASSIC } else if (strcmp(value, "\"jerk\"") == 0 || strcmp(value, "jerk") == 0) { - env_config->dynamics_model = 1; // JERK + env_config->dynamics_model = 1; // JERK } else { printf("Warning: Unknown dynamics_model value '%s', defaulting to JERK\n", value); - env_config->dynamics_model = 1; // Default to JERK + env_config->dynamics_model = 1; // Default to JERK } } else if (MATCH("env", "goal_behavior")) { env_config->goal_behavior = atoi(value); + } else if (MATCH("env", "goal_target_distance")) { + env_config->goal_target_distance = atof(value); } else if (MATCH("env", "reward_vehicle_collision")) { env_config->reward_vehicle_collision = atof(value); } else if (MATCH("env", "reward_offroad_collision")) { @@ -69,13 +83,13 @@ static int handler( env_config->reward_goal_post_respawn = atof(value); } else if (MATCH("env", "reward_vehicle_collision_post_respawn")) { env_config->reward_vehicle_collision_post_respawn = atof(value); - } else if (MATCH("env", "reward_ade")) { - env_config->reward_ade = atof(value); } else if (MATCH("env", "goal_radius")) { env_config->goal_radius = atof(value); - } else if(MATCH("env", "collision_behavior")){ + } else if (MATCH("env", "goal_speed")) { + env_config->goal_speed = atof(value); + } else if (MATCH("env", "collision_behavior")) { env_config->collision_behavior = atoi(value); - } else if(MATCH("env", "offroad_behavior")){ + } else if (MATCH("env", "offroad_behavior")) { env_config->offroad_behavior = atoi(value); } else if (MATCH("env", "spawn_immunity_timer")) { env_config->spawn_immunity_timer = atoi(value); @@ -83,15 +97,90 @@ static int handler( env_config->dt = atof(value); } else if (MATCH("env", "scenario_length")) { env_config->scenario_length = atoi(value); + } else if (MATCH("env", "termination_mode")) { + env_config->termination_mode = atoi(value); } else if (MATCH("env", "init_steps")) { env_config->init_steps = atoi(value); } else if (MATCH("env", "init_mode")) { env_config->init_mode = atoi(value); } else if (MATCH("env", "control_mode")) { env_config->control_mode = atoi(value); + } else if (MATCH("env", "map_dir")) { + if (sscanf(value, "\"%255[^\"]\"", env_config->map_dir) != 1) { + strncpy(env_config->map_dir, value, sizeof(env_config->map_dir) - 1); + env_config->map_dir[sizeof(env_config->map_dir) - 1] = '\0'; + } + // printf("Parsed map_dir: '%s'\n", env_config->map_dir); + } else if (MATCH("env.conditioning", "type")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + // Remove quotes if present + if (value[0] == '"') { + size_t len = strlen(value) - 2; // -2 for both quotes + env_config->conditioning->type = (char *)malloc(len + 1); + strncpy(env_config->conditioning->type, value + 1, len); + env_config->conditioning->type[len] = '\0'; + } else { + env_config->conditioning->type = strdup(value); + } + } else if (MATCH("env.conditioning", "collision_weight_lb")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_collision_weight_lb = atof(value); + } else if (MATCH("env.conditioning", "collision_weight_ub")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_collision_weight_ub = atof(value); + } else if (MATCH("env.conditioning", "offroad_weight_lb")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_offroad_weight_lb = atof(value); + } else if (MATCH("env.conditioning", "offroad_weight_ub")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_offroad_weight_ub = atof(value); + } else if (MATCH("env.conditioning", "goal_weight_lb")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_goal_weight_lb = atof(value); + } else if (MATCH("env.conditioning", "goal_weight_ub")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->reward_goal_weight_ub = atof(value); + } else if (MATCH("env.conditioning", "entropy_weight_lb")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->entropy_weight_lb = atof(value); + } else if (MATCH("env.conditioning", "entropy_weight_ub")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->entropy_weight_ub = atof(value); + } else if (MATCH("env.conditioning", "discount_weight_lb")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->discount_weight_lb = atof(value); + } else if (MATCH("env.conditioning", "discount_weight_ub")) { + if (env_config->conditioning == NULL) { + env_config->conditioning = (conditioning_config *)malloc(sizeof(conditioning_config)); + } + env_config->conditioning->discount_weight_ub = atof(value); + } + + else { + return 0; // Unknown section/name, indicate failure to handle } - #undef MATCH +#undef MATCH return 1; } diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 7aa3daa388..53912f2e33 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -10,12 +10,19 @@ Recurrent = pufferlib.models.LSTMWrapper +Transformer = pufferlib.models.TransformerWrapper class Drive(nn.Module): def __init__(self, env, input_size=128, hidden_size=128, **kwargs): super().__init__() self.hidden_size = hidden_size + self.observation_size = env.single_observation_space.shape[0] + self.max_partner_objects = env.max_partner_objects + self.partner_features = env.partner_features + self.max_road_objects = env.max_road_objects + self.road_features = env.road_features + self.road_features_after_onehot = env.road_features + 6 # 6 is the number of one-hot encoded categories # Conditioning setup self.use_rc = env.reward_conditioned @@ -26,23 +33,23 @@ def __init__(self, env, input_size=128, hidden_size=128, **kwargs): # Determine ego dimension from environment's dynamics model base_ego_dim = 10 if env.dynamics_model == "jerk" else 7 self.ego_dim = base_ego_dim + self.conditioning_dims - + print(f"ego dimensions: {self.ego_dim}", flush=True) self.ego_encoder = nn.Sequential( pufferlib.pytorch.layer_init(nn.Linear(self.ego_dim, input_size)), nn.LayerNorm(input_size), # nn.ReLU(), pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), ) - max_road_objects = 13 + self.road_encoder = nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(max_road_objects, input_size)), + pufferlib.pytorch.layer_init(nn.Linear(self.road_features_after_onehot, input_size)), nn.LayerNorm(input_size), # nn.ReLU(), pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), ) - max_partner_objects = 7 + self.partner_encoder = nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(max_partner_objects, input_size)), + pufferlib.pytorch.layer_init(nn.Linear(self.partner_features, input_size)), nn.LayerNorm(input_size), # nn.ReLU(), pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), @@ -72,17 +79,18 @@ def forward_train(self, x, state=None): def encode_observations(self, observations, state=None): ego_dim = self.ego_dim - partner_dim = 63 * 7 - road_dim = 200 * 7 + partner_dim = self.max_partner_objects * self.partner_features + road_dim = self.max_road_objects * self.road_features ego_obs = observations[:, :ego_dim] partner_obs = observations[:, ego_dim : ego_dim + partner_dim] road_obs = observations[:, ego_dim + partner_dim : ego_dim + partner_dim + road_dim] - partner_objects = partner_obs.view(-1, 63, 7) - road_objects = road_obs.view(-1, 200, 7) - road_continuous = road_objects[:, :, :6] # First 6 features - road_categorical = road_objects[:, :, 6] - road_onehot = F.one_hot(road_categorical.long(), num_classes=7) # Shape: [batch, 200, 7] + partner_objects = partner_obs.view(-1, self.max_partner_objects, self.partner_features) + + road_objects = road_obs.view(-1, self.max_road_objects, self.road_features) + road_continuous = road_objects[:, :, : self.road_features - 1] + road_categorical = road_objects[:, :, self.road_features - 1] + road_onehot = F.one_hot(road_categorical.long(), num_classes=7) # Shape: [batch, ROAD_MAX_OBJECTS, 7] road_objects = torch.cat([road_continuous, road_onehot], dim=2) ego_features = self.ego_encoder(ego_obs) partner_features, _ = self.partner_encoder(partner_objects).max(dim=1) diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 68b2e1ef2a..270205ed70 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -20,6 +20,7 @@ import configparser from threading import Thread from collections import defaultdict, deque +from pathlib import Path import numpy as np import psutil @@ -34,6 +35,7 @@ import pufferlib.vector import pufferlib.pytorch import pufferlib.utils +import pufferlib.utils try: from pufferlib import _C @@ -74,7 +76,13 @@ def __init__(self, config, vecenv, policy, logger=None): # Vecenv info self.adaptive_driving_agent = getattr(vecenv.driver_env, "env_name", None) == "adaptive_drive" if self.adaptive_driving_agent: - config["bptt_horizon"] = vecenv.driver_env.episode_length + if config.get("policy_architecture", "Recurrent") == "Recurrent": + config["bptt_horizon"] = vecenv.driver_env.episode_length + if config.get("policy_architecture", "Recurrent") == "Transformer": + config["context_window"] = self.context_length = vecenv.driver_env.episode_length + config["bptt_horizon"] = ( + vecenv.driver_env.episode_length + ) ## this is used downstream so you need to define it too vecenv.async_reset(seed) obs_space = vecenv.single_observation_space @@ -84,7 +92,10 @@ def __init__(self, config, vecenv, policy, logger=None): if self.population_play: total_ego_agents = vecenv.num_ego_agents agents_for_calc = total_ego_agents - batch_size = vecenv.driver_env.num_ego_agents * config["bptt_horizon"] * vecenv.num_workers + if config.get("policy_architecture", "Recurrent") == "Recurrent": + batch_size = vecenv.driver_env.num_ego_agents * config["bptt_horizon"] * vecenv.num_workers + if config.get("policy_architecture", "Recurrent") == "Transformer": + batch_size = vecenv.driver_env.num_ego_agents * config["context_window"] * vecenv.num_workers config["batch_size"] = batch_size ## this is dynamic and based on ego agents else: agents_for_calc = total_agents @@ -93,15 +104,40 @@ def __init__(self, config, vecenv, policy, logger=None): self.total_agents = total_agents # Experience - if config["batch_size"] == "auto" and config["bptt_horizon"] == "auto": - raise pufferlib.APIUsageError("Must specify batch_size or bptt_horizon") + if ( + config["batch_size"] == "auto" + and config.get("bptt_horizon", "auto") == "auto" + and config.get("context_window", "auto") == "auto" + ): + raise pufferlib.APIUsageError("Must specify batch_size, bptt_horizon, or context_window") elif config["batch_size"] == "auto": - config["batch_size"] = agents_for_calc * config["bptt_horizon"] - elif config["bptt_horizon"] == "auto": + if config.get("policy_architecture", "Recurrent") == "Recurrent": + config["batch_size"] = agents_for_calc * config["bptt_horizon"] + elif config.get("policy_architecture", "Recurrent") == "Transformer": + config["batch_size"] = agents_for_calc * config["context_window"] + elif ( + config.get("bptt_horizon", "auto") == "auto" + and config.get("policy_architecture", "Recurrent") == "Recurrent" + ): config["bptt_horizon"] = config["batch_size"] // agents_for_calc + elif ( + config.get("context_window", "auto") == "auto" + and config.get("policy_architecture", "Recurrent") == "Transformer" + ): + config["context_window"] = config["batch_size"] // agents_for_calc batch_size = config["batch_size"] - horizon = config["bptt_horizon"] + + # Set horizon based on model type + if config.get("policy_architecture", "Recurrent") == "Recurrent": + horizon = config["bptt_horizon"] + elif config.get("policy_architecture", "Recurrent") == "Transformer": + horizon = config["context_window"] + else: + horizon = config.get("bptt_horizon", config.get("context_window", 1)) + + config["bptt_horizon"] = horizon # For backward compatibility + segments = batch_size // horizon self.segments = segments if not self.population_play: @@ -144,7 +180,7 @@ def __init__(self, config, vecenv, policy, logger=None): ensure_drive_binary() # LSTM - if config["use_rnn"]: + if config.get("policy_architecture", "Recurrent") == "Recurrent": h = policy.hidden_size if self.population_play: n = vecenv.ego_agents_per_batch # Use ego agents per batch @@ -156,6 +192,27 @@ def __init__(self, config, vecenv, policy, logger=None): self.lstm_h = {i * n: torch.zeros(n, h, device=device) for i in range(total_agents // n)} self.lstm_c = {i * n: torch.zeros(n, h, device=device) for i in range(total_agents // n)} + # TRANSFORMER + if config.get("policy_architecture", "Recurrent") == "Transformer": + h = policy.hidden_size + + if self.population_play: + n = vecenv.ego_agents_per_batch # Use ego agents per batch + num_chunks = total_ego_agents // n + # Initialize transformer context buffers + self.transformer_context = {i * n: torch.zeros(n, 0, h, device=device) for i in range(num_chunks)} + self.transformer_position = { + i * n: torch.zeros(n, dtype=torch.long, device=device) for i in range(num_chunks) + } + else: + n = vecenv.agents_per_batch + num_chunks = total_agents // n + # Initialize transformer context buffers + self.transformer_context = {i * n: torch.zeros(n, 0, h, device=device) for i in range(num_chunks)} + self.transformer_position = { + i * n: torch.zeros(n, dtype=torch.long, device=device) for i in range(num_chunks) + } + # Minibatching & gradient accumulation if self.adaptive_driving_agent: minibatch_size = config["minibatch_multiplier"] * horizon @@ -179,7 +236,7 @@ def __init__(self, config, vecenv, policy, logger=None): self.minibatch_segments = self.minibatch_size // horizon if self.minibatch_segments * horizon != self.minibatch_size: raise pufferlib.APIUsageError( - f"minibatch_size {self.minibatch_size} must be divisible by bptt_horizon {horizon}" + f"minibatch_size {self.minibatch_size} must be divisible by horizon {horizon}" ) # Torch compile @@ -187,7 +244,8 @@ def __init__(self, config, vecenv, policy, logger=None): self.policy = policy if config["compile"]: self.policy = torch.compile(policy, mode=config["compile_mode"]) - self.policy.forward_eval = torch.compile(policy, mode=config["compile_mode"]) + if hasattr(policy, "forward_eval"): + self.policy.forward_eval = torch.compile(policy.forward_eval, mode=config["compile_mode"]) pufferlib.pytorch.sample_logits = torch.compile( pufferlib.pytorch.sample_logits, mode=config["compile_mode"] ) @@ -217,6 +275,7 @@ def __init__(self, config, vecenv, policy, logger=None): raise ValueError(f"Unknown optimizer: {config['optimizer']}") self.optimizer = optimizer + # Logging self.logger = logger if logger is None: @@ -260,6 +319,7 @@ def __init__(self, config, vecenv, policy, logger=None): self.stats = defaultdict(list) self.last_stats = defaultdict(list) self.losses = {} + # Dashboard self.model_size = sum(p.numel() for p in policy.parameters() if p.requires_grad) self.print_dashboard(clear=True) @@ -284,11 +344,20 @@ def evaluate(self): config = self.config device = config["device"] - if config["use_rnn"]: + # Reset hidden states for both RNN and Transformer + if config.get("policy_architecture", "Recurrent") == "Recurrent": for k in self.lstm_h: self.lstm_h[k] = torch.zeros(self.lstm_h[k].shape, device=device) self.lstm_c[k] = torch.zeros(self.lstm_c[k].shape, device=device) + if config.get("policy_architecture", "Recurrent") == "Transformer": + h = self.policy.hidden_size + for k in self.transformer_context: + n = self.transformer_context[k].shape[0] + # Pre-allocate full buffer instead of empty + self.transformer_context[k] = torch.zeros(n, self.context_length, h, device=device) + self.transformer_position[k] = torch.zeros(1, dtype=torch.long, device=device) + self.full_rows = 0 while self.full_rows < self.segments: profile("env", epoch) @@ -296,24 +365,28 @@ def evaluate(self): # print(f"o shape is {o.shape}", flush = True) if self.population_play: batch_size = self.vecenv.batch_size - ego_ids = info[-1] + # Filter info to get only the ego_ids lists (not the metrics dicts) + ego_ids_per_env = [item for item in info if isinstance(item, list)] if batch_size > 1: total_agents = len(o) num_agents_per_env = total_agents // batch_size - original_shape = o.shape - - o = o.reshape(batch_size, num_agents_per_env, *original_shape[1:]) - r = r.reshape(batch_size, num_agents_per_env) - d = d.reshape(batch_size, num_agents_per_env) - t = t.reshape(batch_size, num_agents_per_env) - - o = o[:, ego_ids].reshape(batch_size * len(ego_ids), *original_shape[1:]) - r = r[:, ego_ids].flatten() - d = d[:, ego_ids].flatten() - t = t[:, ego_ids].flatten() + # Create flat ego_ids by adding batch offset + flat_ego_ids = [] + for env_idx in range(batch_size): + ego_ids = ego_ids_per_env[env_idx] + offset = env_idx * num_agents_per_env + flat_ego_ids.extend([int(idx) + offset for idx in ego_ids]) + + # Simply index with the flat ego_ids + o = o[flat_ego_ids] + r = r[flat_ego_ids] + d = d[flat_ego_ids] + t = t[flat_ego_ids] else: + ego_ids = ego_ids_per_env[0] # Single environment + ego_ids = [int(idx) for idx in ego_ids] # Convert to int o = o[ego_ids] r = r[ego_ids] d = d[ego_ids] @@ -338,10 +411,21 @@ def evaluate(self): env_id=env_id, mask=mask, ) + # Get appropriate batch key for state lookup + if self.population_play: + batch_size = self.vecenv.ego_agents_per_batch + else: + batch_size = self.vecenv.agents_per_batch + state_key = (env_id.start // batch_size) * batch_size - if config["use_rnn"]: - state["lstm_h"] = self.lstm_h[env_id.start] - state["lstm_c"] = self.lstm_c[env_id.start] + if config.get("policy_architecture", "Recurrent") == "Recurrent": + state["lstm_h"] = self.lstm_h[state_key] + state["lstm_c"] = self.lstm_c[state_key] + + if config.get("policy_architecture", "Recurrent") == "Transformer": + state["transformer_context"] = self.transformer_context[state_key] + state["transformer_position"] = self.transformer_position[state_key] + # Note: terminals not needed for eval since we're doing single-step inference logits, value = self.policy.forward_eval(o_device, state) action, logprob, _ = pufferlib.pytorch.sample_logits(logits) @@ -349,18 +433,37 @@ def evaluate(self): profile("eval_copy", epoch) with torch.no_grad(): - if config["use_rnn"]: - # Use the same lstm_key calculation + # Update hidden states after forward pass + if config.get("policy_architecture", "Recurrent") == "Recurrent": if self.population_play: batch_size = self.vecenv.ego_agents_per_batch else: batch_size = self.vecenv.agents_per_batch lstm_key = (env_id.start // batch_size) * batch_size - self.lstm_h[lstm_key] = state["lstm_h"] self.lstm_c[lstm_key] = state["lstm_c"] + if config.get("policy_architecture", "Recurrent") == "Transformer": + if self.population_play: + batch_size = self.vecenv.ego_agents_per_batch + else: + batch_size = self.vecenv.agents_per_batch + + transformer_key = (env_id.start // batch_size) * batch_size + self.transformer_context[transformer_key] = state["transformer_context"] + self.transformer_position[transformer_key] = state["transformer_position"] + + # Reset transformer context on episode boundaries + if done_mask.any(): + done_indices = torch.where(torch.from_numpy(done_mask))[0] + if len(done_indices) > 0: + batch_start_in_group = env_id.start % batch_size + global_indices = batch_start_in_group + done_indices + valid_mask = global_indices < self.transformer_position[transformer_key].shape[0] + valid_indices = global_indices[valid_mask] + if len(valid_indices) > 0: + self.transformer_position[transformer_key][valid_indices] = -1 # Fast path for fully vectorized envs l = self.ep_lengths[env_id.start].item() batch_rows = slice(self.ep_indices[env_id.start].item(), 1 + self.ep_indices[env_id.stop - 1].item()) @@ -378,7 +481,13 @@ def evaluate(self): # Note: We are not yet handling masks in this version self.ep_lengths[env_id] += 1 - if l + 1 >= config["bptt_horizon"]: + # Use appropriate horizon based on model type + horizon = ( + config.get("context_window") + if config.get("policy_architecture", "Recurrent") == "Transformer" + else config["bptt_horizon"] + ) + if l + 1 >= horizon: num_full = env_id.stop - env_id.start self.ep_indices[env_id] = self.free_idx + torch.arange(num_full, device=config["device"]).int() self.ep_lengths[env_id] = 0 @@ -400,12 +509,17 @@ def evaluate(self): self.stats[k].append(v) profile("env", epoch) - self.vecenv.send(action) profile("eval_misc", epoch) self.free_idx = self.total_agents - self.ep_indices = torch.arange(self.total_agents, device=device, dtype=torch.int32) + + if self.population_play: + total_agents = self.vecenv.num_ego_agents + else: + total_agents = self.total_agents + + self.ep_indices = torch.arange(total_agents, device=device, dtype=torch.int32) self.ep_lengths.zero_() profile.end() return self.stats @@ -480,17 +594,45 @@ def train(self): mb_advantages = advantages[idx] profile("train_forward", epoch) - if not config["use_rnn"]: + + # Handle observation reshaping based on model type + if ( + not config.get("policy_architecture", "Recurrent") == "Recurrent" + and not config.get("policy_architecture", "Recurrent") == "Transformer" + ): + # Flatten for non-recurrent models mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) state = dict( action=mb_actions, - lstm_h=None, - lstm_c=None, ) + # Add appropriate state based on model type + if config.get("policy_architecture", "Recurrent") == "Recurrent": + state["lstm_h"] = None + state["lstm_c"] = None + elif config.get("policy_architecture", "Recurrent") == "Transformer": + state["transformer_context"] = None + state["transformer_position"] = None + state["terminals"] = mb_terminals # For episode boundary masking + logits, newvalue = self.policy(mb_obs, state) - actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=mb_actions) + + # Handle action sampling based on observation shape + if ( + config.get("policy_architecture", "Recurrent") == "Recurrent" + or config.get("policy_architecture", "Recurrent") == "Transformer" + ): + # Add this right before calling sample_logits + if isinstance(logits, tuple): + logits = logits[0] + actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=mb_actions) + else: + # Need to flatten actions for non-recurrent models + actions, newlogprob, entropy = pufferlib.pytorch.sample_logits( + logits, + action=mb_actions.reshape(-1, *mb_actions.shape[2:]) if len(mb_actions.shape) > 2 else mb_actions, + ) profile("train_misc", epoch) newlogprob = newlogprob.reshape(mb_logprobs.shape) @@ -508,6 +650,8 @@ def train(self): mb_gammas = gammas[idx] else: mb_gammas = torch.full((len(idx),), config["gamma"], device=device, dtype=torch.float32) + + # Recompute advantages with new ratios adv = compute_puff_advantage( mb_values, mb_rewards, @@ -607,6 +751,7 @@ def train(self): self.msg = f"Checkpoint saved at update {self.epoch}" if self.render and self.epoch % self.render_interval == 0: + print("Attempting Render ") model_dir = os.path.join(self.config["data_dir"], f"{self.config['env']}_{self.logger.run_id}") model_files = glob.glob(os.path.join(model_dir, "model_*.pt")) @@ -627,7 +772,9 @@ def train(self): path=bin_path, silent=True, ) - pufferlib.utils.render_videos(self.config, self.vecenv, self.logger, self.global_step, bin_path) + pufferlib.utils.render_videos( + self.config, self.vecenv, self.logger, self.epoch, self.global_step, bin_path + ) except Exception as e: print(f"Failed to export model weights: {e}") @@ -1052,6 +1199,7 @@ def __init__(self, args, load_id=None, resume="allow"): save_code=False, resume=resume, config=args, + name=args.get("wandb_name"), tags=[args["tag"]] if args["tag"] is not None else [], ) self.wandb = wandb @@ -1107,7 +1255,7 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None): elif args["wandb"]: logger = WandbLogger(args) - train_config = dict(**args["train"], env=env_name, eval=args.get("eval", {})) + train_config = dict(**args["train"], env=env_name, eval=args.get("eval", {}), env_config=args.get("env", {})) pufferl = PuffeRL(train_config, vecenv, policy, logger) all_logs = [] @@ -1149,15 +1297,18 @@ def eval(env_name, args=None, vecenv=None, policy=None): wosac_enabled = args["eval"]["wosac_realism_eval"] human_replay_enabled = args["eval"]["human_replay_eval"] + args["env"]["map_dir"] = args["eval"]["map_dir"] + args["env"]["num_maps"] = args["eval"]["num_maps"] + args["env"]["use_all_maps"] = True + dataset_name = args["env"]["map_dir"].split("/")[-1] if wosac_enabled: - print(f"Running WOSAC realism evaluation. \n") + print(f"Running WOSAC realism evaluation with {dataset_name} dataset. \n") from pufferlib.ocean.benchmark.evaluator import WOSACEvaluator backend = args["eval"]["backend"] assert backend == "PufferEnv" or not wosac_enabled, "WOSAC evaluation only supports PufferEnv backend." args["vec"] = dict(backend=backend, num_envs=1) - args["env"]["num_agents"] = args["eval"]["wosac_num_agents"] args["env"]["init_mode"] = args["eval"]["wosac_init_mode"] args["env"]["control_mode"] = args["eval"]["wosac_control_mode"] args["env"]["init_steps"] = args["eval"]["wosac_init_steps"] @@ -1172,6 +1323,10 @@ def eval(env_name, args=None, vecenv=None, policy=None): # Collect ground truth trajectories from the dataset gt_trajectories = evaluator.collect_ground_truth_trajectories(vecenv) + print(f"Number of scenarios: {len(np.unique(gt_trajectories['scenario_id']))}") + print(f"Number of controlled agents: {gt_trajectories['x'].shape[0]}") + print(f"Number of evaluated agents: {np.sum(gt_trajectories['id'] >= 0)}") + # Roll out trained policy in the simulator simulated_trajectories = evaluator.collect_simulated_trajectories(args, vecenv, policy) @@ -1179,32 +1334,39 @@ def eval(env_name, args=None, vecenv=None, policy=None): evaluator._quick_sanity_check(gt_trajectories, simulated_trajectories) # Analyze and compute metrics + agent_state = vecenv.driver_env.get_global_agent_state() + road_edge_polylines = vecenv.driver_env.get_road_edge_polylines() results = evaluator.compute_metrics( - gt_trajectories, simulated_trajectories, args["eval"]["wosac_aggregate_results"] + gt_trajectories, + simulated_trajectories, + agent_state, + road_edge_polylines, + args["eval"]["wosac_aggregate_results"], ) if args["eval"]["wosac_aggregate_results"]: import json - print("WOSAC_METRICS_START") + print("\nWOSAC_METRICS_START") print(json.dumps(results)) print("WOSAC_METRICS_END") return results elif human_replay_enabled: - print("Running human replay evaluation.\n") + print(f"Running human replay evaluation with {dataset_name} dataset.\n") from pufferlib.ocean.benchmark.evaluator import HumanReplayEvaluator backend = args["eval"].get("backend", "PufferEnv") args["vec"] = dict(backend=backend, num_envs=1) - args["env"]["num_agents"] = args["eval"]["human_replay_num_agents"] args["env"]["control_mode"] = args["eval"]["human_replay_control_mode"] - args["env"]["scenario_length"] = 91 # Standard scenario length + args["env"]["episode_length"] = 91 # WOMD scenario length vecenv = vecenv or load_env(env_name, args) policy = policy or load_policy(args, vecenv, env_name) + print(f"Effective number of scenarios used: {len(vecenv.driver_env.agent_offsets) - 1}") + evaluator = HumanReplayEvaluator(args) # Run rollouts with human replays @@ -1312,6 +1474,118 @@ def sweep(args=None, env_name=None): args["train"]["total_timesteps"] = total_timesteps +def controlled_exp(env_name, args=None): + """Run experiments with all combinations of specified parameter values.""" + import itertools + from copy import deepcopy + + args = args or load_config(env_name) + if not args["wandb"] and not args["neptune"]: + raise pufferlib.APIUsageError("Targeted experiments require either wandb or neptune") + + # Check if controlled_exp config exists + if "controlled_exp" not in args: + raise pufferlib.APIUsageError("No [controlled_exp.*] sections found in config") + + # Extract parameters from controlled_exp namespace + params = {} + for section, section_config in args["controlled_exp"].items(): + if isinstance(section_config, dict): + for param, param_config in section_config.items(): + if isinstance(param_config, dict) and "values" in param_config: + params[f"{section}.{param}"] = param_config["values"] + + if not params: + raise pufferlib.APIUsageError("No parameters with 'values' lists found in [controlled_exp.*] sections") + + # Generate all combinations + keys = list(params.keys()) + combinations = list(itertools.product(*[params[k] for k in keys])) + + print(f"Running a total of {len(combinations)} experiments with parameters: {keys}") + + # Run each combination + for i, combo in enumerate(combinations, 1): + exp_args = deepcopy(args) + + # Set parameters + for key, value in zip(keys, combo): + section, param = key.split(".") + exp_args[section][param] = value + + print(f"\nExperiment {i}/{len(combinations)}: {dict(zip(keys, combo))}") + + # Train + train(env_name, args=exp_args) + + print(f"\n✓ Completed all {len(combinations)} experiments") + + +def sanity(env_name, args=None): + args = args or load_config(env_name) + base_dir = Path(__file__).resolve().parent / "resources" / "drive" / "sanity" + json_dir = base_dir / "sanity_jsons" + binary_dir = base_dir / "sanity_binaries" + + available_maps = {p.stem: p for p in json_dir.glob("*.json")} + selected = args.get("sanity_maps") + if isinstance(selected, str): + selected = [selected] + + if selected: + missing = [name for name in selected if name not in available_maps] + if missing: + raise pufferlib.APIUsageError(f"Unknown sanity maps: {', '.join(sorted(missing))}") + chosen = [(name, available_maps[name]) for name in selected] + else: + chosen = sorted(available_maps.items()) + + if not chosen: + raise pufferlib.APIUsageError(f"No sanity maps found in {json_dir}") + + from pufferlib.ocean.drive.drive import load_map + + binary_dir.mkdir(parents=True, exist_ok=True) + binaries = [] + for idx, (name, json_path) in enumerate(chosen): + output_path = binary_dir / f"{name}.bin" + load_map(str(json_path), idx, str(output_path)) + binaries.append((name, output_path)) + + runs = [] + for name, binary in binaries: + map_zero = binary_dir / "map_000.bin" + shutil.copy2(binary, map_zero) + + run_args = { + **args, + "env": {**args["env"], "num_maps": 1, "map_dir": str(binary_dir)}, + "train": {**args["train"], "render_map": str(map_zero)}, + } + if run_args.get("wandb"): + run_args["wandb_name"] = name + + print(f"Running sanity map '{name}' from {binary.name}") + run_logs = train(env_name=env_name, args=run_args) + runs.append({"map": name, "logs": run_logs}) + + print("Sanity checklist:") + for entry in runs: + name = entry["map"] + logs = entry.get("logs") or [] + final = logs[-1] if logs else {} + score = final.get("environment/score") + if score is None: + status = "unknown (no score)" + elif score >= 0.95: + status = "✅ Solved" + else: + status = "❌ unsolved" + print(f" - {name}: {status} (score={score})") + + return runs + + def profile(args=None, env_name=None, vecenv=None, policy=None): args = load_config() vecenv = vecenv or load_env(env_name, args) @@ -1357,20 +1631,14 @@ def ensure_drive_binary(): binary is always up-to-date with the latest code changes. """ if os.path.exists("./visualize"): - print("Removing existing visualize binary...") - try: - os.remove("./visualize") - except FileNotFoundError: - print("Binary not found") - print("Building visualize binary...") + os.remove("./visualize") + try: result = subprocess.run( ["bash", "scripts/build_ocean.sh", "visualize", "local"], capture_output=True, text=True, timeout=300 ) - if result.returncode == 0: - print("Successfully built visualize binary") - else: + if result.returncode != 0: print(f"Build failed: {result.stderr}") raise RuntimeError("Failed to build visualize binary for rendering") except subprocess.TimeoutExpired: @@ -1405,9 +1673,18 @@ def load_policy(args, vecenv, env_name=""): policy_cls = getattr(env_module.torch, args["policy_name"]) policy = policy_cls(vecenv.driver_env, **args["policy"]) - rnn_name = args["rnn_name"] - if rnn_name is not None: - rnn_cls = getattr(env_module.torch, args["rnn_name"]) + # Handle both RNN and Transformer wrappers + rnn_name = args.get("rnn_name") + transformer_name = args.get("transformer_name") + + if transformer_name is not None: + # Load transformer wrapper + transformer_cls = getattr(env_module.torch, transformer_name) + args["transformer"]["context_length"] = vecenv.driver_env.episode_length + policy = transformer_cls(vecenv.driver_env, policy, **args["transformer"]) + elif rnn_name is not None: + # Load RNN wrapper + rnn_cls = getattr(env_module.torch, rnn_name) policy = rnn_cls(vecenv.driver_env, policy, **args["rnn"]) policy = policy.to(device) @@ -1466,7 +1743,7 @@ def load_config(env_name): parser.add_argument("--neptune-project", type=str, default="ablations") parser.add_argument("--local-rank", type=int, default=0, help="Used by torchrun for DDP") parser.add_argument("--tag", type=str, default=None, help="Tag for experiment") - + parser.add_argument("--sanity-maps", nargs="*", default=None, help="Optional list of sanity map base names to run") args = parser.parse_known_args()[0] # Load defaults and config @@ -1517,9 +1794,7 @@ def puffer_type(value): def main(): - err = ( - "Usage: puffer [train, eval, sweep, autotune, profile, export] [env_name] [optional args]. --help for more info" - ) + err = "Usage: puffer [train, eval, sweep, controlled_exp, autotune, profile, export, sanity] [env_name] [optional args]. --help for more info" if len(sys.argv) < 3: raise pufferlib.APIUsageError(err) @@ -1531,12 +1806,16 @@ def main(): eval(env_name=env_name) elif mode == "sweep": sweep(env_name=env_name) + elif mode == "controlled_exp": + controlled_exp(env_name=env_name) elif mode == "autotune": autotune(env_name=env_name) elif mode == "profile": profile(env_name=env_name) elif mode == "export": export(env_name=env_name) + elif mode == "sanity": + sanity(env_name=env_name) else: raise pufferlib.APIUsageError(err) diff --git a/pufferlib/resources/drive/binaries/map_000.bin b/pufferlib/resources/drive/binaries/training/map_000.bin similarity index 100% rename from pufferlib/resources/drive/binaries/map_000.bin rename to pufferlib/resources/drive/binaries/training/map_000.bin diff --git a/pufferlib/utils.py b/pufferlib/utils.py index a93c6b2945..8ef3cb4034 100644 --- a/pufferlib/utils.py +++ b/pufferlib/utils.py @@ -10,6 +10,8 @@ def run_human_replay_eval_in_subprocess(config, logger, global_step): """ Run human replay evaluation in a subprocess and log metrics to wandb. + For adaptive agents, this runs evaluate_human_logs.py with --human-replay flag. + For non-adaptive agents, this runs pufferl eval with human-replay-eval flag. """ try: run_id = logger.run_id @@ -22,50 +24,145 @@ def run_human_replay_eval_in_subprocess(config, logger, global_step): latest_cpt = max(model_files, key=os.path.getctime) - # Prepare evaluation command - eval_config = config["eval"] - cmd = [ - sys.executable, - "-m", - "pufferlib.pufferl", - "eval", - config["env"], - "--load-model-path", - latest_cpt, - "--eval.wosac-realism-eval", - "False", - "--eval.human-replay-eval", - "True", - "--eval.human-replay-num-agents", - str(eval_config["human_replay_num_agents"]), - "--eval.human-replay-control-mode", - str(eval_config["human_replay_control_mode"]), - ] + # Check if this is an adaptive driving agent + # config["env"] is the env name string (e.g., "puffer_adaptive_drive") + env_name = config.get("env", "") + is_adaptive = "adaptive" in env_name - # Run human replay evaluation in subprocess - result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, cwd=os.getcwd()) + # Get nested config sections + env_config = config.get("env_config", {}) + eval_config = config.get("eval", {}) - if result.returncode == 0: - # Extract JSON from stdout between markers - stdout = result.stdout - if "HUMAN_REPLAY_METRICS_START" in stdout and "HUMAN_REPLAY_METRICS_END" in stdout: - start = stdout.find("HUMAN_REPLAY_METRICS_START") + len("HUMAN_REPLAY_METRICS_START") - end = stdout.find("HUMAN_REPLAY_METRICS_END") - json_str = stdout[start:end].strip() - human_replay_metrics = json.loads(json_str) + print(f"[Human Replay Eval] env_name={env_name}, is_adaptive={is_adaptive}") + print(f"[Human Replay Eval] Using model: {latest_cpt}") + + if is_adaptive: + # Use evaluate_human_logs.py for adaptive agents with human replay + cmd = [ + sys.executable, + "evaluate_human_logs.py", + "--policy-path", + latest_cpt, + "--policy-architecture", + config.get("policy_architecture", "Transformer"), + "--adaptive-driving-agent", + "1", + "--k-scenarios", + str(env_config.get("k_scenarios", 2)), + "--num-agents", + str(eval_config.get("human_replay_num_agents", 32)), + "--num-maps", + str(eval_config.get("human_replay_num_maps", 100)), + "--num-rollouts", + str(eval_config.get("human_replay_num_rollouts", 100)), + "--dynamics-model", + str(env_config.get("dynamics_model", "classic")), + "--human-replay", # Enable human replay mode + "--max-controlled-agents", + "1", + "--output", + "/tmp/human_replay_eval.json", + ] + print(f"[Human Replay Eval] Command: {' '.join(cmd)}") - # Log to wandb if available - if hasattr(logger, "wandb") and logger.wandb: - logger.wandb.log( - { - "eval/human_replay_collision_rate": human_replay_metrics["collision_rate"], - "eval/human_replay_offroad_rate": human_replay_metrics["offroad_rate"], - "eval/human_replay_completion_rate": human_replay_metrics["completion_rate"], - }, - step=global_step, - ) + result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, cwd=os.getcwd()) + + if result.returncode == 0: + # Read metrics from output JSON file + try: + with open("/tmp/human_replay_eval.json", "r") as f: + human_replay_metrics = json.load(f) + + # Log to wandb if available + if hasattr(logger, "wandb") and logger.wandb: + log_data = { + "eval/human_replay_collision_rate": human_replay_metrics.get("collision_rate", 0), + "eval/human_replay_offroad_rate": human_replay_metrics.get("offroad_rate", 0), + "eval/human_replay_completion_rate": human_replay_metrics.get("completion_rate", 0), + "eval/human_replay_score": human_replay_metrics.get("score", 0), + } + # Add adaptive delta metrics if available (difference between last and first scenario) + if "ada_delta_score" in human_replay_metrics: + # All delta metrics + delta_metrics = [ + "ada_delta_score", + "ada_delta_collision_rate", + "ada_delta_offroad_rate", + "ada_delta_completion_rate", + "ada_delta_episode_return", + "ada_delta_perf", + "ada_delta_dnf_rate", + "ada_delta_num_goals_reached", + ] + for metric in delta_metrics: + if metric in human_replay_metrics: + log_data[f"eval/human_replay_{metric}"] = human_replay_metrics[metric] + + # First and last scenario metrics + scenario_metrics = [ + "first_scenario_score", + "first_scenario_collision_rate", + "first_scenario_offroad_rate", + "first_scenario_episode_return", + "last_scenario_score", + "last_scenario_collision_rate", + "last_scenario_offroad_rate", + "last_scenario_episode_return", + ] + for metric in scenario_metrics: + if metric in human_replay_metrics: + log_data[f"eval/human_replay_{metric}"] = human_replay_metrics[metric] + + logger.wandb.log(log_data, step=global_step) + except (FileNotFoundError, json.JSONDecodeError) as e: + print(f"Failed to read human replay metrics: {e}") + else: + print(f"Human replay evaluation failed with exit code {result.returncode}") + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") else: - print(f"Human replay evaluation failed with exit code {result.returncode}: {result.stderr}") + # Non-adaptive: use original pufferl eval path + cmd = [ + sys.executable, + "-m", + "pufferlib.pufferl", + "eval", + config["env"], + "--load-model-path", + latest_cpt, + "--eval.wosac-realism-eval", + "False", + "--eval.human-replay-eval", + "True", + "--eval.human-replay-num-agents", + str(eval_config.get("human_replay_num_agents", 64)), + "--eval.human-replay-control-mode", + str(eval_config.get("human_replay_control_mode", "control_sdc_only")), + ] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, cwd=os.getcwd()) + + if result.returncode == 0: + # Extract JSON from stdout between markers + stdout = result.stdout + if "HUMAN_REPLAY_METRICS_START" in stdout and "HUMAN_REPLAY_METRICS_END" in stdout: + start = stdout.find("HUMAN_REPLAY_METRICS_START") + len("HUMAN_REPLAY_METRICS_START") + end = stdout.find("HUMAN_REPLAY_METRICS_END") + json_str = stdout[start:end].strip() + human_replay_metrics = json.loads(json_str) + + # Log to wandb if available + if hasattr(logger, "wandb") and logger.wandb: + logger.wandb.log( + { + "eval/human_replay_collision_rate": human_replay_metrics["collision_rate"], + "eval/human_replay_offroad_rate": human_replay_metrics["offroad_rate"], + "eval/human_replay_completion_rate": human_replay_metrics["completion_rate"], + }, + step=global_step, + ) + else: + print(f"Human replay evaluation failed with exit code {result.returncode}: {result.stderr}") except subprocess.TimeoutExpired: print("Human replay evaluation timed out") @@ -114,7 +211,7 @@ def run_wosac_eval_in_subprocess(config, logger, global_step): "--eval.wosac-init-mode", str(eval_config.get("wosac_init_mode", "create_all_valid")), "--eval.wosac-control-mode", - str(eval_config.get("wosac_control_mode", "control_tracks_to_predict")), + str(eval_config.get("wosac_control_mode", "control_wosac")), "--eval.wosac-init-steps", str(eval_config.get("wosac_init_steps", 10)), "--eval.wosac-goal-behavior", @@ -151,15 +248,23 @@ def run_wosac_eval_in_subprocess(config, logger, global_step): step=global_step, ) else: - print(f"WOSAC evaluation failed with exit code {result.returncode}: {result.stderr}") + print(f"WOSAC evaluation failed with exit code {result.returncode}") + print(f"Error: {result.stderr}") + + # Check for memory issues + stderr_lower = result.stderr.lower() + if "out of memory" in stderr_lower or "cuda out of memory" in stderr_lower: + print("GPU out of memory. Skipping this WOSAC evaluation.") except subprocess.TimeoutExpired: - print("WOSAC evaluation timed out") + print("WOSAC evaluation timed out after 600 seconds") + except MemoryError as e: + print(f"WOSAC evaluation ran out of memory. Skipping this evaluation: {e}") except Exception as e: - print(f"Failed to run WOSAC evaluation: {e}") + print(f"Failed to run WOSAC evaluation: {type(e).__name__}: {e}") -def render_videos(config, vecenv, logger, global_step, bin_path): +def render_videos(config, vecenv, logger, epoch, global_step, bin_path): """ Generate and log training videos using C-based rendering. @@ -194,91 +299,104 @@ def render_videos(config, vecenv, logger, global_step, bin_path): # TODO: Fix memory leaks so that this is not needed # Suppress AddressSanitizer exit code (temp) - env = os.environ.copy() - env["ASAN_OPTIONS"] = "exitcode=0" - - cmd = ["xvfb-run", "-a", "-s", "-screen 0 1280x720x24", "./visualize"] - - # Add render configurations - if config["show_grid"]: - cmd.append("--show-grid") - if config["obs_only"]: - cmd.append("--obs-only") - if config["show_lasers"]: - cmd.append("--lasers") - if config["show_human_logs"]: - cmd.append("--log-trajectories") - if vecenv.driver_env.goal_radius is not None: - cmd.extend(["--goal-radius", str(vecenv.driver_env.goal_radius)]) - if vecenv.driver_env.init_steps > 0: - cmd.extend(["--init-steps", str(vecenv.driver_env.init_steps)]) - if config["render_map"] is not None: - map_path = config["render_map"] - if os.path.exists(map_path): - cmd.extend(["--map-name", map_path]) - if vecenv.driver_env.init_mode is not None: - cmd.extend(["--init-mode", str(vecenv.driver_env.init_mode)]) - if vecenv.driver_env.control_mode is not None: - cmd.extend(["--control-mode", str(vecenv.driver_env.control_mode)]) - - if hasattr(vecenv.driver_env, "reward_conditioned"): - cmd.extend(["--use-rc", "1" if vecenv.driver_env.reward_conditioned else "0"]) - if hasattr(vecenv.driver_env, "entropy_conditioned"): - cmd.extend(["--use-ec", "1" if vecenv.driver_env.entropy_conditioned else "0"]) - if hasattr(vecenv.driver_env, "discount_conditioned"): - cmd.extend(["--use-dc", "1" if vecenv.driver_env.discount_conditioned else "0"]) - - # Specify output paths for videos - cmd.extend(["--output-topdown", "resources/drive/output_topdown.mp4"]) - cmd.extend(["--output-agent", "resources/drive/output_agent.mp4"]) - - # Add environment configuration + env_vars = os.environ.copy() + env_vars["ASAN_OPTIONS"] = "exitcode=0" + + # Base command with only visualization flags (env config comes from INI) + base_cmd = ["xvfb-run", "-a", "-s", "-screen 0 1280x720x24", "./visualize"] + + # Visualization config flags only + if config.get("show_grid", False): + base_cmd.append("--show-grid") + if config.get("obs_only", False): + base_cmd.append("--obs-only") + if config.get("show_lasers", False): + base_cmd.append("--lasers") + if config.get("show_human_logs", False): + base_cmd.append("--show-human-logs") + if config.get("zoom_in", False): + base_cmd.append("--zoom-in") + + # Frame skip for rendering performance + frame_skip = config.get("frame_skip", 1) + if frame_skip > 1: + base_cmd.extend(["--frame-skip", str(frame_skip)]) + + # View mode + view_mode = config.get("view_mode", "both") + base_cmd.extend(["--view", view_mode]) + + # Get num_maps if available env_cfg = getattr(vecenv, "driver_env", None) - if env_cfg is not None: - n_policy = getattr(env_cfg, "max_controlled_agents", -1) - try: - n_policy = int(n_policy) - except (TypeError, ValueError): - n_policy = -1 - if n_policy > 0: - cmd += ["--num-policy-controlled-agents", str(n_policy)] - if getattr(env_cfg, "num_maps", False): - cmd.extend(["--num-maps", str(env_cfg.num_maps)]) - if getattr(env_cfg, "scenario_length", None): - cmd.extend(["--scenario-length", str(env_cfg.scenario_length)]) - - # Call C code that runs eval_gif() in subprocess - result = subprocess.run(cmd, cwd=os.getcwd(), capture_output=True, text=True, timeout=120, env=env) - - vids_exist = os.path.exists("resources/drive/output_topdown.mp4") and os.path.exists( - "resources/drive/output_agent.mp4" - ) - - if result.returncode == 0 or (result.returncode == 1 and vids_exist): - # Move both generated videos to the model directory - videos = [ - ("resources/drive/output_topdown.mp4", f"step_{global_step:09d}_topdown.mp4"), - ("resources/drive/output_agent.mp4", f"step_{global_step:09d}_agent.mp4"), - ] - - for source_vid, target_filename in videos: - if os.path.exists(source_vid): - target_gif = os.path.join(video_output_dir, target_filename) - shutil.move(source_vid, target_gif) - - # Log to wandb if available - if hasattr(logger, "wandb") and logger.wandb: - import wandb - - view_type = "world_state" if "topdown" in target_filename else "agent_view" - logger.wandb.log( - {f"render/{view_type}": wandb.Video(target_gif, format="mp4")}, - step=global_step, - ) - else: - print(f"Video generation completed but {source_vid} not found") + if env_cfg is not None and getattr(env_cfg, "num_maps", None): + base_cmd.extend(["--num-maps", str(env_cfg.num_maps)]) + + # Handle single or multiple map rendering + render_maps = config.get("render_map", None) + if render_maps is None: + render_maps = [None] + elif isinstance(render_maps, (str, os.PathLike)): + render_maps = [render_maps] else: - print(f"C rendering failed with exit code {result.returncode}: {result.stdout}") + # Ensure list-like + render_maps = list(render_maps) + + # Collect videos to log as lists so W&B shows all in the same step + videos_to_log_world = [] + videos_to_log_agent = [] + + for i, map_path in enumerate(render_maps): + cmd = list(base_cmd) # copy + if map_path is not None and os.path.exists(map_path): + cmd.extend(["--map-name", str(map_path)]) + + # Output paths (overwrite each iteration; then moved/renamed) + cmd.extend(["--output-topdown", "resources/drive/output_topdown.mp4"]) + cmd.extend(["--output-agent", "resources/drive/output_agent.mp4"]) + + result = subprocess.run(cmd, cwd=os.getcwd(), capture_output=True, text=True, timeout=600, env=env_vars) + + vids_exist = os.path.exists("resources/drive/output_topdown.mp4") and os.path.exists( + "resources/drive/output_agent.mp4" + ) + + if result.returncode == 0 or (result.returncode == 1 and vids_exist): + videos = [ + ( + "resources/drive/output_topdown.mp4", + f"epoch_{epoch:06d}_map{i:02d}_topdown.mp4" if map_path else f"epoch_{epoch:06d}_topdown.mp4", + ), + ( + "resources/drive/output_agent.mp4", + f"epoch_{epoch:06d}_map{i:02d}_agent.mp4" if map_path else f"epoch_{epoch:06d}_agent.mp4", + ), + ] + + for source_vid, target_filename in videos: + if os.path.exists(source_vid): + target_path = os.path.join(video_output_dir, target_filename) + shutil.move(source_vid, target_path) + # Accumulate for a single wandb.log call + if hasattr(logger, "wandb") and logger.wandb: + import wandb + + if "topdown" in target_filename: + videos_to_log_world.append(wandb.Video(target_path, format="mp4")) + else: + videos_to_log_agent.append(wandb.Video(target_path, format="mp4")) + else: + print(f"Video generation completed but {source_vid} not found") + else: + print(f"C rendering failed (map index {i}) with exit code {result.returncode}: {result.stdout}") + + # Log all videos at once so W&B keeps all of them under the same step + if hasattr(logger, "wandb") and logger.wandb and (videos_to_log_world or videos_to_log_agent): + payload = {} + if videos_to_log_world: + payload["render/world_state"] = videos_to_log_world + if videos_to_log_agent: + payload["render/agent_view"] = videos_to_log_agent + logger.wandb.log(payload, step=global_step) except subprocess.TimeoutExpired: print("C rendering timed out") diff --git a/pufferlib/vector.py b/pufferlib/vector.py index 24ac492405..c397ec7e16 100644 --- a/pufferlib/vector.py +++ b/pufferlib/vector.py @@ -838,13 +838,10 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer import gymnasium from pufferlib.ocean.torch import Drive import pufferlib.models + from pufferlib.ocean.drive import binding - dynamics_model = env_k.get("dynamics_model", "jerk") - # Observation space calculation - if dynamics_model == "classic": - ego_features = 7 - elif dynamics_model == "jerk": - ego_features = 10 + dynamics_model = env_k.get("dynamics_model", "classic") + action_type = env_k.get("action_type", "discrete") co_player_policy = env_k["co_player_policy"] @@ -858,27 +855,68 @@ def make(env_creator_or_creators, env_args=None, env_kwargs=None, backend=Puffer reward_conditioned = condition_type in ("reward", "all") entropy_conditioned = condition_type in ("entropy", "all") discount_conditioned = condition_type in ("discount", "all") - # Calculate conditioning dimensions + + if action_type == "discrete": + if dynamics_model == "classic": + # Joint action space (assume dependence) + single_action_space = gymnasium.spaces.MultiDiscrete([7 * 13]) + # Multi discrete (assume independence) + # self.single_action_space = gymnasium.spaces.MultiDiscrete([7, 13]) + elif dynamics_model == "jerk": + # Joint action space (assume dependence) - 4 longitudinal × 3 lateral = 12 + single_action_space = gymnasium.spaces.MultiDiscrete([4 * 3]) + else: + raise ValueError(f"dynamics_model must be 'classic' or 'jerk'. Got: {dynamics_model}") + elif action_type == "continuous": + single_action_space = gymnasium.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + else: + raise ValueError(f"action_space must be 'discrete' or 'continuous'. Got: {action_type}") + + # # Observation space calculation + ego_features = {"classic": binding.EGO_FEATURES_CLASSIC, "jerk": binding.EGO_FEATURES_JERK}.get(dynamics_model) + conditioning_dims = ( (3 if reward_conditioned else 0) + (1 if entropy_conditioned else 0) + (1 if discount_conditioned else 0) ) - # Base observations + conditioning observations - num_obs = ego_features + conditioning_dims + 63 * 7 + 200 * 7 - temp_env = SimpleNamespace( - single_action_space=gymnasium.spaces.MultiDiscrete([7 * 13]), - single_observation_space=gymnasium.spaces.Box(low=-1, high=1, shape=(num_obs,), dtype=np.float32), + ego_features += conditioning_dims + + # # Extract observation shapes from constants + # # These need to be defined in C, since they determine the shape of the arrays + # max_road_objects = 200 + # max_partner_objects = 63 + # partner_features = 7 + # road_features = 7 + + # Extract observation shapes from constants + # These need to be defined in C, since they determine the shape of the arrays + max_road_objects = binding.MAX_ROAD_SEGMENT_OBSERVATIONS + max_partner_objects = binding.MAX_AGENTS - 1 + partner_features = binding.PARTNER_FEATURES + road_features = binding.ROAD_FEATURES + + num_obs = ego_features + max_partner_objects * partner_features + max_road_objects * road_features + + single_observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(num_obs,), dtype=np.float32) + + co_player_env = SimpleNamespace( + single_action_space=single_action_space, + single_observation_space=single_observation_space, reward_conditioned=reward_conditioned, entropy_conditioned=entropy_conditioned, discount_conditioned=discount_conditioned, dynamics_model=dynamics_model, ## keep these the same I think, multiple dynamics models could get weird + max_partner_objects=max_partner_objects, + partner_features=partner_features, + max_road_objects=max_road_objects, + road_features=road_features, ) - base_policy = Drive(temp_env, input_size=input_size, hidden_size=hidden_size) + base_policy = Drive(co_player_env, input_size=input_size, hidden_size=hidden_size) if co_player_rnn: policy = pufferlib.models.LSTMWrapper( - temp_env, + co_player_env, base_policy, input_size=co_player_rnn.get("input_size"), hidden_size=co_player_rnn.get("hidden_size"), diff --git a/pyproject.toml b/pyproject.toml index 4bcc818849..72a1b34949 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,8 +117,10 @@ metta = [ 'hydra-core', 'duckdb', 'raylib>=5.5.0', - 'metta-common @ git+https://github.com/metta-ai/metta.git@main#subdirectory=common', - 'metta-mettagrid @ git+https://github.com/metta-ai/metta.git@main#subdirectory=mettagrid', + # 'metta-common @ git+ssh://git@github.com/metta-ai/metta.git@main#subdirectory=common', + # 'metta-mettagrid @ git+ssh://git@github.com/metta-ai/metta.git@main#subdirectory=mettagrid', + 'mettagrid' + ] microrts = [ diff --git a/scripts/gpu_heartbeat.py b/scripts/gpu_heartbeat.py new file mode 100644 index 0000000000..1acf77a3d1 --- /dev/null +++ b/scripts/gpu_heartbeat.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +""" +GPU Heartbeat for RL Training (CPU-bound workloads) +Keeps GPU utilization above cluster threshold when training is idle. +Tuned for L40S - increase N and/or loop count for H200. +""" + +import torch +import time +import os +import subprocess + +# Settings +THRESHOLD = 65 # If util is below threshold, we wake up (buffer above 50% requirement) +CHECK_INTERVAL = 0.5 # Check nvidia-smi every 0.5 seconds +N = 11000 # Size of matrix (~1GB VRAM, tuned for L40S) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Starting GPU Heartbeat on {torch.cuda.get_device_name(0)}") +print(f"PID: {os.getpid()}") + +# Pre-allocate memory so we don't slow down allocation later +x = torch.randn(N, N, device=device) +y = torch.randn(N, N, device=device) + + +def get_gpu_utilization(): + """Reads the current GPU utilization directly from nvidia-smi""" + try: + result = subprocess.check_output( + ["nvidia-smi", "--query-gpu=utilization.gpu", "--format=csv,noheader,nounits"], encoding="utf-8" + ) + return int(result.strip()) + except Exception: + # If checking fails, assume high load to be safe and sleep + return 100 + + +while True: + current_util = get_gpu_utilization() + + if current_util < THRESHOLD: + # IDLE MODE: Generate load + for _ in range(25): + z = torch.mm(x, y) + torch.cuda.synchronize() + else: + # TRAINING MODE: Get out of the way + time.sleep(CHECK_INTERVAL) diff --git a/scripts/run.sh b/scripts/run.sh index 4038d0cd04..e9c7031e7b 100644 --- a/scripts/run.sh +++ b/scripts/run.sh @@ -3,11 +3,11 @@ #SBATCH --output=/scratch/mmk9418/logs/%A_%a_%x.out #SBATCH --error=/scratch/mmk9418/logs/%A_%a_%x.err #SBATCH --mem=128GB -#SBATCH --time=12:00:00 +#SBATCH --time=24:00:00 #SBATCH --nodes=1 #SBATCH --ntasks=1 #SBATCH --account=torch_pr_355_tandon_priority -#SBATCH --cpus-per-task=32 +#SBATCH --cpus-per-task=48 #SBATCH --gres=gpu:1 #SBATCH --array=0-15 @@ -53,6 +53,11 @@ singularity exec --nv \ cd /scratch/mmk9418/projects/Adaptive_Driving_Agent source .venv/bin/activate + # Start GPU heartbeat in background (for RL training which is CPU-bound) + nice -n 19 python scripts/gpu_heartbeat.py & + HEARTBEAT_PID=\$! + echo \"Started GPU Heartbeat with PID: \$HEARTBEAT_PID\" + puffer train puffer_adaptive_drive --wandb --env.num-maps 1000 \ --env.conditioning.type none \ --env.co-player-enabled 1 \ @@ -62,4 +67,6 @@ singularity exec --nv \ --env.co-player-policy.conditioning.discount-weight-ub $DISCOUNT_UB \ --env.co-player-policy.conditioning.entropy-weight-lb $ENTROPY_LB \ --env.co-player-policy.conditioning.entropy-weight-ub $ENTROPY_UB + + kill \$HEARTBEAT_PID " diff --git a/scripts/run_carla.sh b/scripts/run_carla.sh new file mode 100755 index 0000000000..efe55dc92d --- /dev/null +++ b/scripts/run_carla.sh @@ -0,0 +1,26 @@ +#!/bin/bash +#SBATCH --job-name=puffer_drive_carla +#SBATCH --output=/scratch/mmk9418/logs/%A_%a_%x.out +#SBATCH --error=/scratch/mmk9418/logs/%A_%a_%x.err +#SBATCH --mem=128GB +#SBATCH --time=12:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --account=torch_pr_355_tandon_priority +#SBATCH --cpus-per-task=48 +#SBATCH --gres=gpu:1 + +singularity exec --nv \ + --overlay "$OVERLAY_FILE:ro" \ + "$SINGULARITY_IMAGE" \ + bash -c " + set -e + + source ~/.bashrc + cd /scratch/mmk9418/projects/Adaptive_Driving_Agent + source .venv/bin/activate + + puffer train puffer_drive --wandb \ + --env.map-dir resources/drive/binaries/carla_data \ + --env.num-maps 8 + " diff --git a/tests/test_drive_conditioning.py b/tests/test_drive_conditioning.py index ef9f91c71d..1060180ce9 100644 --- a/tests/test_drive_conditioning.py +++ b/tests/test_drive_conditioning.py @@ -12,7 +12,13 @@ ) def test_no_conditioning(dynamics_model, base_dim, total_dim): """Test that condition_type='none' works for both dynamics models.""" - env = Drive(num_agents=4, condition_type="none", num_maps=1, dynamics_model=dynamics_model, scenario_length=91) + env = Drive( + num_agents=4, + conditioning={"type": "none"}, + num_maps=1, + dynamics_model=dynamics_model, + scenario_length=91, + ) assert env.single_observation_space.shape[0] == base_dim + 63 * 7 + 200 * 7 assert not env.reward_conditioned assert not env.entropy_conditioned @@ -32,13 +38,15 @@ def test_reward_conditioning(dynamics_model, base_dim): """Test that RC adds 3 dimensions and weights are in range for both dynamics models.""" env = Drive( num_agents=4, - condition_type="reward", - collision_weight_lb=-1.0, - collision_weight_ub=0.0, - offroad_weight_lb=-1.0, - offroad_weight_ub=0.0, - goal_weight_lb=0.0, - goal_weight_ub=1.0, + conditioning={ + "type": "reward", + "collision_weight_lb": -1.0, + "collision_weight_ub": 0.0, + "offroad_weight_lb": -1.0, + "offroad_weight_ub": 0.0, + "goal_weight_lb": 0.0, + "goal_weight_ub": 1.0, + }, num_maps=1, dynamics_model=dynamics_model, scenario_length=91, @@ -64,9 +72,11 @@ def test_entropy_conditioning(dynamics_model, base_dim): """Test that EC adds 1 dimension and weight is in range for both dynamics models.""" env = Drive( num_agents=4, - condition_type="entropy", - entropy_weight_lb=0.0, - entropy_weight_ub=0.1, + conditioning={ + "type": "entropy", + "entropy_weight_lb": 0.0, + "entropy_weight_ub": 0.1, + }, num_maps=1, dynamics_model=dynamics_model, scenario_length=91, @@ -90,9 +100,11 @@ def test_discount_conditioning(dynamics_model, base_dim): """Test that DC adds 1 dimension and weight is in range for both dynamics models.""" env = Drive( num_agents=4, - condition_type="discount", - discount_weight_lb=0.9, - discount_weight_ub=0.99, + conditioning={ + "type": "discount", + "discount_weight_lb": 0.9, + "discount_weight_ub": 0.99, + }, num_maps=1, dynamics_model=dynamics_model, scenario_length=91, @@ -116,17 +128,19 @@ def test_combined_conditioning(dynamics_model, base_dim): """Test that RC + EC + DC work together for both dynamics models.""" env = Drive( num_agents=4, - condition_type="all", - collision_weight_lb=-1.0, - collision_weight_ub=0.0, - offroad_weight_lb=-1.0, - offroad_weight_ub=0.0, - goal_weight_lb=0.0, - goal_weight_ub=1.0, - entropy_weight_lb=0.0, - entropy_weight_ub=0.1, - discount_weight_lb=0.9, - discount_weight_ub=0.99, + conditioning={ + "type": "all", + "collision_weight_lb": -1.0, + "collision_weight_ub": 0.0, + "offroad_weight_lb": -1.0, + "offroad_weight_ub": 0.0, + "goal_weight_lb": 0.0, + "goal_weight_ub": 1.0, + "entropy_weight_lb": 0.0, + "entropy_weight_ub": 0.1, + "discount_weight_lb": 0.9, + "discount_weight_ub": 0.99, + }, num_maps=1, dynamics_model=dynamics_model, scenario_length=91, diff --git a/tests/test_drive_render.py b/tests/test_drive_render.py index 75ea9c3746..618562ca64 100644 --- a/tests/test_drive_render.py +++ b/tests/test_drive_render.py @@ -58,7 +58,7 @@ def test_drive_render(): "--frame-skip", "10", "--map-name", - "resources/drive/binaries/map_000.bin", + "resources/drive/binaries/training/map_000.bin", "--output-topdown", "resources/drive/output_topdown.mp4", "--output-agent", diff --git a/tests/test_drive_scenarios.py b/tests/test_drive_scenarios.py index 62f7b54033..0d538c3c30 100644 --- a/tests/test_drive_scenarios.py +++ b/tests/test_drive_scenarios.py @@ -19,10 +19,10 @@ def run_training_test(env_name, config_overrides, target_steps=10000, test_name= "device": "cpu", "compile": False, "total_timesteps": 100000, - "batch_size": 128, + "batch_size": 64, "bptt_horizon": 8, - "minibatch_size": 128, - "max_minibatch_size": 128, + "minibatch_size": 64, + "max_minibatch_size": 64, "update_epochs": 1, "render": False, "checkpoint_interval": 999999, @@ -41,8 +41,10 @@ def run_training_test(env_name, config_overrides, target_steps=10000, test_name= args["env"].update( { "num_agents": 8, + "num_ego_agents": 8, "action_type": "discrete", "num_maps": 1, + "map_dir": "resources/drive/binaries/training", } ) diff --git a/tests/test_drive_train.py b/tests/test_drive_train.py index 0b54aea17d..f9f16ef648 100644 --- a/tests/test_drive_train.py +++ b/tests/test_drive_train.py @@ -26,10 +26,10 @@ def test_drive_training(): "device": "cpu", "compile": False, "total_timesteps": 100000, - "batch_size": 128, + "batch_size": 64, "bptt_horizon": 8, - "minibatch_size": 128, - "max_minibatch_size": 128, + "minibatch_size": 64, + "max_minibatch_size": 64, "update_epochs": 1, "render": False, "checkpoint_interval": 999999, @@ -48,8 +48,12 @@ def test_drive_training(): args["env"].update( { "num_agents": 8, # 1 env * 8 agents = 8 total <= 16 segments + "num_ego_agents": 8, # Must match num_agents for population play "action_type": "discrete", "num_maps": 1, + "map_dir": "resources/drive/binaries/training", + "init_mode": "create_all_valid", + "control_mode": "control_agents", } ) diff --git a/tests/test_simulator_perf.py b/tests/test_simulator_perf.py index 6843eaf101..fd6d2fd15f 100644 --- a/tests/test_simulator_perf.py +++ b/tests/test_simulator_perf.py @@ -14,7 +14,7 @@ def test_simulator_raw(): num_agents = 32 # ---- Run simulation ---- - env = Drive(num_agents=num_agents, num_maps=1, scenario_length=91) + env = Drive(num_agents=num_agents, num_maps=1, episode_length=91, map_dir="resources/drive/binaries/training") obs, _ = env.reset() tick = 0