Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,16 @@ You must specify exactly one determining factor for training duration in the `tr

### 6. Environment Modes

- **Offline**: `offline: true`
- **Offline**: `offline: true`
Disables Hugging Face Hub / Weights & Biases network calls (essential for air-gapped nodes).
- **Debug**: `debug.enabled: true`
- **Debug**: `debug.enabled: true`
Forces `report_to: none`, uses a separate output directory, and allows overwriting existing runs.
- **Tokenize only**: `--tokenize-only` (CLI flag on `train.py` / `submit.py`)
Exits immediately after the trainer is initialized — dataset loading, tokenization, and packing all run, but the training loop is never entered. Useful for pretokenizing the dataset before committing to a full run. When passed to `submit.py`, the job is automatically constrained to 1 node and 1 GPU.

```bash
python scripts/submit.py --config configs/trl/sft.yaml --tokenize-only
```

### 7. Logging & Experiment Tracking

Expand Down
15 changes: 12 additions & 3 deletions scripts/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,26 @@ def _parse_args() -> tuple[str, list[str], bool]:
action="store_true",
help="Skip the interactive guardrails review and submit immediately.",
)
parser.add_argument(
"--tokenize-only",
action="store_true",
help="Pass --tokenize-only to train.py — exits after trainer initialization.",
)
known, unknown = parser.parse_known_args()
return known.config, unknown, known.confirm
return known.config, unknown, known.confirm, known.tokenize_only


def main() -> None:
setup_logging()

config_path, cli_overrides, confirmed = _parse_args()
config_path, cli_overrides, confirmed, tokenize_only = _parse_args()
logger.info("Loading config from %s", config_path)
config = PostTrainingConfig.load(config_path, cli_overrides)

if tokenize_only:
config.slurm.num_nodes = 1
config.slurm.gpus_per_node = 1

if config.offline:
logger.info(
"offline=True: pre-fetching models and datasets on the login node "
Expand All @@ -66,7 +75,7 @@ def main() -> None:
logger.info("Run directory: %s", run_dir)

if not confirmed:
run_guardrails(config, run_dir)
run_guardrails(config, run_dir, tokenize_only=tokenize_only)

# CRITICAL: Set run_name so it's preserved in the frozen config.
# This ensures train.py uses the same directory when it loads the config.
Expand Down
13 changes: 11 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,19 @@ def _parse_args() -> tuple[str, list[str]]:
default="configs/trl/sft.yaml",
help="Path to the YAML config file.",
)
parser.add_argument(
"--tokenize-only",
action="store_true",
help="Exit after initializing the trainer (useful for verifying tokenization).",
)
known, unknown = parser.parse_known_args()
return known.config, unknown
return known.config, known.tokenize_only, unknown


def main() -> None:
setup_logging()

config_path, cli_overrides = _parse_args()
config_path, tokenize_only, cli_overrides = _parse_args()
logger.info("Loading config from %s", config_path)
config = PostTrainingConfig.load(config_path, cli_overrides)

Expand Down Expand Up @@ -108,6 +113,10 @@ def main() -> None:
# ── Build trainer & launch ──────────────────────────────────────
trainer = build_trainer(config, run_dir)

if tokenize_only:
logger.info("--tokenize-only set — exiting after trainer initialization.")
return

# Auto-resume from the latest checkpoint if one exists.
checkpoints_dir = run_dir / "checkpoints"
existing = sorted(checkpoints_dir.glob("checkpoint-*"))
Expand Down
8 changes: 7 additions & 1 deletion src/post_training/utils/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _batch_summary(config: PostTrainingConfig, total_gpus: int) -> tuple[str, st
# ---------------------------------------------------------------------------


def run_guardrails(config: PostTrainingConfig, run_dir: Path) -> None:
def run_guardrails(config: PostTrainingConfig, run_dir: Path, tokenize_only: bool = False) -> None:
"""Print a full config summary and ask for confirmation.

Parameters
Expand Down Expand Up @@ -261,6 +261,12 @@ def run_guardrails(config: PostTrainingConfig, run_dir: Path) -> None:
_row("Debug mode", _red("*** ENABLED — output dir may be overwritten ***"), warn=True)
else:
_row("Debug mode", "disabled")
if tokenize_only:
_row(
"Tokenize only",
_yellow("*** ENABLED — will exit after trainer init, no training ***"),
warn=True,
)

# ------------------------------------------------------------------
# Final confirmation
Expand Down
Loading