Skip to content

Commit fe389ff

Browse files
authored
chg: small corrections (#7)
1 parent 3a85f66 commit fe389ff

7 files changed

Lines changed: 22 additions & 14 deletions

File tree

frame/evaluate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,16 @@ def main():
5353

5454
# * Get checkpoint and prepare Explainer
5555
model = models.select_model(model_name, tune)
56-
model.load_state_dict(torch.load(path_checkpoint))
56+
model.load_state_dict(torch.load(path_checkpoint,
57+
map_location=device,
58+
weights_only=True))
5759
model.eval()
5860

5961
agg_pred = []
6062
agg_lbl = []
6163
agg_true = []
6264
for data in tqdm(test_loader, ncols=120, desc="Explaining"):
63-
data.to(device)
65+
data = data.to(device)
6466

6567
# * Make predictions
6668
model_out = model(x=data.x.float(),

frame/explain.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def main():
6060

6161
# * Get checkpoint and prepare Explainer
6262
model = models.select_model(model_name, tune)
63-
model.load_state_dict(torch.load(path_checkpoint))
63+
model.load_state_dict(torch.load(path_checkpoint,
64+
map_location=device,
65+
weights_only=True))
6466
model.eval()
6567

6668
if task == "classification":
@@ -77,7 +79,7 @@ def main():
7779
return_type="raw"))
7880

7981
for data in tqdm(dataloader, ncols=120, desc="Explaining"):
80-
data.to(device)
82+
data = data.to(device)
8183

8284
# * Make predictions
8385
model_out = model(x=data.x.float(),

frame/generate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ def main():
4242
raise NotImplementedError("Loader not available")
4343

4444
# * Export
45-
bce_weight = (len(dataset.y) - sum(dataset.y)) / sum(dataset.y)
45+
task = params["Data"].get("task", "classification").lower()
46+
if task == "classification" and sum(dataset.y) > 0:
47+
bce_weight = (len(dataset.y) - sum(dataset.y)) / sum(dataset.y)
48+
else:
49+
bce_weight = torch.tensor(1.0)
4650
metadata = {"feat_size": dataset.num_node_features,
4751
"edge_dim": dataset.num_edge_features,
4852
"bce_weight": bce_weight,

frame/source/datasets/decompose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def process_data(self):
5959
# * Iterate
6060
data_list = []
6161
for line in tqdm(dataset, ncols=120, desc="Creating graphs"):
62-
line = re.sub(r"\'.*\'", "", line) # Replace ".*" strings.
62+
line = re.sub(r"\'.*?\'", "", line) # Replace '...' strings.
6363
line = line.split(",")
6464

6565
# Get label
@@ -206,7 +206,7 @@ def _gen_features(smiles):
206206
# [single, double, triple, aromatic, conjugation, ring] + stereo)
207207

208208
# edge_attrs += [edge_attr, edge_attr]
209-
# frag_edge_attr = torch.stack(edge_attrs, dim=0)
209+
# frag_edge_attr = torch.stack(edge_attrs, dim=0)
210210

211211
agg_x = torch.sum(frag_x, dim=0)
212212
return agg_x

frame/source/datasets/default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def process_data(self):
5959
# * Iterate
6060
data_list = []
6161
for line in tqdm(dataset, ncols=120, desc="Creating graphs"):
62-
line = re.sub(r"\'.*\'", "", line) # Replace ".*" strings.
62+
line = re.sub(r"\'.*?\'", "", line) # Replace '...' strings.
6363
line = line.split(",")
6464

6565
# Get label

frame/source/train/epoch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ def train_epoch(model, optim, scheduler, lossfn, loader):
2626
model = model.train()
2727
for batch in loader:
2828
batch = batch.to(device)
29-
optim.zero_grad(batch)
29+
optim.zero_grad()
3030

3131
# * Make predictions
3232
out = model(x=batch.x.float(),
3333
edge_index=batch.edge_index,
34-
edge_attr=batch.edge_attr,
34+
edge_attr=batch.edge_attr.float(),
3535
batch=batch.batch)
3636

3737
# * Compute loss
@@ -64,7 +64,7 @@ def valid_epoch(model, task, loader):
6464
# * Make predictions
6565
out = model(x=batch.x.float(),
6666
edge_index=batch.edge_index,
67-
edge_attr=batch.edge_attr,
67+
edge_attr=batch.edge_attr.float(),
6868
batch=batch.batch)
6969

7070
# * Read prediction values

frame/tune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def objective(trial, params, dataset):
6464
patience_counter = 0
6565
best_model_state = None
6666

67+
start = time.time()
6768
for epoch in tqdm(range(epochs), ncols=120, desc="Training"):
68-
start = time.time()
6969
_ = train.train_epoch(model, optim, schdlr,
7070
lossfn, train_loader)
7171
val_metrics = train.valid_epoch(model, task, valid_loader)
@@ -82,7 +82,7 @@ def objective(trial, params, dataset):
8282
if patience_counter >= patience:
8383
break
8484

85-
fit_time = time.time() - start
85+
fit_time = time.time() - start
8686

8787
# Prepare best model
8888
model.load_state_dict(best_model_state)
@@ -153,7 +153,7 @@ def main():
153153
params = yaml.safe_load(stream)
154154

155155
# * Initialize
156-
task = name = params["Data"]["task"]
156+
task = params["Data"]["task"]
157157
name = params["Data"]["name"]
158158
if name.lower() == "none":
159159
name = str(uuid.uuid4()).split("-")[0]

0 commit comments

Comments
 (0)