4343from mp_actors import close_proxy , move_to_child_process
4444
4545from .. import dev
46+ from .._backend_training import (
47+ aggregate_rl_training_metrics ,
48+ build_rl_train_configs ,
49+ )
4650from ..backend import AnyTrainableModel , Backend
4751from ..costs import build_cost_calculator , get_model_pricing
4852from ..metrics_taxonomy import (
4953 TRAIN_GRADIENT_STEPS_KEY ,
50- average_metric_samples ,
5154 build_training_summary_metrics ,
5255 summarize_trajectory_groups ,
5356)
@@ -642,45 +645,36 @@ async def train( # type: ignore[override]
642645 if adam_params is not None :
643646 raise ValueError ("LocalBackend requires adam_params=None." )
644647
645- # Build config objects from explicit kwargs
646- config = TrainConfig (
647- learning_rate = learning_rate , kl_penalty_coef = kl_penalty_coef
648- )
649- dev_config : dev .TrainConfig = {
650- "advantage_balance" : advantage_balance ,
651- "allow_training_without_logprobs" : allow_training_without_logprobs ,
652- "importance_sampling_level" : importance_sampling_level ,
653- "kl_penalty_coef" : kl_penalty_coef ,
654- "mask_prob_ratio" : mask_prob_ratio ,
655- "plot_tensors" : plot_tensors ,
656- "ppo" : loss_fn == "ppo" ,
657- "precalculate_logprobs" : precalculate_logprobs ,
658- "scale_learning_rate_by_reward_std_dev" : scale_learning_rate_by_reward_std_dev ,
659- "scale_rewards" : scale_rewards ,
660- "logprob_calculation_chunk_size" : logprob_calculation_chunk_size ,
661- "num_trajectories_learning_rate_multiplier_power" : num_trajectories_learning_rate_multiplier_power ,
662- }
663- # Only include optional fields if they're set
664- if epsilon is not None :
665- dev_config ["epsilon" ] = epsilon
666- if epsilon_high is not None :
667- dev_config ["epsilon_high" ] = epsilon_high
668- if max_negative_advantage_importance_sampling_weight is not None :
669- dev_config ["max_negative_advantage_importance_sampling_weight" ] = (
670- max_negative_advantage_importance_sampling_weight
671- )
672- if kimi_k2_tau is not None :
673- dev_config ["kimi_k2_tau" ] = kimi_k2_tau
674- if truncated_importance_sampling is not None :
675- dev_config ["truncated_importance_sampling" ] = truncated_importance_sampling
676- if kl_ref_adapter_path is not None :
677- dev_config ["kl_ref_adapter_path" ] = kl_ref_adapter_path
678- elif kl_penalty_reference_step is not None :
679- ref_checkpoint_dir = get_step_checkpoint_dir (
648+ resolved_kl_ref_adapter_path = kl_ref_adapter_path
649+ if (
650+ resolved_kl_ref_adapter_path is None
651+ and kl_penalty_reference_step is not None
652+ ):
653+ resolved_kl_ref_adapter_path = get_step_checkpoint_dir (
680654 get_model_dir (model = model , art_path = self ._path ),
681655 kl_penalty_reference_step ,
682656 )
683- dev_config ["kl_ref_adapter_path" ] = ref_checkpoint_dir
657+ config , dev_config = build_rl_train_configs (
658+ learning_rate = learning_rate ,
659+ advantage_balance = advantage_balance ,
660+ scale_rewards = scale_rewards ,
661+ importance_sampling_level = importance_sampling_level ,
662+ mask_prob_ratio = mask_prob_ratio ,
663+ ppo = loss_fn == "ppo" ,
664+ precalculate_logprobs = precalculate_logprobs ,
665+ epsilon = epsilon ,
666+ epsilon_high = epsilon_high ,
667+ max_negative_advantage_importance_sampling_weight = max_negative_advantage_importance_sampling_weight ,
668+ kimi_k2_tau = kimi_k2_tau ,
669+ kl_penalty_coef = kl_penalty_coef ,
670+ allow_training_without_logprobs = allow_training_without_logprobs ,
671+ plot_tensors = plot_tensors ,
672+ truncated_importance_sampling = truncated_importance_sampling ,
673+ scale_learning_rate_by_reward_std_dev = scale_learning_rate_by_reward_std_dev ,
674+ logprob_calculation_chunk_size = logprob_calculation_chunk_size ,
675+ num_trajectories_learning_rate_multiplier_power = num_trajectories_learning_rate_multiplier_power ,
676+ kl_ref_adapter_path = resolved_kl_ref_adapter_path ,
677+ )
684678
685679 # Collect metrics from training
686680 training_metrics : list [dict [str , float ]] = []
@@ -690,21 +684,10 @@ async def train( # type: ignore[override]
690684 ):
691685 training_metrics .append (metrics )
692686
693- # Aggregate metrics
694- avg_metrics = average_metric_samples (training_metrics )
695- summary = summarize_trajectory_groups (groups_list )
696- avg_metrics .setdefault (
697- "time/step_trainer_s" , time .monotonic () - trainer_started
698- )
699- avg_metrics .update (
700- {
701- key : value
702- for key , value in build_training_summary_metrics (
703- summary ,
704- include_trainable_groups = True ,
705- ).items ()
706- if key not in avg_metrics
707- }
687+ avg_metrics = aggregate_rl_training_metrics (
688+ training_metrics = training_metrics ,
689+ trajectory_groups = groups_list ,
690+ trainer_started = trainer_started ,
708691 )
709692
710693 # Get step and checkpoint path
0 commit comments