Skip to content

Commit 18a910f

Browse files
committed
fix compare_encoders.py main()
1 parent 79189f3 commit 18a910f

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

compare_encoders.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def train_model(encoder_cls, full_dataset, random_seed, val_ratio=0.2):
178178
return final_loss, auc, val_final, val_auc
179179

180180

181-
def main():
181+
if __name__ == "__main__":
182182
parser = argparse.ArgumentParser(description="Compare graph encoders")
183183
parser.add_argument("--num-runs", type=int, default=1, help="number of runs")
184184
parser.add_argument("--num-samples", type=int, default=1000, help="number of optimizers to sample")
@@ -256,7 +256,3 @@ def main():
256256
print(f"Mean val loss (async encoder): {np.mean(res_async_val):.4f}")
257257
print(f"Mean AUC loss (async encoder): {np.mean(res_async_auc):.4f}")
258258
print(f"Mean val AUC (async encoder): {np.mean(res_async_val_auc):.4f}")
259-
260-
261-
if __name__ == "__main__":
262-
main()

0 commit comments

Comments
 (0)