Skip to content

Commit b5ec09a

Browse files
committed
updated config for teacher training
1 parent d1ad1c4 commit b5ec09a

File tree

3 files changed

+41
-45
lines changed

3 files changed

+41
-45
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,5 @@ outputs/
8282
.env/
8383

8484
.DS_Store
85+
86+
runs/

configs/teacher/sst2_hf.yaml

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,43 @@
1-
model:
2-
model_name: "microsoft/deberta-v3-base"
3-
num_labels: 2
4-
use_fast_tokenizer: true
5-
6-
data:
7-
dataset_path: "./data/clean/" # Use HF dataset identifier
8-
max_len: 32
9-
train_split: "train"
10-
validation_split: "val"
11-
test_split: "test"
12-
131
training:
14-
output_dir: "runs/teacher/deberta_v3_base" # Specific output for this run
2+
# ---------- bookkeeping ----------
3+
output_dir: "runs/teacher/deberta_v3_base"
154
overwrite_output_dir: true
16-
run_name: "teacher_sst2_deberta_v3_base_run" # Optional W&B/TensorBoard run name
17-
18-
# Reporting
19-
report_to: "wandb"
20-
wandb_project: "senti_synth_teacher"
21-
22-
# Batching & Epochs
23-
per_device_train_batch_size: 16
24-
per_device_eval_batch_size: 32
25-
gradient_accumulation_steps: 1
26-
num_train_epochs: 3
27-
28-
# Optimizer & Scheduler
29-
learning_rate: 0.00003
30-
warmup_ratio: 0.1
31-
32-
# Logging, Saving, Evaluation
33-
logging_steps: 50
34-
eval_steps: 200 # Evaluate every N steps
35-
save_steps: 200 # Save checkpoint every N steps
36-
save_total_limit: 2 # Keep only the best and the latest checkpoints
37-
load_best_model_at_end: true # Load the best model found during training
38-
metric_for_best_model: "eval_f1" # Metric to determine the 'best' model
5+
run_name: "teacher_sst2_deberta_v3_base_h100"
6+
7+
report_to: "wandb"
8+
wandb_project: "senti_synth_teacher"
9+
10+
# ---------- batch size & epochs ----------
11+
per_device_train_batch_size: 64 # 4× bigger than before; fits easily in 80 GB
12+
per_device_eval_batch_size: 256 # evaluation is memory‑lighter, so push higher
13+
gradient_accumulation_steps: 1 # no need for micro‑batching on an H100
14+
num_train_epochs: 6 # SST‑2 is small; 4 epochs normally reaches peak F1
15+
16+
# ---------- precision & speed ----------
17+
bf16: true # H100 has native BF16; gives ~1.8× speed‑up over FP32
18+
fp16: false # turn FP16 off to avoid two mixed‑precision modes
19+
# If you prefer automatic selection, drop bf16/fp16 and add `torch_dtype: "auto"`
20+
21+
# ---------- optimiser & scheduler ----------
22+
learning_rate: 0.0001 # linear‑scale LR (16→64 batch ⇒ ×4 LR)
23+
warmup_ratio: 0.05 # keep warm‑up tokens roughly constant after batch change
24+
25+
# ---------- misc performance knobs ----------
26+
dataloader_num_workers: 8 # plenty of CPU headroom; hides data‑loading latency
27+
gradient_checkpointing: false # not needed; trade memory for speed
28+
max_grad_norm: 1.0 # good default when using larger LR + BF16
29+
30+
# ---------- logging, saving, early stop ----------
31+
logging_steps: 100
32+
eval_steps: 500
33+
save_steps: 500
34+
save_total_limit: 3
35+
load_best_model_at_end: true
36+
metric_for_best_model: "eval_f1"
3937
greater_is_better: true
4038

41-
# Hardware & Performance
42-
fp16: true # Set to false if GPU doesn't support FP16 or causes issues
43-
44-
# Callbacks
4539
use_early_stopping: true
46-
early_stopping_patience: 3
47-
early_stopping_threshold: 0.001 # Small improvement needed to reset patience
40+
early_stopping_patience: 2 # fewer epochs, so tighten patience
41+
early_stopping_threshold: 0.0005
4842

49-
# Optional: Evaluate on test set after training
50-
do_test_eval: true
43+
do_test_eval: true

src/cli/01_train_teacher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def main(config_path: Path = typer.Argument(..., help="Path to YAML config")):
9898
trainer.save_metrics("train", metrics)
9999

100100
# Evaluate on test set if available
101-
test_dataset = data_module.get_test_dataset()
101+
# We use the sanity set as test set since the test set labels are all -1
102+
test_dataset = data_module.get_sanity_dataset()
102103
if test_dataset and cfg['training'].get("do_test_eval", True):
103104
logger.info("Evaluating on test set...")
104105
test_metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")

0 commit comments

Comments
 (0)