Skip to content

Commit b2b681d

Browse files
authored
Merge pull request #20 from TimeDelta/codex/add-mlflow-to-compare_encoders.py-experiment
Add MLflow to encoder experiment
2 parents 82e679d + 18a910f commit b2b681d

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

compare_encoders.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import argparse
12
import os
23
import random
34
import time
45
import tracemalloc
56

7+
import mlflow
68
import numpy as np
79
import torch
810
from torch_geometric.data import Data
@@ -177,11 +179,18 @@ def train_model(encoder_cls, full_dataset, random_seed, val_ratio=0.2):
177179

178180

179181
if __name__ == "__main__":
182+
parser = argparse.ArgumentParser(description="Compare graph encoders")
183+
parser.add_argument("--num-runs", type=int, default=1, help="number of runs")
184+
parser.add_argument("--num-samples", type=int, default=1000, help="number of optimizers to sample")
185+
parser.add_argument("--experiment-name", type=str, default="compare_encoders", help="MLflow experiment name")
186+
args = parser.parse_args()
187+
188+
mlflow.set_experiment(args.experiment_name)
189+
180190
num_node_types = len(NODE_TYPE_TO_INDEX)
181191
graph_latent_dim = 16
182192
task_latent_dim = 8
183193

184-
# lists for both metrics
185194
res_attention_final = []
186195
res_attention_auc = []
187196
res_attention_val = []
@@ -191,15 +200,36 @@ def train_model(encoder_cls, full_dataset, random_seed, val_ratio=0.2):
191200
res_async_val = []
192201
res_async_val_auc = []
193202

194-
for i in range(100):
203+
for i in range(args.num_runs):
195204
random_seed = random.randint(0, 99999999)
196-
data = generate_data(1000)
205+
data = generate_data(args.num_samples)
197206
attr_name_vocab = data[4]
198-
shared_attr_vocab = SharedAttributeVocab(attr_name_vocab, 5)
199-
fitness_dim = len(data[1][0])
207+
globals()["shared_attr_vocab"] = SharedAttributeVocab(attr_name_vocab, 5)
208+
globals()["fitness_dim"] = len(data[1][0])
200209

201210
final_att, auc_att, val_att, val_auc_att = train_model(GraphEncoder, data[:4], random_seed)
211+
with mlflow.start_run(run_name=f"GraphEncoder_{i}"):
212+
mlflow.log_params({"encoder": "GraphEncoder", "seed": random_seed, "num_samples": args.num_samples})
213+
mlflow.log_metrics(
214+
{
215+
"train_final_loss": final_att,
216+
"train_auc": auc_att,
217+
"val_final_loss": val_att,
218+
"val_auc": val_auc_att,
219+
}
220+
)
221+
202222
final_async, auc_async, val_async, val_auc_async = train_model(AsyncGraphEncoder, data[:4], random_seed)
223+
with mlflow.start_run(run_name=f"AsyncGraphEncoder_{i}"):
224+
mlflow.log_params({"encoder": "AsyncGraphEncoder", "seed": random_seed, "num_samples": args.num_samples})
225+
mlflow.log_metrics(
226+
{
227+
"train_final_loss": final_async,
228+
"train_auc": auc_async,
229+
"val_final_loss": val_async,
230+
"val_auc": val_auc_async,
231+
}
232+
)
203233

204234
res_attention_final.append(final_att)
205235
res_attention_auc.append(auc_att)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ scipy
99
pre-commit
1010
black
1111
scikit-learn
12+
mlflow

0 commit comments

Comments
 (0)