Skip to content

Commit 7ade9cb

Browse files
committed
enables bagging and parallel training
1 parent a5e64ac commit 7ade9cb

File tree

2 files changed

+94
-36
lines changed

2 files changed

+94
-36
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from chebai.preprocessing import reader as dr
2121

22-
import extras.adamh as f
22+
from extras.adamh import Ensemble_loader
2323

2424

2525
class XYBaseDataModule(LightningDataModule):
@@ -723,6 +723,8 @@ class _DynamicDataset(XYBaseDataModule, ABC):
723723

724724
def __init__(
725725
self,
726+
ensemble: bool,
727+
load_path: str,
726728
**kwargs,
727729
):
728730
super(_DynamicDataset, self).__init__(**kwargs)
@@ -731,6 +733,7 @@ def __init__(
731733
self._dynamic_df_train = None
732734
self._dynamic_df_test = None
733735
self._dynamic_df_val = None
736+
self.loader= Ensemble_loader(ensemble=ensemble,load_path=load_path)
734737
# Path of csv file which contains a list of ids & their assignment to a dataset (either train,
735738
# validation or test).
736739
self.splits_file_path = self._validate_splits_file_path(
@@ -1182,11 +1185,20 @@ def load_processed_data(
11821185
data_df = self.dynamic_split_dfs[kind]
11831186
data = data_df.to_dict(orient="records")
11841187
if kind == "train" :
1185-
# f.init_weights()
1186-
data = f.add_train_weights(data)
1188+
1189+
if self.loader.ensemble:
1190+
data = self.loader.add_val_weights(data)
1191+
1192+
data = self.loader.add_duplicates(data,self.loader.load_path)
1193+
1194+
else:
1195+
data = self.loader.add_train_weights(data,self.loader.load_path)
1196+
exit()
11871197
if kind == "validation" :
1188-
data = f.add_val_weights(data)
1189-
# torch.save(data,"gewicht.pt")
1198+
data = self.loader.add_val_weights(data)
1199+
1200+
1201+
11901202

11911203
return data
11921204

extras/adamh.py

Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,68 @@
33
import numpy
44

55

6-
train = 0
76

7+
class Ensemble_loader():
8+
9+
def __init__(
10+
self,
11+
#True :bagging, False : boosting
12+
ensemble:bool,
13+
load_path:str,
14+
):
15+
self.ensemble=ensemble
16+
self.load_path=load_path
17+
18+
19+
20+
21+
def add_train_weights(self,ids,load_path):
22+
d = torch.load(load_path,weights_only=False)
23+
print("start")
24+
it = 0
25+
for i in ids:
26+
if it % 10000 == 0:
27+
print(it)
28+
ident = i["ident"]
29+
print(d[str(ident)])
30+
i["weight"] = d[str(ident)]
31+
it = it + 1
32+
return ids
33+
34+
def add_val_weights(self,ids):
35+
for i in ids:
36+
i["weight"] = [1]*1528
37+
return ids
38+
#dict reverse to the dict created by the method bootstrapping in sample.py
39+
def add_duplicates(self,data,load_path):
40+
path_to_dict = load_path
41+
d = torch.load(path_to_dict,weights_only=False)
42+
length = len(data)
43+
print(length)
44+
for i in range(0,length):
45+
ident = data[i]["ident"]
46+
if(d[str(ident)] > 1):
47+
r = d[str(ident)]
48+
for j in range(0,r-1):
49+
data.append(data[i])
50+
print("append")
51+
print(len(data))
52+
53+
return data
54+
55+
56+
def create_data_weights(batchsize:int,dim:int,weights:dict[str,list[float,...]],idents:tuple[int,...])-> torch.tensor:
57+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58+
weight = None
59+
index = 0
60+
for i in idents:
61+
w = torch.Tensor([weights[str(i)],]).to(device)
62+
if weight == None:
63+
weight = w
64+
else:
65+
weight = torch.cat((weight,w),0)
66+
index = index + 1
67+
return weight
868

969
def create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
1070
weights = {}
@@ -21,7 +81,7 @@ def create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/splits.cs
2181
torch.save(weights,"/home/programmer/Bachelorarbeit/weights/init_mh.pt")
2282

2383

24-
84+
#for 1_ada_no_normal_weights weights =0.0001
2585
def new_create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
2686
weights = {}
2787
with open(path_to_split, 'r') as csvfile:
@@ -30,44 +90,30 @@ def new_create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/split
3090
for row in reader:
3191
if (row[1] == "train") and i > 0:
3292
# print(row[0])
33-
weights[row[0]] = [1/(1528 * 160715)]* 1528
93+
weights[row[0]] = [1/(1528*160715)]* 1528
3494
# print(row[0])
3595
i = i + 1
3696
print(len(weights))
3797
torch.save(weights, "/home/programmer/Bachelorarbeit/weights/init_mh.pt")
3898

3999

40-
def add_train_weights(ids):
41-
d = torch.load("/home/programmer/Bachelorarbeit/weights/init_mh.pt",weights_only=False)
42-
it = 0
43-
for i in ids:
44-
if it % 10000 == 0:
45-
print(it)
46-
ident = i["ident"]
47-
i["weight"] = d[str(ident)]
48-
it = it + 1
49-
return ids
50100

51-
def add_val_weights(ids):
52-
for i in ids:
53-
weight = 1
54-
#i["weight"] = torch.full((1,1528),1)
55-
i["weight"] = [1]*1528
56101

57-
return ids
58102

59-
def create_data_weights(batchsize:int,dim:int,weights:dict[str,list[float,...]],idents:tuple[int,...])-> torch.tensor:
60-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
61-
weight = None
62-
index = 0
63-
for i in idents:
64-
w = torch.Tensor([weights[str(i)],]).to(device)
65-
if weight == None:
66-
weight = w
67-
else:
68-
weight = torch.cat((weight,w),0)
69-
index = index + 1
70-
return weight
103+
104+
105+
106+
107+
108+
109+
110+
111+
112+
113+
114+
115+
116+
71117

72118
#new_create_weight()
73119
#create_weight()

0 commit comments

Comments
 (0)