Skip to content

Commit 0c09181

Browse files
committed
fix dumb things from chatgpt
1 parent b2b681d commit 0c09181

File tree

1 file changed

+10
-20
lines changed

1 file changed

+10
-20
lines changed

compare_encoders.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,13 @@ def train_model(encoder_cls, full_dataset, random_seed, val_ratio=0.2):
165165
trainer = OnlineTrainer(model, optimizer)
166166
trainer.add_data(train_graphs, train_fitnesses, task_type, task_features)
167167

168-
train_losses = []
169-
val_losses = []
170-
loss_history = trainer.train(epochs=100, batch_size=4, warmup_epochs=100, verbose=True)
171-
train_losses.append(loss_history[-1].sum().item())
172-
val_losses = evaluate_fitness_loss(model, val_graphs, val_fitnesses, task_type, task_features)
168+
loss_history = trainer.train(epochs=10, batch_size=4, warmup_epochs=10, verbose=True)
169+
train_losses = [lh.sum().item() for lh in loss_history]
170+
val_loss = evaluate_fitness_loss(model, val_graphs, val_fitnesses, task_type, task_features)
173171

174172
final_loss = train_losses[-1]
175173
auc = np.trapz(train_losses, dx=1)
176-
val_final = val_losses[-1]
177-
val_auc = np.trapz(val_losses, dx=1)
178-
return final_loss, auc, val_final, val_auc
174+
return final_loss, auc, val_loss
179175

180176

181177
if __name__ == "__main__":
@@ -207,44 +203,38 @@ def train_model(encoder_cls, full_dataset, random_seed, val_ratio=0.2):
207203
globals()["shared_attr_vocab"] = SharedAttributeVocab(attr_name_vocab, 5)
208204
globals()["fitness_dim"] = len(data[1][0])
209205

210-
final_att, auc_att, val_att, val_auc_att = train_model(GraphEncoder, data[:4], random_seed)
206+
final_att, auc_att, val_att = train_model(GraphEncoder, data[:4], random_seed)
211207
with mlflow.start_run(run_name=f"GraphEncoder_{i}"):
212208
mlflow.log_params({"encoder": "GraphEncoder", "seed": random_seed, "num_samples": args.num_samples})
213209
mlflow.log_metrics(
214210
{
215211
"train_final_loss": final_att,
216212
"train_auc": auc_att,
217-
"val_final_loss": val_att,
218-
"val_auc": val_auc_att,
213+
"val_loss": val_att,
219214
}
220215
)
221216

222-
final_async, auc_async, val_async, val_auc_async = train_model(AsyncGraphEncoder, data[:4], random_seed)
217+
final_async, auc_async, val_async = train_model(AsyncGraphEncoder, data[:4], random_seed)
223218
with mlflow.start_run(run_name=f"AsyncGraphEncoder_{i}"):
224219
mlflow.log_params({"encoder": "AsyncGraphEncoder", "seed": random_seed, "num_samples": args.num_samples})
225220
mlflow.log_metrics(
226221
{
227222
"train_final_loss": final_async,
228223
"train_auc": auc_async,
229-
"val_final_loss": val_async,
230-
"val_auc": val_auc_async,
224+
"val_loss": val_async,
231225
}
232226
)
233227

234228
res_attention_final.append(final_att)
235229
res_attention_auc.append(auc_att)
236230
res_attention_val.append(val_att)
237-
res_attention_val_auc.append(val_auc_att)
238231
res_async_final.append(final_async)
239232
res_async_auc.append(auc_async)
240233
res_async_val.append(val_async)
241-
res_async_val_auc.append(val_auc_async)
242234

235+
print(f" attention encoder → train loss: {final_att:.4f}, val loss: {val_att:.4f}, AUC loss: {auc_att:.4f}")
243236
print(
244-
f" attention encoder → train loss: {final_att:.4f}, val loss: {val_att:.4f}, AUC loss: {auc_att:.4f}, val AUC: {val_auc_att:.4f}"
245-
)
246-
print(
247-
f" async encoder → train loss: {final_async:.4f}, val loss: {val_async:.4f}, AUC loss: {auc_async:.4f}, val AUC: {val_auc_async:.4f}"
237+
f" async encoder → train loss: {final_async:.4f}, val loss: {val_async:.4f}, AUC loss: {auc_async:.4f}"
248238
)
249239

250240
print("\nSummary:")

0 commit comments

Comments
 (0)