1+ import argparse
12import os
23import random
34import time
45import tracemalloc
56
7+ import mlflow
68import numpy as np
79import torch
810from torch_geometric .data import Data
@@ -176,12 +178,19 @@ def train_model(encoder_cls, full_dataset, random_seed, val_ratio=0.2):
176178 return final_loss , auc , val_final , val_auc
177179
178180
179- if __name__ == "__main__" :
181+ def main ():
182+ parser = argparse .ArgumentParser (description = "Compare graph encoders" )
183+ parser .add_argument ("--num-runs" , type = int , default = 1 , help = "number of runs" )
184+ parser .add_argument ("--num-samples" , type = int , default = 1000 , help = "number of optimizers to sample" )
185+ parser .add_argument ("--experiment-name" , type = str , default = "compare_encoders" , help = "MLflow experiment name" )
186+ args = parser .parse_args ()
187+
188+ mlflow .set_experiment (args .experiment_name )
189+
180190 num_node_types = len (NODE_TYPE_TO_INDEX )
181191 graph_latent_dim = 16
182192 task_latent_dim = 8
183193
184- # lists for both metrics
185194 res_attention_final = []
186195 res_attention_auc = []
187196 res_attention_val = []
@@ -191,15 +200,36 @@ def train_model(encoder_cls, full_dataset, random_seed, val_ratio=0.2):
191200 res_async_val = []
192201 res_async_val_auc = []
193202
194- for i in range (100 ):
203+ for i in range (args . num_runs ):
195204 random_seed = random .randint (0 , 99999999 )
196- data = generate_data (1000 )
205+ data = generate_data (args . num_samples )
197206 attr_name_vocab = data [4 ]
198- shared_attr_vocab = SharedAttributeVocab (attr_name_vocab , 5 )
199- fitness_dim = len (data [1 ][0 ])
207+ globals ()[ " shared_attr_vocab" ] = SharedAttributeVocab (attr_name_vocab , 5 )
208+ globals ()[ " fitness_dim" ] = len (data [1 ][0 ])
200209
201210 final_att , auc_att , val_att , val_auc_att = train_model (GraphEncoder , data [:4 ], random_seed )
211+ with mlflow .start_run (run_name = f"GraphEncoder_{ i } " ):
212+ mlflow .log_params ({"encoder" : "GraphEncoder" , "seed" : random_seed , "num_samples" : args .num_samples })
213+ mlflow .log_metrics (
214+ {
215+ "train_final_loss" : final_att ,
216+ "train_auc" : auc_att ,
217+ "val_final_loss" : val_att ,
218+ "val_auc" : val_auc_att ,
219+ }
220+ )
221+
202222 final_async , auc_async , val_async , val_auc_async = train_model (AsyncGraphEncoder , data [:4 ], random_seed )
223+ with mlflow .start_run (run_name = f"AsyncGraphEncoder_{ i } " ):
224+ mlflow .log_params ({"encoder" : "AsyncGraphEncoder" , "seed" : random_seed , "num_samples" : args .num_samples })
225+ mlflow .log_metrics (
226+ {
227+ "train_final_loss" : final_async ,
228+ "train_auc" : auc_async ,
229+ "val_final_loss" : val_async ,
230+ "val_auc" : val_auc_async ,
231+ }
232+ )
203233
204234 res_attention_final .append (final_att )
205235 res_attention_auc .append (auc_att )
@@ -226,3 +256,7 @@ def train_model(encoder_cls, full_dataset, random_seed, val_ratio=0.2):
226256 print (f"Mean val loss (async encoder): { np .mean (res_async_val ):.4f} " )
227257 print (f"Mean AUC loss (async encoder): { np .mean (res_async_auc ):.4f} " )
228258 print (f"Mean val AUC (async encoder): { np .mean (res_async_val_auc ):.4f} " )
259+
260+
261+ if __name__ == "__main__" :
262+ main ()
0 commit comments