2929)
3030logger = logging .getLogger (__name__ )
3131
32+ # Set OMP_NUM_THREADS to 1 to avoid potential CPU over-subscription
33+ os .environ ['OMP_NUM_THREADS' ] = '1'
34+ logger .info (f"Setting OMP_NUM_THREADS=1" )
35+
3236
3337def load_config (config_path : str ) -> Dict :
3438 """Loads configuration from a YAML file."""
@@ -46,9 +50,11 @@ def load_datasets(dataset_path: str) -> DatasetDict:
4650 logger .info (f"Datasets loaded: { datasets } " )
4751 # Ensure standard column names
4852 if "sentence" in datasets ["train" ].column_names :
49- datasets = datasets .rename_column ("sentence" , "text" )
53+ if "text" not in datasets ["train" ].column_names :
54+ datasets = datasets .rename_column ("sentence" , "text" )
5055 if "label" in datasets ["train" ].column_names :
51- datasets = datasets .rename_column ("label" , "labels" )
56+ if "labels" not in datasets ["train" ].column_names :
57+ datasets = datasets .rename_column ("label" , "labels" )
5258 # Make sure 'labels' column exists
5359 if "labels" not in datasets ["train" ].column_names :
5460 raise ValueError (
@@ -101,6 +107,7 @@ def main(config_path: str):
101107 # --- Setup W&B ---
102108 run_name = f"train_teacher_{ int (time .time ())} "
103109 if config .get ("report_to" ) == "wandb" :
110+ os .environ .pop ("WANDB_DISABLED" , None )
104111 os .environ ["WANDB_PROJECT" ] = config ["project_name" ]
105112 logger .info (f"Logging to W&B project: { config ['project_name' ]} " )
106113 else :
@@ -114,7 +121,7 @@ def main(config_path: str):
114121 model = AutoModelForSequenceClassification .from_pretrained (
115122 config ["model_name" ], num_labels = 2 # Assuming binary classification for SST-2
116123 )
117- tokenizer = AutoTokenizer .from_pretrained (config ["model_name" ], use_fast = False )
124+ tokenizer = AutoTokenizer .from_pretrained (config ["model_name" ], use_fast = True )
118125
119126 # --- Load and Prepare Data ---
120127 raw_datasets = load_datasets (config ["dataset_path" ])
@@ -126,8 +133,24 @@ def main(config_path: str):
126133 eval_dataset = tokenized_datasets [config ["eval_split" ]]
127134 sanity_dataset = tokenized_datasets [config ["sanity_split" ]]
128135
136+ ## TODO: UNCOMMENT FOR REAL RUN::
137+ train_dataset = train_dataset .shuffle (seed = 42 ).select (range (256 ))
138+ eval_dataset = eval_dataset .shuffle (seed = 42 ).select (range (128 ))
139+ # sanity_dataset = sanity_dataset.shuffle(seed=42).select(range(256))
140+
129141 data_collator = DataCollatorWithPadding (tokenizer = tokenizer )
130142
143+ # --- Detect device (CUDA, MPS, or CPU) ----------------------------
144+ if torch .cuda .is_available ():
145+ device = torch .device ("cuda" )
146+ elif torch .backends .mps .is_available ():
147+ device = torch .device ("mps" )
148+ else :
149+ device = torch .device ("cpu" )
150+
151+ model .to (device )
152+ logger .info (f"Using device: { device } " )
153+
131154 # --- Training Arguments ---
132155 logger .info ("Setting up Training Arguments..." )
133156 training_args = TrainingArguments (
@@ -144,8 +167,8 @@ def main(config_path: str):
144167 fp16 = config ["fp16" ] and torch .cuda .is_available (),
145168 logging_dir = f"{ config ['output_dir' ]} /logs" ,
146169 logging_steps = config ["logging_steps" ],
147- evaluation_strategy = IntervalStrategy .STEPS ,
148170 eval_steps = config ["eval_steps" ],
171+ eval_strategy = IntervalStrategy .STEPS ,
149172 save_strategy = IntervalStrategy .STEPS ,
150173 save_steps = config ["save_steps" ],
151174 save_total_limit = 2 , # Saves the best and the latest checkpoints
@@ -156,6 +179,7 @@ def main(config_path: str):
156179 run_name = run_name ,
157180 label_names = ["labels" ], # Specify label column name
158181 remove_unused_columns = False , # Keep all columns tokenized earlier
182+ ddp_find_unused_parameters = False ,
159183 )
160184 logger .info (f"FP16 enabled: { training_args .fp16 } " )
161185
@@ -171,7 +195,7 @@ def main(config_path: str):
171195 model = model ,
172196 args = training_args ,
173197 train_dataset = train_dataset ,
174- eval_dataset = { "eval" : eval_dataset , "sanity" : sanity_dataset }, # Evaluate on both
198+ eval_dataset = eval_dataset ,
175199 tokenizer = tokenizer ,
176200 data_collator = data_collator ,
177201 compute_metrics = compute_metrics ,
@@ -193,12 +217,7 @@ def main(config_path: str):
193217 trainer .log_metrics ("train" , metrics )
194218 trainer .save_metrics ("train" , metrics )
195219
196- # Evaluate one last time on both sets with the best model
197- logger .info ("Evaluating best model on eval and sanity sets..." )
198- eval_metrics = trainer .evaluate (eval_dataset = eval_dataset , metric_key_prefix = "final_eval" )
199- trainer .log_metrics ("final_eval" , eval_metrics )
200- trainer .save_metrics ("final_eval" , eval_metrics )
201-
220+ # Evaluate on sanity set
202221 sanity_metrics = trainer .evaluate (eval_dataset = sanity_dataset , metric_key_prefix = "final_sanity" )
203222 trainer .log_metrics ("final_sanity" , sanity_metrics )
204223 trainer .save_metrics ("final_sanity" , sanity_metrics )
0 commit comments