diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 571b2ea582..7d73ecf1b6 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -12,7 +12,7 @@ batch_size = auto [policy] ; Encoder layer -input_size = 64 +input_size = 256 encoder_gigaflow = True dropout = 0.0 ; Shared backbone layer @@ -25,7 +25,7 @@ actor_num_layers = 0 critic_hidden_size = 512 critic_num_layers = 0 ; Dual or shared actor-critic backbone -split_network = False +split_network = True [rnn] input_size = 512 @@ -48,6 +48,24 @@ dynamics_model = "jerk" dt = 0.1 ; Optional nonzero launch speed for gigaflow random spawns spawn_initial_speed = 0.0 +; Per-agent spawn dimensions (meters). Training and eval use independent +; ranges so a policy trained on non-car shapes (e.g. trucks) can be eval'd +; on matching shapes. width is clipped to <= length at spawn time. +spawn_length_min = 0.8 +spawn_length_max = 7.0 +spawn_width_min = 0.8 +spawn_width_max = 3.0 +eval_spawn_length_min = 2.0 +eval_spawn_length_max = 5.5 +eval_spawn_width_min = 1.5 +eval_spawn_width_max = 2.5 +; Mixed-population spawn: per-agent P(is_truck). 0.0 (default) preserves +; the legacy single-population behavior +truck_fraction = 0.0 +truck_spawn_length_min = 9.0 +truck_spawn_length_max = 15.0 +truck_spawn_width_min = 2.0 +truck_spawn_width_max = 2.6 ; Collision behavior - options: 0 - Ignore, 1 - Stop, 2 - Remove collision_behavior = 1 ; Offroad behavior - options: 0 - Ignore, 1 - Stop, 2 - Remove @@ -85,8 +103,8 @@ min_waypoint_spacing = 20.0 max_waypoint_spacing = 60.0 ; --- Rewards --- -reward_conditioning = False -reward_randomization = False +reward_conditioning = True +reward_randomization = True reward_goal = 1.0 reward_vehicle_collision = 1.0 reward_offroad_collision = 1.0 @@ -578,3 +596,22 @@ values = [0.001, 0.003, 0.01] [controlled_exp.train.ent_coef] values = [0.01, 0.005] + +[finetune] +enabled = False +; Strategy: full | freeze | lora +; full - train every parameter starting from base weights (standard finetune) +; freeze - freeze params matching freeze_regex; train the rest +; lora - wrap target Linears with LoRA adapters; freeze base weights of those layers +mode = full +; Overrides train.learning_rate when enabled. None = inherit train.learning_rate. +base_lr = None +; Regex over named_parameters; matched params get requires_grad=False (used by +; mode=freeze; also additive on top of mode=lora). +freeze_regex = None +; LoRA knobs (used when mode=lora). lora_target is a regex over named_modules +; matching nn.Linear layers to wrap. +lora_rank = 16 +lora_alpha = 32 +lora_target = None +lora_lr_mult = 10.0 diff --git a/pufferlib/config/ocean/drive_finetune_nuplan.ini b/pufferlib/config/ocean/drive_finetune_nuplan.ini new file mode 100644 index 0000000000..d02fb17076 --- /dev/null +++ b/pufferlib/config/ocean/drive_finetune_nuplan.ini @@ -0,0 +1,28 @@ +[env] +; --- nuPlan training data (replay-mode rollouts of real traffic logs) --- +map_dir = "/scratch/ev2237/data/nuplan/nuplan_mini_train_bins" +num_maps = 200 +simulation_mode = "replay" +control_mode = "control_sdc_only" +init_mode = "create_all_valid" +; nuPlan log length is ~20s at 10Hz = 200 steps. +1 to match behavior eval sections. +scenario_length = 201 +resample_frequency = 201 +; SDC is the only controlled agent; the "too many inactive agents" early-reset +; trigger doesn't apply here. Use scenario-length termination. +termination_mode = 0 + +[train] +total_timesteps = 2_000_000_000 +checkpoint_interval = 50 + +[finetune] +enabled = True +mode = lora +; LoRA only on the shared 4-layer backbone — encoders + heads stay fully +; trainable so they can adapt to replay-style observations end-to-end. +lora_target = "actor_backbone\\.backbone\\." +lora_rank = 32 +lora_alpha = 64 +lora_lr_mult = 10.0 +base_lr = 1e-4 diff --git a/pufferlib/config/ocean/drive_finetune_reward.ini b/pufferlib/config/ocean/drive_finetune_reward.ini new file mode 100644 index 0000000000..0b4d1095cb --- /dev/null +++ b/pufferlib/config/ocean/drive_finetune_reward.ini @@ -0,0 +1,18 @@ +[env] +reward_goal = 2.0 +reward_velocity = 0.3 +reward_vehicle_collision = 0.5 +reward_offroad_collision = 0.5 + +[train] +total_timesteps = 2_000_000_000 +checkpoint_interval = 50 + +[finetune] +enabled = True +mode = lora +lora_target = "actor_backbone\\.backbone\\." +lora_rank = 32 +lora_alpha = 64 +lora_lr_mult = 10.0 +base_lr = 1e-4 diff --git a/pufferlib/config/ocean/drive_finetune_truck_mixed.ini b/pufferlib/config/ocean/drive_finetune_truck_mixed.ini new file mode 100644 index 0000000000..5888866858 --- /dev/null +++ b/pufferlib/config/ocean/drive_finetune_truck_mixed.ini @@ -0,0 +1,19 @@ +[env] +; --- Mixed population: 20% trucks at spawn, rest sample the default car range --- +truck_fraction = 0.2 +truck_spawn_length_min = 9.0 +truck_spawn_length_max = 15.0 +truck_spawn_width_min = 2.0 +truck_spawn_width_max = 2.6 + +goal_radius = 4.0 +reward_lane_align = 0.01 + +[train] +total_timesteps = 2_000_000_000 +checkpoint_interval = 50 + +[finetune] +enabled = True +mode = full +base_lr = 1e-4 diff --git a/pufferlib/finetune.py b/pufferlib/finetune.py new file mode 100644 index 0000000000..919f2b8af2 --- /dev/null +++ b/pufferlib/finetune.py @@ -0,0 +1,191 @@ +"""Parameter-efficient finetuning primitives for PufferLib. + +Activated by [finetune].enabled = True in the env config (drive.ini) or the +overlay passed via --finetune-config. Three strategies: + + - full : every parameter trainable; nothing here applies. + - freeze : params matching [finetune].freeze_regex get requires_grad=False. + - lora : nn.Linear submodules whose dotted name matches + [finetune].lora_target get wrapped with LoRALinear (frozen base + weight + trainable rank-r adapter B@A). + +Ordering inside load_policy: + policy = build() + policy.load_state_dict(base_pt) # base weights present, NO lora keys + apply_freeze(policy, ...) # freeze the layers the user pinned + wrap_lora(policy, ...) # swap nn.Linear -> LoRALinear in place + # ... then DDP wrap, torch.compile, optimizer +""" + +import math +import re +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LoRALinear(nn.Module): + """Drop-in replacement for nn.Linear with a frozen base weight and a + trainable rank-r adapter. + + forward(x) = F.linear(x, W, b) + (alpha / r) * F.linear(F.linear(x, A), B) + + State-dict layout: + Saves under {prefix}weight (MERGED: W + scaling * B @ A) and {prefix}bias. + Does NOT save lora_A / lora_B — by design. Every LoRA finetune restarts + its adapter from scratch on resume; the merged weight carries forward + the previous run's learning so progress isn't lost. This keeps saved + .pt files load-compatible with a vanilla nn.Linear at the same path, + which is what subprocess evals and downstream finetunes expect. + """ + + def __init__(self, base: nn.Linear, rank: int, alpha: float): + super().__init__() + self.in_features = base.in_features + self.out_features = base.out_features + + # Copy the base weight + bias as own parameters (NOT a nested + # submodule) so the state_dict key layout matches a vanilla Linear. + self.weight = nn.Parameter(base.weight.data.clone()) + self.weight.requires_grad = False + if base.bias is not None: + self.bias = nn.Parameter(base.bias.data.clone()) + self.bias.requires_grad = False + else: + self.register_parameter("bias", None) + + self.rank = int(rank) + self.alpha = float(alpha) + self.scaling = (self.alpha / self.rank) if self.rank > 0 else 0.0 + + # Kaiming-uniform init for A (standard LoRA practice). + # B is zero so at step 0 the adapter contributes nothing — the + # wrapped model is mathematically identical to the base policy + # before any gradient step. + self.lora_A = nn.Parameter(torch.empty(self.rank, self.in_features)) + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + self.lora_B = nn.Parameter(torch.zeros(self.out_features, self.rank)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = F.linear(x, self.weight, self.bias) + if self.rank > 0: + out = out + self.scaling * F.linear(F.linear(x, self.lora_A), self.lora_B) + return out + + def merged_weight(self) -> torch.Tensor: + if self.rank > 0: + return self.weight + self.scaling * (self.lora_B @ self.lora_A) + return self.weight + + def _save_to_state_dict(self, destination, prefix, keep_vars): + merged = self.merged_weight() + destination[prefix + "weight"] = merged if keep_vars else merged.detach() + if self.bias is not None: + destination[prefix + "bias"] = self.bias if keep_vars else self.bias.detach() + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"rank={self.rank}, alpha={self.alpha}" + ) + + +def _resolve_optional_str(value) -> Optional[str]: + """drive.ini may carry the literal string 'None' for unset values when + routed through ast.literal_eval; treat that the same as Python None.""" + if value is None: + return None + if isinstance(value, str) and value.strip() in ("", "None"): + return None + return value + + +def apply_freeze(policy: nn.Module, regex) -> int: + """Set requires_grad=False on every named_parameter whose name matches + `regex` (re.search semantics). Returns the count of frozen tensors. + No-op if regex is None / empty. + """ + regex = _resolve_optional_str(regex) + if not regex: + return 0 + pattern = re.compile(regex) + count = 0 + for name, p in policy.named_parameters(): + if pattern.search(name): + p.requires_grad = False + count += 1 + return count + + +def wrap_lora(policy: nn.Module, target_regex, rank: int, alpha: float) -> int: + """Replace nn.Linear submodules whose dotted module-name matches + `target_regex` (re.search semantics) with LoRALinear. Returns the count + of wrapped modules. No-op if target_regex is None / empty or rank <= 0. + + Note: regex matches MODULE NAMES (e.g. 'actor_backbone.backbone.1'), not + parameter names. Type filtering to nn.Linear is enforced internally, so + your regex doesn't need to include 'Linear'. + """ + target_regex = _resolve_optional_str(target_regex) + if not target_regex or rank <= 0: + return 0 + pattern = re.compile(target_regex) + targets = [] + for name, module in policy.named_modules(): + # Skip LoRALinear instances so re-runs are idempotent. + if isinstance(module, LoRALinear): + continue + if isinstance(module, nn.Linear) and pattern.search(name): + targets.append((name, module)) + for name, module in targets: + parent_name, _, child_name = name.rpartition(".") + parent = policy.get_submodule(parent_name) if parent_name else policy + wrapped = LoRALinear(module, rank=rank, alpha=alpha).to( + device=module.weight.device, dtype=module.weight.dtype + ) + setattr(parent, child_name, wrapped) + return len(targets) + + +def get_lora_params(policy: nn.Module) -> List[nn.Parameter]: + """Collect every LoRALinear.lora_A / lora_B in the policy. Used to build + a separate optimizer parameter group with its own LR.""" + params: List[nn.Parameter] = [] + for module in policy.modules(): + if isinstance(module, LoRALinear): + params.append(module.lora_A) + params.append(module.lora_B) + return params + + +def build_param_groups(policy: nn.Module, base_lr: float, lora_lr: float) -> list: + """Build optimizer parameter groups. LoRA adapter params get `lora_lr`; + every other trainable param gets `base_lr`. Frozen params are excluded. + Returns a list of dicts suitable for torch.optim.Optimizer. + """ + lora_params = get_lora_params(policy) + lora_param_ids = {id(p) for p in lora_params} + base_params = [ + p for p in policy.parameters() + if p.requires_grad and id(p) not in lora_param_ids + ] + groups = [] + if base_params: + groups.append({"params": base_params, "lr": base_lr}) + active_lora = [p for p in lora_params if p.requires_grad] + if active_lora: + groups.append({"params": active_lora, "lr": lora_lr}) + return groups + + +def trainable_summary(policy: nn.Module) -> str: + total = sum(p.numel() for p in policy.parameters()) + trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad) + lora_modules = sum(1 for m in policy.modules() if isinstance(m, LoRALinear)) + pct = (trainable / total * 100.0) if total > 0 else 0.0 + return ( + f"[finetune] trainable {trainable:,} / {total:,} ({pct:.2f}%) " + f"| LoRA modules wrapped: {lora_modules}" + ) diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index f67243a32d..f0652a3e47 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -1814,6 +1814,19 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) { env->traffic_control_scope = (int) unpack(kwargs, "traffic_control_scope"); env->dt = (float) unpack(kwargs, "dt"); env->spawn_initial_speed = (float) unpack(kwargs, "spawn_initial_speed"); + env->spawn_length_min = (float) unpack(kwargs, "spawn_length_min"); + env->spawn_length_max = (float) unpack(kwargs, "spawn_length_max"); + env->spawn_width_min = (float) unpack(kwargs, "spawn_width_min"); + env->spawn_width_max = (float) unpack(kwargs, "spawn_width_max"); + env->eval_spawn_length_min = (float) unpack(kwargs, "eval_spawn_length_min"); + env->eval_spawn_length_max = (float) unpack(kwargs, "eval_spawn_length_max"); + env->eval_spawn_width_min = (float) unpack(kwargs, "eval_spawn_width_min"); + env->eval_spawn_width_max = (float) unpack(kwargs, "eval_spawn_width_max"); + env->truck_fraction = (float) unpack(kwargs, "truck_fraction"); + env->truck_spawn_length_min = (float) unpack(kwargs, "truck_spawn_length_min"); + env->truck_spawn_length_max = (float) unpack(kwargs, "truck_spawn_length_max"); + env->truck_spawn_width_min = (float) unpack(kwargs, "truck_spawn_width_min"); + env->truck_spawn_width_max = (float) unpack(kwargs, "truck_spawn_width_max"); env->goal_speed = (float) unpack(kwargs, "goal_speed"); env->scenario_length = (int) unpack(kwargs, "scenario_length"); env->termination_mode = (int) unpack(kwargs, "termination_mode"); diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 329e1bdf76..8c61d33986 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -343,6 +343,29 @@ struct Drive { float world_mean_y; float dt; float spawn_initial_speed; + // Per-agent spawn dimension ranges (training vs eval modes use separate ranges). + // length: longitudinal extent in meters; width: lateral extent in meters. + // width is clipped to <= length at spawn time. + float spawn_length_min; + float spawn_length_max; + float spawn_width_min; + float spawn_width_max; + float eval_spawn_length_min; + float eval_spawn_length_max; + float eval_spawn_width_min; + float eval_spawn_width_max; + // Mixed-population spawn: per-agent P(is_truck). When truck_fraction == 0 + // (default), every agent samples dims from the spawn_*_{min,max} range above + // and behavior is identical to the single-population path. When > 0, a + // Bernoulli(truck_fraction) draw at spawn time picks between the car range + // (spawn_*) and the truck range (truck_spawn_*). Used for co-training a + // single policy on a realistic mixed-vehicle population (~10% trucks). + // Eval mode reuses the same truck range; eval_spawn_* still covers cars. + float truck_fraction; + float truck_spawn_length_min; + float truck_spawn_length_max; + float truck_spawn_width_min; + float truck_spawn_width_max; float goal_radius; float goal_speed; float min_waypoint_spacing; @@ -2896,19 +2919,27 @@ static int spawn_agent(Drive *env, int agent_idx, int num_agents) { agent->active_agent = 1; agent->mark_as_expert = 0; - // Default vehicle dimensions - // length: [0.8, 7.0] m - // width: [0.8, 3.0] m - // width = min(width, length) + // Vehicle dimensions sampled from configured ranges. Training and eval + // modes use independent ranges so e.g. a truck-trained policy can be + // eval'd on truck-sized agents instead of falling back to cars. + // Mixed population: when truck_fraction > 0, draw is_truck ~ + // Bernoulli(truck_fraction) per agent and sample dims from the + // truck_spawn_* range; otherwise sample from the standard (car) range. + // The truck range is shared across train and eval — only the car range + // differs between modes. + // width = min(width, length) is enforced below. float spawn_length, spawn_width; - if (env->eval_mode) { - // Fixed size for eval mode - spawn_length = random_uniform(2.0f, 5.5f); - spawn_width = random_uniform(1.5f, 2.5f); + int is_truck = (env->truck_fraction > 0.0f) + && (random_uniform(0.0f, 1.0f) < env->truck_fraction); + if (is_truck) { + spawn_length = random_uniform(env->truck_spawn_length_min, env->truck_spawn_length_max); + spawn_width = random_uniform(env->truck_spawn_width_min, env->truck_spawn_width_max); + } else if (env->eval_mode) { + spawn_length = random_uniform(env->eval_spawn_length_min, env->eval_spawn_length_max); + spawn_width = random_uniform(env->eval_spawn_width_min, env->eval_spawn_width_max); } else { - // Random size for training mode - spawn_length = random_uniform(0.8f, 7.0f); - spawn_width = random_uniform(0.8f, 3.0f); + spawn_length = random_uniform(env->spawn_length_min, env->spawn_length_max); + spawn_width = random_uniform(env->spawn_width_min, env->spawn_width_max); } if (spawn_width > spawn_length) { spawn_width = spawn_length; diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index d49563ebf6..4528f60ffe 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -56,6 +56,19 @@ def __init__( emit_completed_episodes=False, dt=0.1, spawn_initial_speed=0.0, + spawn_length_min=0.8, + spawn_length_max=7.0, + spawn_width_min=0.8, + spawn_width_max=3.0, + eval_spawn_length_min=2.0, + eval_spawn_length_max=5.5, + eval_spawn_width_min=1.5, + eval_spawn_width_max=2.5, + truck_fraction=0.0, + truck_spawn_length_min=9.0, + truck_spawn_length_max=15.0, + truck_spawn_width_min=2.0, + truck_spawn_width_max=2.6, goal_speed=3.0, scenario_length=None, resample_frequency=91, @@ -108,6 +121,19 @@ def __init__( ): self.dt = dt self.spawn_initial_speed = float(spawn_initial_speed) + self.spawn_length_min = float(spawn_length_min) + self.spawn_length_max = float(spawn_length_max) + self.spawn_width_min = float(spawn_width_min) + self.spawn_width_max = float(spawn_width_max) + self.eval_spawn_length_min = float(eval_spawn_length_min) + self.eval_spawn_length_max = float(eval_spawn_length_max) + self.eval_spawn_width_min = float(eval_spawn_width_min) + self.eval_spawn_width_max = float(eval_spawn_width_max) + self.truck_fraction = float(truck_fraction) + self.truck_spawn_length_min = float(truck_spawn_length_min) + self.truck_spawn_length_max = float(truck_spawn_length_max) + self.truck_spawn_width_min = float(truck_spawn_width_min) + self.truck_spawn_width_max = float(truck_spawn_width_max) self.goal_speed = float(goal_speed) self.reward_conditioning = reward_conditioning self.reward_randomization = reward_randomization @@ -394,6 +420,19 @@ def _env_init_kwargs(self, map_file, max_agents): "traffic_control_scope": self.traffic_control_scope, "dt": self.dt, "spawn_initial_speed": self.spawn_initial_speed, + "spawn_length_min": self.spawn_length_min, + "spawn_length_max": self.spawn_length_max, + "spawn_width_min": self.spawn_width_min, + "spawn_width_max": self.spawn_width_max, + "eval_spawn_length_min": self.eval_spawn_length_min, + "eval_spawn_length_max": self.eval_spawn_length_max, + "eval_spawn_width_min": self.eval_spawn_width_min, + "eval_spawn_width_max": self.eval_spawn_width_max, + "truck_fraction": self.truck_fraction, + "truck_spawn_length_min": self.truck_spawn_length_min, + "truck_spawn_length_max": self.truck_spawn_length_max, + "truck_spawn_width_min": self.truck_spawn_width_min, + "truck_spawn_width_max": self.truck_spawn_width_max, "goal_speed": self.goal_speed, "scenario_length": int(self.scenario_length) if self.scenario_length is not None else None, "termination_mode": int(self.termination_mode), diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index aa038f67d7..6af52e9df1 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -115,6 +115,30 @@ def logits_to_float(logits): return logits.float() return tuple(l.float() for l in logits) +# [env] keys whose values shape the observation / action tensors and therefore +# the network's input/output dimensions. When loading a checkpoint, these are +# pulled from the base run's config.yaml and override whatever drive.ini / the +# finetune overlay / CLI args specify — otherwise state_dict load would fail. +# [policy] and [rnn] sections are whole-section locked (see train()). +FINETUNE_LOCKED_ENV_KEYS = { + "action_type", + "dynamics_model", + "target_type", + "num_target_waypoints", + "reward_conditioning", + "reward_randomization", + "trajectory_prediction_length", + "num_trajectory_scaling_factors", + "trajectory_scaling_factors", + "max_boundary_segment_observations", + "max_lane_segment_observations", + "max_partner_observations", + "max_traffic_control_observations", + "traffic_control_scope", + "boundary_segment_dropout", + "lane_segment_dropout", +} + class PuffeRL: def __init__(self, config, vecenv, policy, logger=None): @@ -232,17 +256,56 @@ def __init__(self, config, vecenv, policy, logger=None): self.policy.forward_eval = torch.compile(self.uncompiled_policy.forward_eval, **compile_kwargs) pufferlib.pytorch.sample_logits = torch.compile(pufferlib.pytorch.sample_logits, **compile_kwargs) + # Build the iterable of params (or param groups) for the optimizer. + # Under finetune.enabled, freeze_regex / mode=lora may have already + # masked some params (requires_grad=False) and added LoRA adapter + # params; build a two-group config so the LoRA adapters can ride a + # different LR (typically 10x base_lr per the LoRA paper). + finetune_cfg = (config.get("finetune") or {}) if isinstance(config, dict) else {} + if finetune_cfg.get("enabled", False): + from pufferlib.finetune import build_param_groups + + base_lr_for_opt = config["learning_rate"] + lora_lr = base_lr_for_opt * float(finetune_cfg.get("lora_lr_mult", 1.0) or 1.0) + optim_param_arg = build_param_groups(self.policy, base_lr_for_opt, lora_lr) + if not optim_param_arg: + raise pufferlib.APIUsageError( + "[finetune] no trainable params after applying freeze/LoRA. " + "Check freeze_regex / lora_target." + ) + else: + optim_param_arg = self.policy.parameters() + + # Build the iterable of params (or param groups) for the optimizer. + # Under finetune.enabled, freeze_regex / mode=lora may have already + # masked some params (requires_grad=False) and added LoRA adapter + # params; build a two-group config so the LoRA adapters can ride a + # different LR (typically 10x base_lr per the LoRA paper). + finetune_cfg = (config.get("finetune") or {}) if isinstance(config, dict) else {} + if finetune_cfg.get("enabled", False): + from pufferlib.finetune import build_param_groups + + base_lr_for_opt = config["learning_rate"] + lora_lr = base_lr_for_opt * float(finetune_cfg.get("lora_lr_mult", 1.0) or 1.0) + optim_param_arg = build_param_groups(self.policy, base_lr_for_opt, lora_lr) + if not optim_param_arg: + raise pufferlib.APIUsageError( + "[finetune] no trainable params after applying freeze/LoRA. Check freeze_regex / lora_target." + ) + else: + optim_param_arg = self.policy.parameters() + # Optimizer if config["optimizer"] == "adam": optimizer = torch.optim.Adam( - self.policy.parameters(), + optim_param_arg, lr=config["learning_rate"], betas=(config["adam_beta1"], config["adam_beta2"]), eps=config["adam_eps"], ) elif config["optimizer"] == "adamw": optimizer = torch.optim.AdamW( - self.policy.parameters(), + optim_param_arg, lr=config["learning_rate"], betas=(config["adam_beta1"], config["adam_beta2"]), eps=config["adam_eps"], @@ -261,7 +324,7 @@ def __init__(self, config, vecenv, policy, logger=None): # heavyball_momentum=True introduced in heavyball 2.1.1 # recovers heavyball-1.7.2 behaviour - previously swept hyperparameters work well optimizer = ForeachMuon( - self.policy.parameters(), + optim_param_arg, lr=config["learning_rate"], betas=(config["adam_beta1"], config["adam_beta2"]), eps=config["adam_eps"], @@ -1363,43 +1426,11 @@ def _save_experiment_config(args, path): def train(env_name, args=None, vecenv=None, policy=None, logger=None, early_stop_fn=None): args = args or load_config(env_name) - # Fine-tuning: reload network, observation configuration from config.yaml and override the args --> only change new reward / new maps / new simulation mode - if args["load_model_path"]: - experiment_dir = os.path.dirname(args["load_model_path"]) - config_yaml_path = os.path.join(experiment_dir, "config.yaml") - KEYS_OF_INTEREST = { - "action_type", - "dynamics_model", - "target_type", - "num_target_waypoints", - "reward_conditioning", - "reward_randomization", - "trajectory_prediction_length", - "num_trajectory_scaling_factors", - "trajectory_scaling_factors", - "max_boundary_segment_observations", - "max_lane_segment_observations", - "boundary_segment_dropout", - "lane_segment_dropout", - "max_partner_observations", - "max_traffic_control_observations", - "traffic_control_scope", - } - if os.path.exists(config_yaml_path): - print(f"Found config.yaml at {config_yaml_path}. Merging with defaults...") - with open(config_yaml_path, "r") as f: - yaml_config = yaml.safe_load(f) - - # Override Policy and RNN dimensions from model config - for section in ["policy", "rnn"]: - if section in yaml_config and isinstance(yaml_config[section], dict): - for k, v in yaml_config[section].items(): - args[section][k] = v - # Override ENV parameters for observation size from model config - if "env" in yaml_config and isinstance(yaml_config["env"], dict): - for k, v in yaml_config["env"].items(): - if k in KEYS_OF_INTEREST: - args["env"][k] = v + # Fine-tuning: reload network / observation config from the base run's + # config.yaml and override args so the rebuilt policy matches the saved + # checkpoint's shapes. Everything not in the locked set (rewards, maps, + # sim params, train HPs) stays free to differ. + _apply_checkpoint_arch_lock(args) # Assume TorchRun DDP is used if LOCAL_RANK is set if "LOCAL_RANK" in os.environ: @@ -1441,7 +1472,21 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, early_stop experiment_dir=experiment_dir, ) - train_config = dict(**args["train"], env=env_name, eval=args.get("eval", {})) + # Apply finetune.base_lr override before PuffeRL builds the optimizer. + finetune_cfg = args.get("finetune", {}) + finetune_enabled = bool(finetune_cfg.get("enabled", False)) + if finetune_enabled: + base_lr = finetune_cfg.get("base_lr", None) + if base_lr is not None: + print(f"[finetune] overriding train.learning_rate {args['train']['learning_rate']} -> {base_lr}") + args["train"]["learning_rate"] = base_lr + + train_config = dict( + **args["train"], + env=env_name, + eval=args.get("eval", {}), + finetune=args.get("finetune", {}), + ) pufferl = PuffeRL(train_config, vecenv, policy, logger) if args["train"].get("resume_state_path"): @@ -1451,10 +1496,11 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, early_stop pufferl._eval_manager = EvalManager.from_config(args, run_id=logger.run_id if logger else None) - # Restore optimizer state + step counters when resuming from a checkpoint. - # save_checkpoint writes models/model__.pt and trainer_state.pt - # (sibling of models/) — so trainer_state.pt is one dir above the .pt path. - if args.get("load_model_path"): + # Restore optimizer state + step counters when RESUMING from a checkpoint. + # When finetune.enabled=True we deliberately start fresh — a finetune is a + # new training run that happens to inherit base weights, not a continuation + # of the base's optimizer momentum / LR schedule / step counter. + if args.get("load_model_path") and not finetune_enabled: trainer_state_path = os.path.join(os.path.dirname(os.path.dirname(args["load_model_path"])), "trainer_state.pt") if os.path.exists(trainer_state_path): print(f"Resuming optimizer/step state from {trainer_state_path}") @@ -1468,10 +1514,35 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, early_stop pufferl.scheduler.step() else: print(f"No trainer_state.pt next to {args['load_model_path']}; starting optimizer fresh.") + elif args.get("load_model_path") and finetune_enabled: + print( + f"[finetune] mode={finetune_cfg.get('mode', 'full')}: loaded weights from " + f"{args['load_model_path']}; starting optimizer / step counter / LR schedule from epoch 0." + ) path = os.path.join(args["train"]["data_dir"], f"{env_name}_{pufferl.logger.run_id}") _save_experiment_config(args, path) + # Drop a finetune_meta.yaml next to config.yaml so the run is self-describing: + # what base checkpoint was used, what strategy, which params were trainable. + if finetune_enabled and args.get("load_model_path"): + meta = { + "base_checkpoint": args["load_model_path"], + "mode": finetune_cfg.get("mode", "full"), + "base_lr": args["train"]["learning_rate"], + "total_timesteps": train_config["total_timesteps"], + "freeze_regex": finetune_cfg.get("freeze_regex", None), + "lora_target": finetune_cfg.get("lora_target", None), + "lora_rank": finetune_cfg.get("lora_rank", None), + "lora_alpha": finetune_cfg.get("lora_alpha", None), + "lora_lr_mult": finetune_cfg.get("lora_lr_mult", None), + "finetune_config": args.get("finetune_config", None), + } + meta_path = os.path.join(path, "finetune_meta.yaml") + with open(meta_path, "w") as f: + yaml.dump(meta, f, sort_keys=False) + print(f"[finetune] wrote {meta_path}") + # Sweep needs data for early stopped runs, so send data when steps > 100M logging_threshold = min(0.20 * train_config["total_timesteps"], 100_000_000) all_logs = [] @@ -1665,11 +1736,10 @@ def eval( args = args or load_config(env_name) - # When evaluating a checkpoint, adopt its network architecture from the - # training run's sibling config.yaml so the policy is built to match the - # weights regardless of what drive.ini currently says. - if args.get("load_model_path"): - _merge_checkpoint_arch(args, args["load_model_path"]) + # Subprocess evals (and standalone `puffer eval`) re-read drive.ini and + # have no idea what arch the parent run used. Apply the same checkpoint + # arch-lock train() does so the policy is built with matching shapes. + _apply_checkpoint_arch_lock(args) if evaluator_name is None: evaluator_name = args.get("evaluator") @@ -2041,6 +2111,35 @@ def load_env(env_name, args): return pufferlib.vector.make(make_env, env_kwargs=args["env"], **args["vec"]) +def _load_and_verify_checkpoint(policy, ckpt_path, device, source): + """Load `ckpt_path` into `policy` strictly, then verify every tensor in + the checkpoint was copied in exactly. Raises on mismatch; logs a single- + line confirmation on success so finetune runs can prove which weights + they started from. + """ + state_dict = torch.load(ckpt_path, map_location=device) + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + + # strict=True: raises RuntimeError on any missing or unexpected key. + policy.load_state_dict(state_dict, strict=True) + + policy_sd = policy.state_dict() + mismatched = [k for k, v in state_dict.items() if not torch.equal(policy_sd[k], v.to(policy_sd[k].device))] + if mismatched: + raise RuntimeError( + f"[load] {len(mismatched)}/{len(state_dict)} tensors did not match after load from {ckpt_path}: " + f"{mismatched[:5]}{'...' if len(mismatched) > 5 else ''}" + ) + + total_params = sum(v.numel() for v in state_dict.values()) + # Content fingerprint: tiny but stable across machines for the same .pt. + fingerprint = sum(float(v.detach().to(torch.float64).sum().item()) for v in state_dict.values()) + print( + f"[load] {source}: loaded {len(state_dict)} tensors ({total_params:,} params) " + f"from {ckpt_path}; all values verified (sum-fingerprint={fingerprint:.6e})." + ) + + def load_policy(args, vecenv, env_name=""): package = args["package"] module_name = "pufferlib.ocean" if package == "ocean" else f"pufferlib.environments.{package}" @@ -2066,23 +2165,123 @@ def load_policy(args, vecenv, env_name=""): else: raise pufferlib.APIUsageError("No run id provided for eval") - state_dict = torch.load(path, map_location=device) - policy.load_state_dict(clean_policy_state_dict(state_dict)) + _load_and_verify_checkpoint(policy, path, device, source=f"load_id={load_id}") load_path = args["load_model_path"] if load_path == "latest": load_path = max(glob.glob(f"experiments/{env_name}*.pt"), key=os.path.getctime) if load_path is not None: - state_dict = torch.load(load_path, map_location=device) - policy.load_state_dict(clean_policy_state_dict(state_dict)) + _load_and_verify_checkpoint(policy, load_path, device, source="load_model_path") # state_path = os.path.join(*load_path.split('/')[:-1], 'state.pt') # optim_state = torch.load(state_path)['optimizer_state_dict'] # pufferl.optimizer.load_state_dict(optim_state) + # ----- Phase 2: parameter-efficient finetune wiring ----- + # Order: load_state_dict has already run (base weights are in `policy`). + # apply_freeze first so freeze_regex can pin any non-LoRA params; then + # wrap_lora replaces matched nn.Linears with LoRALinear (which copies the + # just-loaded weight as a frozen base + adds trainable A,B init at zero). + # Must happen BEFORE DDP wrap / torch.compile / optimizer construction. + finetune_cfg = (args.get("finetune") or {}) if isinstance(args, dict) else {} + if finetune_cfg.get("enabled", False): + from pufferlib.finetune import apply_freeze, wrap_lora, trainable_summary + + mode = finetune_cfg.get("mode", "full") or "full" + if mode in ("freeze", "lora"): + n_frozen = apply_freeze(policy, finetune_cfg.get("freeze_regex")) + if n_frozen: + print( + f"[finetune] froze {n_frozen} parameter tensors matching " + f"freeze_regex={finetune_cfg.get('freeze_regex')!r}" + ) + if mode == "lora": + rank = int(finetune_cfg.get("lora_rank", 16)) + alpha = float(finetune_cfg.get("lora_alpha", 32)) + lora_target = finetune_cfg.get("lora_target") + n_wrapped = wrap_lora(policy, target_regex=lora_target, rank=rank, alpha=alpha) + if n_wrapped: + print( + f"[finetune] wrapped {n_wrapped} nn.Linear modules with LoRA " + f"(rank={rank}, alpha={alpha}, target={lora_target!r})" + ) + else: + print( + f"[finetune] WARNING: mode=lora but 0 modules matched " + f"lora_target={lora_target!r} — training will run as mode=full." + ) + if mode not in ("full", "freeze", "lora"): + raise pufferlib.APIUsageError(f"[finetune] unknown mode={mode!r}; expected one of full / freeze / lora") + print(trainable_summary(policy)) + return policy +def _apply_checkpoint_arch_lock(args): + """If args has a load_model_path with a sibling config.yaml, pull the + locked architecture / observation-shape keys from that config.yaml and + override the current args in place so the rebuilt policy has shapes that + match the saved checkpoint. Whole [policy] and [rnn] sections are + overridden; within [env], only FINETUNE_LOCKED_ENV_KEYS are. + + Called from both train() (so finetune runs inherit base architecture) + and eval() (so subprocess evals match the parent's architecture even + though the subprocess only receives --load-model-path on the CLI). + """ + if not args.get("load_model_path"): + return + # save_checkpoint layout: /models/_.pt — config.yaml + # lives at /config.yaml (one dir above models/). + experiment_dir = os.path.dirname(os.path.dirname(args["load_model_path"])) + config_yaml_path = os.path.join(experiment_dir, "config.yaml") + if not os.path.exists(config_yaml_path): + return + print(f"Found config.yaml at {config_yaml_path}. Merging with defaults...") + with open(config_yaml_path, "r") as f: + yaml_config = yaml.safe_load(f) + for section in ("policy", "rnn"): + if section in yaml_config and isinstance(yaml_config[section], dict): + args.setdefault(section, {}) + for k, v in yaml_config[section].items(): + args[section][k] = v + if "env" in yaml_config and isinstance(yaml_config["env"], dict): + args.setdefault("env", {}) + for k, v in yaml_config["env"].items(): + if k in FINETUNE_LOCKED_ENV_KEYS: + args["env"][k] = v + + +def _warn_locked_overlay_keys(overlay_p, overlay_path): + """Emit a warning for any key in the finetune overlay that will be ignored + because it's locked to the base checkpoint's architecture. + + Whole sections [policy] and [rnn] are locked. Within [env], only the keys + in FINETUNE_LOCKED_ENV_KEYS are locked. We warn but do not strip; the + locked override happens later in train() via the base run's config.yaml. + """ + locked_hits = [] # list of (section, key) + for section in ("policy", "rnn"): + if section in overlay_p.sections(): + for key in overlay_p[section]: + locked_hits.append((section, key)) + if "env" in overlay_p.sections(): + for key in overlay_p["env"]: + if key in FINETUNE_LOCKED_ENV_KEYS: + locked_hits.append(("env", key)) + if not locked_hits: + return + print( + f"[finetune] WARNING: {len(locked_hits)} key(s) in overlay '{overlay_path}' " + f"will be IGNORED because they are locked to the base checkpoint's architecture:" + ) + for section, key in locked_hits: + print(f"[finetune] - [{section}] {key}") + print( + "[finetune] (architecture / observation-shape keys are inherited from the base's " + "config.yaml at load time so the saved weights can be loaded.)" + ) + + def load_config(env_name, config_dir=None): parser = argparse.ArgumentParser( description=f":blowfish: PufferLib [bright_cyan]{pufferlib.__version__}[/]" @@ -2117,6 +2316,13 @@ def load_config(env_name, config_dir=None): parser.add_argument( "--eval_simulation", type=str, default=None, help="Simulation mode for evaluation - gigaflow/replay" ) + parser.add_argument( + "--finetune-config", + type=str, + default=None, + help="Optional overlay .ini layered on top of the env config. Used to specify only the " + "keys that differ for a finetune run (rewards, maps, [finetune] section, etc.).", + ) args = parser.parse_known_args()[0] if config_dir is None: @@ -2128,18 +2334,36 @@ def load_config(env_name, config_dir=None): # Load defaults and config puffer_config_dir = os.path.join(puffer_dir, "config/**/*.ini") puffer_default_config = os.path.join(puffer_dir, "config/default.ini") + # Layered .ini reads: defaults → env config → optional finetune overlay. + # ConfigParser.read() processes files left-to-right with last-write-wins, so the + # overlay can override any key in the env config (except FINETUNE_LOCKED_ENV_KEYS, + # which we warn about below and which train() later overrides from base config.yaml). + config_layers = [puffer_default_config] if env_name == "default": p = configparser.ConfigParser(inline_comment_prefixes=(";", "#")) - p.read(puffer_default_config) + p.read(config_layers) else: for path in glob.glob(puffer_config_dir, recursive=True): p = configparser.ConfigParser(inline_comment_prefixes=(";", "#")) - p.read([puffer_default_config, path]) + p.read(config_layers + [path]) if env_name in p["base"]["env_name"].split(): + config_layers.append(path) break else: raise pufferlib.APIUsageError("No config for env_name {}".format(env_name)) + if args.finetune_config is not None: + if not os.path.exists(args.finetune_config): + raise pufferlib.APIUsageError(f"--finetune-config path not found: {args.finetune_config}") + # Read the overlay alone first so we can detect which keys came from it + # (vs which were inherited from drive.ini) — used by the locked-key warning. + overlay_p = configparser.ConfigParser(inline_comment_prefixes=(";", "#")) + overlay_p.read(args.finetune_config) + _warn_locked_overlay_keys(overlay_p, args.finetune_config) + # Merge overlay on top. + p.read(args.finetune_config) + config_layers.append(args.finetune_config) + # Dynamic help menu from config def puffer_type(value): try: