Skip to content

Commit 03fa77f

Browse files
committed
added config files and training loop code
1 parent 3d1572e commit 03fa77f

7 files changed

Lines changed: 71 additions & 19 deletions

File tree

.gitignore

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ wheels/
2424
venv/
2525
env/
2626
ENV/
27-
27+
venv312/
28+
venv311/
2829
# Jupyter Notebook
2930
.ipynb_checkpoints
3031

@@ -66,4 +67,7 @@ data/clean/*
6667

6768
# Logs
6869
logs/
69-
*.log
70+
*.log
71+
72+
wandb/
73+
outputs/

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.11.2

configs/deberta_large_sst2.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
model_name: microsoft/deberta-v3-large
2-
dataset_path: data/sst2_dd
2+
dataset_path: data/processed
33
train_split: train
44
eval_split: val
5-
sanity_split: sanity
5+
sanity_split: sent_sanity
6+
test_split: test
67
max_len: 128
78
per_device_train_batch_size: 8
89
per_device_eval_batch_size: 32
910
gradient_accumulation_steps: 4
1011
num_train_epochs: 3
11-
learning_rate: 2e-5
12+
learning_rate: 0.00005
1213
warmup_ratio: 0.06
1314
fp16: true
1415
logging_steps: 50
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
model_name: prajjwal1/bert-small
2+
dataset_path: data/processed
3+
train_split: train
4+
eval_split: val
5+
sanity_split: sent_sanity
6+
test_split: test
7+
max_len: 128
8+
per_device_train_batch_size: 2
9+
per_device_eval_batch_size: 32
10+
gradient_accumulation_steps: 16
11+
num_train_epochs: 10
12+
learning_rate: 0.0005
13+
warmup_ratio: 0.06
14+
fp16: false
15+
logging_steps: 5
16+
eval_steps: 10
17+
save_steps: 10000
18+
output_dir: outputs/teacher
19+
report_to: wandb
20+
project_name: sst2_teacher_mps

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
torch>=1.9.0
2-
transformers>=4.10.0
2+
transformers
33
datasets>=1.11.0
44
numpy>=1.19.5
55
scikit-learn>=0.24.2

src/train_teacher.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
)
3030
logger = logging.getLogger(__name__)
3131

32+
# Set OMP_NUM_THREADS to 1 to avoid potential CPU over-subscription
33+
os.environ['OMP_NUM_THREADS'] = '1'
34+
logger.info(f"Setting OMP_NUM_THREADS=1")
35+
3236

3337
def load_config(config_path: str) -> Dict:
3438
"""Loads configuration from a YAML file."""
@@ -46,9 +50,11 @@ def load_datasets(dataset_path: str) -> DatasetDict:
4650
logger.info(f"Datasets loaded: {datasets}")
4751
# Ensure standard column names
4852
if "sentence" in datasets["train"].column_names:
49-
datasets = datasets.rename_column("sentence", "text")
53+
if "text" not in datasets["train"].column_names:
54+
datasets = datasets.rename_column("sentence", "text")
5055
if "label" in datasets["train"].column_names:
51-
datasets = datasets.rename_column("label", "labels")
56+
if "labels" not in datasets["train"].column_names:
57+
datasets = datasets.rename_column("label", "labels")
5258
# Make sure 'labels' column exists
5359
if "labels" not in datasets["train"].column_names:
5460
raise ValueError(
@@ -101,6 +107,7 @@ def main(config_path: str):
101107
# --- Setup W&B ---
102108
run_name = f"train_teacher_{int(time.time())}"
103109
if config.get("report_to") == "wandb":
110+
os.environ.pop("WANDB_DISABLED", None)
104111
os.environ["WANDB_PROJECT"] = config["project_name"]
105112
logger.info(f"Logging to W&B project: {config['project_name']}")
106113
else:
@@ -114,7 +121,7 @@ def main(config_path: str):
114121
model = AutoModelForSequenceClassification.from_pretrained(
115122
config["model_name"], num_labels=2 # Assuming binary classification for SST-2
116123
)
117-
tokenizer = AutoTokenizer.from_pretrained(config["model_name"], use_fast=False)
124+
tokenizer = AutoTokenizer.from_pretrained(config["model_name"], use_fast=True)
118125

119126
# --- Load and Prepare Data ---
120127
raw_datasets = load_datasets(config["dataset_path"])
@@ -126,8 +133,24 @@ def main(config_path: str):
126133
eval_dataset = tokenized_datasets[config["eval_split"]]
127134
sanity_dataset = tokenized_datasets[config["sanity_split"]]
128135

136+
## TODO: UNCOMMENT FOR REAL RUN::
137+
train_dataset = train_dataset.shuffle(seed=42).select(range(256))
138+
eval_dataset = eval_dataset.shuffle(seed=42).select(range(128))
139+
# sanity_dataset = sanity_dataset.shuffle(seed=42).select(range(256))
140+
129141
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
130142

143+
# --- Detect device (CUDA, MPS, or CPU) ----------------------------
144+
if torch.cuda.is_available():
145+
device = torch.device("cuda")
146+
elif torch.backends.mps.is_available():
147+
device = torch.device("mps")
148+
else:
149+
device = torch.device("cpu")
150+
151+
model.to(device)
152+
logger.info(f"Using device: {device}")
153+
131154
# --- Training Arguments ---
132155
logger.info("Setting up Training Arguments...")
133156
training_args = TrainingArguments(
@@ -144,8 +167,8 @@ def main(config_path: str):
144167
fp16=config["fp16"] and torch.cuda.is_available(),
145168
logging_dir=f"{config['output_dir']}/logs",
146169
logging_steps=config["logging_steps"],
147-
evaluation_strategy=IntervalStrategy.STEPS,
148170
eval_steps=config["eval_steps"],
171+
eval_strategy=IntervalStrategy.STEPS,
149172
save_strategy=IntervalStrategy.STEPS,
150173
save_steps=config["save_steps"],
151174
save_total_limit=2, # Saves the best and the latest checkpoints
@@ -156,6 +179,7 @@ def main(config_path: str):
156179
run_name=run_name,
157180
label_names=["labels"], # Specify label column name
158181
remove_unused_columns=False, # Keep all columns tokenized earlier
182+
ddp_find_unused_parameters=False,
159183
)
160184
logger.info(f"FP16 enabled: {training_args.fp16}")
161185

@@ -171,7 +195,7 @@ def main(config_path: str):
171195
model=model,
172196
args=training_args,
173197
train_dataset=train_dataset,
174-
eval_dataset={"eval": eval_dataset, "sanity": sanity_dataset}, # Evaluate on both
198+
eval_dataset=eval_dataset,
175199
tokenizer=tokenizer,
176200
data_collator=data_collator,
177201
compute_metrics=compute_metrics,
@@ -193,12 +217,7 @@ def main(config_path: str):
193217
trainer.log_metrics("train", metrics)
194218
trainer.save_metrics("train", metrics)
195219

196-
# Evaluate one last time on both sets with the best model
197-
logger.info("Evaluating best model on eval and sanity sets...")
198-
eval_metrics = trainer.evaluate(eval_dataset=eval_dataset, metric_key_prefix="final_eval")
199-
trainer.log_metrics("final_eval", eval_metrics)
200-
trainer.save_metrics("final_eval", eval_metrics)
201-
220+
# Evaluate on sanity set
202221
sanity_metrics = trainer.evaluate(eval_dataset=sanity_dataset, metric_key_prefix="final_sanity")
203222
trainer.log_metrics("final_sanity", sanity_metrics)
204223
trainer.save_metrics("final_sanity", sanity_metrics)

train.sh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
# Make sure to run: chmod +x train.sh
33
set -e
44

5-
CONFIG=configs/deberta_large_sst2.yaml
5+
# Check if config argument was provided
6+
if [ $# -eq 0 ]; then
7+
echo "Error: Please provide config file path as argument"
8+
echo "Usage: ./train.sh <config_path>"
9+
exit 1
10+
fi
11+
12+
CONFIG=$1
613

714
# Check if CONFIG file exists
815
if [ ! -f "$CONFIG" ]; then
@@ -21,4 +28,4 @@ echo "Starting training using config: $CONFIG"
2128
torchrun --nnodes 1 --nproc_per_node 4 --master_port 12345 \
2229
src/train_teacher.py $CONFIG
2330

24-
echo "Training script finished."
31+
echo "Training script finished."

0 commit comments

Comments
 (0)