diff --git a/README.md b/README.md index 6ec9ede..fdb17fd 100644 --- a/README.md +++ b/README.md @@ -408,7 +408,7 @@ training: learning_rate: 2.0e-5 effective_batch_size: 32 # per_device * grad_accum * world_size per_device_train_batch_size: 8 - warmup_ratio: 0.03 + warmup_steps: 0.03 lr_scheduler_type: "cosine_with_min_lr" lr_scheduler_kwargs: min_lr_rate: 0.1 diff --git a/configs/llamafactory/long-context.yaml b/configs/llamafactory/long-context.yaml index 391229f..14c0594 100644 --- a/configs/llamafactory/long-context.yaml +++ b/configs/llamafactory/long-context.yaml @@ -72,7 +72,7 @@ llamafactory: learning_rate: 2.0e-4 num_train_epochs: 1.0 lr_scheduler_type: cosine - warmup_ratio: 0.05 + warmup_steps: 0.05 weight_decay: 0.03 max_grad_norm: 1.0e-3 bf16: true diff --git a/configs/trl/dpo.yaml b/configs/trl/dpo.yaml index b388e0d..dde3263 100644 --- a/configs/trl/dpo.yaml +++ b/configs/trl/dpo.yaml @@ -35,7 +35,7 @@ training: learning_rate: 5.0e-7 effective_batch_size: 4 per_device_train_batch_size: 1 - warmup_ratio: 0.1 + warmup_steps: 0.1 lr_scheduler_type: "cosine_with_min_lr" lr_scheduler_kwargs: min_lr_rate: 0.1 diff --git a/configs/trl/sft.yaml b/configs/trl/sft.yaml index 5829762..32a7891 100644 --- a/configs/trl/sft.yaml +++ b/configs/trl/sft.yaml @@ -45,7 +45,7 @@ training: learning_rate: 2.0e-5 effective_batch_size: 32 # per_device * grad_accum * world_size per_device_train_batch_size: 8 - warmup_ratio: 0.03 + warmup_steps: 0.03 lr_scheduler_type: "cosine_with_min_lr" lr_scheduler_kwargs: min_lr_rate: 0.1 diff --git a/src/post_training/config.py b/src/post_training/config.py index a0bde16..7a05ed8 100644 --- a/src/post_training/config.py +++ b/src/post_training/config.py @@ -61,7 +61,7 @@ class TrainingConfig: learning_rate: float = 2.0e-5 effective_batch_size: int = 512 per_device_train_batch_size: int = 4 - warmup_ratio: float = 0.03 + warmup_steps: float = 0.0 lr_scheduler_type: str = "cosine_with_min_lr" lr_scheduler_kwargs: LRSchedulerKwargs = field(default_factory=LRSchedulerKwargs) adam_beta1: float = 0.9 diff --git a/src/post_training/methods/common.py b/src/post_training/methods/common.py index 4cf0db4..af995cc 100644 --- a/src/post_training/methods/common.py +++ b/src/post_training/methods/common.py @@ -102,7 +102,7 @@ def build_common_training_kwargs( weight_decay=t.weight_decay, adam_epsilon=t.adam_epsilon, gradient_accumulation_steps=grad_accum, - warmup_steps=t.warmup_ratio, + warmup_steps=t.warmup_steps, lr_scheduler_type=t.lr_scheduler_type, lr_scheduler_kwargs={ k: v for k, v in dataclasses.asdict(t.lr_scheduler_kwargs).items() if v is not None diff --git a/src/post_training/utils/guardrails.py b/src/post_training/utils/guardrails.py index 08db38f..d54be25 100644 --- a/src/post_training/utils/guardrails.py +++ b/src/post_training/utils/guardrails.py @@ -207,7 +207,7 @@ def run_guardrails(config: PostTrainingConfig, run_dir: Path, tokenize_only: boo min_lr = config.training.lr_scheduler_kwargs.min_lr_rate lr_sched_str = lr_sched if min_lr is None else f"{lr_sched} (min_lr_rate={min_lr})" _row("LR scheduler", lr_sched_str) - _row("Warmup ratio", str(config.training.warmup_ratio)) + _row("Warmup steps", str(config.training.warmup_steps)) batch_line, _ = _batch_summary(config, total_gpus) _row("Batch sizes", batch_line) _row("Grad checkpoint", str(config.training.gradient_checkpointing))