Skip to content

Commit 1c553f7

Browse files
committed
sync
1 parent 96ef82e commit 1c553f7

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

chebai/models/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,10 @@ def _execute(
270270
if self.pass_loss_kwargs:
271271
loss_kwargs = loss_kwargs_candidates
272272
#torch.save(loss_data,"loss_data.pt")
273-
loss_kwargs['weights'] = f.create_data_weights(batchsize=len(data['idents']),dim=data['labels'].size(dim=1),weights=data["loss_kwargs"],idents=data["idents"])
273+
if not f.class_weights:
274+
loss_kwargs['weights'] = f.create_data_weights(batchsize=len(data['idents']),dim=data['labels'].size(dim=1),weights=data["loss_kwargs"],idents=data["idents"])
275+
else:
276+
loss_kwargs['weights'] = f.create_weight_class_tensor(len(data['idents']))
274277
loss_kwargs["current_epoch"] = self.trainer.current_epoch
275278
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
276279
if isinstance(loss, tuple):

chebai/preprocessing/datasets/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,10 +1181,12 @@ def load_processed_data(
11811181
else:
11821182
data_df = self.dynamic_split_dfs[kind]
11831183
data = data_df.to_dict(orient="records")
1184-
if kind == "train":
1184+
if f.class_weights:
1185+
f.create_class_tensor("../../../weights/test.pt")
1186+
if kind == "train" and f.class_weights == False:
11851187
f.init_weights()
11861188
data = f.add_train_weights(data)
1187-
if kind == "validation":
1189+
if kind == "validation" and f.class_weights == False:
11881190
data = f.add_val_weights(data)
11891191
# torch.save(data,"gewicht.pt")
11901192

extras/weight_loader.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import os
44

5+
class_weights = True
56

67
#inint weights in a csv file
78
def init_weights(path="../weights/first_it.csv",path_to_split="../split/splits.csv"):
@@ -113,7 +114,45 @@ def check_weights(data):
113114
for i in data:
114115
print(f"({i["ident"]} , {i["weight"]}")
115116

116-
init_weights()
117-
mock_init_weights()
118-
# print(get_weights((233713,51990)))
119117

118+
def init_class_weights(class_path:str,weight_path:str,weight:float):
119+
with open(class_path,'r') as classes:
120+
with open(weight_path,'w') as weights:
121+
reader = csv.reader(classes)
122+
writer = csv.writer(weights)
123+
writer.writerow(["class","weight"])
124+
for row in reader:
125+
row = row + [weight,]
126+
writer.writerow(row)
127+
128+
def create_class_tensor(save_path:str)-> torch.Tensor:
129+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
130+
t = torch.empty(1,1528)
131+
with open("../../weights/class_first_it.csv",'r') as f:
132+
reader = csv.reader(f)
133+
index = 0
134+
for row in reader:
135+
if row[1] == "weight":
136+
continue
137+
t[0][index] = float(row[1])
138+
index = index + 1
139+
torch.save(t,save_path)
140+
141+
def create_weight_class_tensor(batch_size:int)-> torch.Tensor:
142+
t = torch.load("../../weights/test.pt")
143+
w = None
144+
for i in range(0,batch_size):
145+
if w is None:
146+
w = t
147+
else:
148+
w = torch.cat((w,t),dim=0)
149+
print(w.shape)
150+
return w
151+
152+
153+
154+
155+
156+
#init_class_weights("../../data/chebi_v241/ChEBI50/processed/classes.txt","../../weights/class_first_it.csv",1)
157+
create_class_tensor("../../weights/test.pt")
158+
create_weight_class_tensor(32)

0 commit comments

Comments
 (0)