@@ -165,17 +165,13 @@ def train_model(encoder_cls, full_dataset, random_seed, val_ratio=0.2):
165165 trainer = OnlineTrainer (model , optimizer )
166166 trainer .add_data (train_graphs , train_fitnesses , task_type , task_features )
167167
168- train_losses = []
169- val_losses = []
170- loss_history = trainer .train (epochs = 100 , batch_size = 4 , warmup_epochs = 100 , verbose = True )
171- train_losses .append (loss_history [- 1 ].sum ().item ())
172- val_losses = evaluate_fitness_loss (model , val_graphs , val_fitnesses , task_type , task_features )
168+ loss_history = trainer .train (epochs = 10 , batch_size = 4 , warmup_epochs = 10 , verbose = True )
169+ train_losses = [lh .sum ().item () for lh in loss_history ]
170+ val_loss = evaluate_fitness_loss (model , val_graphs , val_fitnesses , task_type , task_features )
173171
174172 final_loss = train_losses [- 1 ]
175173 auc = np .trapz (train_losses , dx = 1 )
176- val_final = val_losses [- 1 ]
177- val_auc = np .trapz (val_losses , dx = 1 )
178- return final_loss , auc , val_final , val_auc
174+ return final_loss , auc , val_loss
179175
180176
181177if __name__ == "__main__" :
@@ -207,44 +203,38 @@ def train_model(encoder_cls, full_dataset, random_seed, val_ratio=0.2):
207203 globals ()["shared_attr_vocab" ] = SharedAttributeVocab (attr_name_vocab , 5 )
208204 globals ()["fitness_dim" ] = len (data [1 ][0 ])
209205
210- final_att , auc_att , val_att , val_auc_att = train_model (GraphEncoder , data [:4 ], random_seed )
206+ final_att , auc_att , val_att = train_model (GraphEncoder , data [:4 ], random_seed )
211207 with mlflow .start_run (run_name = f"GraphEncoder_{ i } " ):
212208 mlflow .log_params ({"encoder" : "GraphEncoder" , "seed" : random_seed , "num_samples" : args .num_samples })
213209 mlflow .log_metrics (
214210 {
215211 "train_final_loss" : final_att ,
216212 "train_auc" : auc_att ,
217- "val_final_loss" : val_att ,
218- "val_auc" : val_auc_att ,
213+ "val_loss" : val_att ,
219214 }
220215 )
221216
222- final_async , auc_async , val_async , val_auc_async = train_model (AsyncGraphEncoder , data [:4 ], random_seed )
217+ final_async , auc_async , val_async = train_model (AsyncGraphEncoder , data [:4 ], random_seed )
223218 with mlflow .start_run (run_name = f"AsyncGraphEncoder_{ i } " ):
224219 mlflow .log_params ({"encoder" : "AsyncGraphEncoder" , "seed" : random_seed , "num_samples" : args .num_samples })
225220 mlflow .log_metrics (
226221 {
227222 "train_final_loss" : final_async ,
228223 "train_auc" : auc_async ,
229- "val_final_loss" : val_async ,
230- "val_auc" : val_auc_async ,
224+ "val_loss" : val_async ,
231225 }
232226 )
233227
234228 res_attention_final .append (final_att )
235229 res_attention_auc .append (auc_att )
236230 res_attention_val .append (val_att )
237- res_attention_val_auc .append (val_auc_att )
238231 res_async_final .append (final_async )
239232 res_async_auc .append (auc_async )
240233 res_async_val .append (val_async )
241- res_async_val_auc .append (val_auc_async )
242234
235+ print (f" attention encoder → train loss: { final_att :.4f} , val loss: { val_att :.4f} , AUC loss: { auc_att :.4f} " )
243236 print (
244- f" attention encoder → train loss: { final_att :.4f} , val loss: { val_att :.4f} , AUC loss: { auc_att :.4f} , val AUC: { val_auc_att :.4f} "
245- )
246- print (
247- f" async encoder → train loss: { final_async :.4f} , val loss: { val_async :.4f} , AUC loss: { auc_async :.4f} , val AUC: { val_auc_async :.4f} "
237+ f" async encoder → train loss: { final_async :.4f} , val loss: { val_async :.4f} , AUC loss: { auc_async :.4f} "
248238 )
249239
250240 print ("\n Summary:" )
0 commit comments