Skip to content

Commit b42edb2

Browse files
committed
add weights more dynamic with size
1 parent c84843f commit b42edb2

File tree

2 files changed

+25
-23
lines changed

2 files changed

+25
-23
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,7 @@ def __init__(
725725
self,
726726
ensemble=True,
727727
load_path=None,
728+
dim = 1528,
728729
**kwargs,
729730
):
730731
super(_DynamicDataset, self).__init__(**kwargs)
@@ -733,7 +734,7 @@ def __init__(
733734
self._dynamic_df_train = None
734735
self._dynamic_df_test = None
735736
self._dynamic_df_val = None
736-
self.loader= Ensemble_loader(ensemble=ensemble,load_path=load_path)
737+
self.loader= Ensemble_loader(ensemble=ensemble,load_path=load_path,dim=dim)
737738
# Path of csv file which contains a list of ids & their assignment to a dataset (either train,
738739
# validation or test).
739740
self.splits_file_path = self._validate_splits_file_path(
@@ -1188,7 +1189,7 @@ def load_processed_data(
11881189

11891190
if self.loader.ensemble:
11901191

1191-
data = self.loader.add_val_weights(data)
1192+
data = self.loader.add_val_weights(data,self.loader.dim)
11921193
if self.loader.load_path is not None:
11931194

11941195
data = self.loader.add_duplicates(data,self.loader.load_path)
@@ -1197,7 +1198,7 @@ def load_processed_data(
11971198
data = self.loader.add_train_weights(data,self.loader.load_path)
11981199

11991200
if kind == "validation" :
1200-
data = self.loader.add_val_weights(data)
1201+
data = self.loader.add_val_weights(data,self.loader.dim)
12011202

12021203

12031204

extras/adamh.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ def __init__(
1111
#True :bagging, False : boosting
1212
ensemble:bool,
1313
load_path:str,
14+
dim:int,
1415
):
1516
self.ensemble=ensemble
1617
self.load_path=load_path
18+
self.dim=dim
1719

1820

1921

@@ -26,14 +28,13 @@ def add_train_weights(self,ids,load_path):
2628
if it % 10000 == 0:
2729
print(it)
2830
ident = i["ident"]
29-
print(d[str(ident)])
3031
i["weight"] = d[str(ident)]
3132
it = it + 1
3233
return ids
3334

34-
def add_val_weights(self,ids):
35+
def add_val_weights(self,ids,dim):
3536
for i in ids:
36-
i["weight"] = [1]*1528
37+
i["weight"] = [1]*dim
3738
return ids
3839
#dict reverse to the dict created by the method bootstrapping in sample.py
3940
def add_duplicates(self,data,load_path):
@@ -64,35 +65,35 @@ def create_data_weights(batchsize:int,dim:int,weights:dict[str,list[float,...]],
6465
index = index + 1
6566
return weight
6667

67-
def create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
68-
weights = {}
69-
with open(path_to_split, 'r') as csvfile:
70-
reader = csv.reader(csvfile)
71-
i = 0
72-
for row in reader:
73-
if (row[1] == "train") and i > 0:
74-
#print(row[0])
75-
weights[row[0]] = torch.full((1,1528),int(row[0]))
76-
#print(row[0])
77-
i = i +1
78-
print(len(weights))
79-
torch.save(weights,"/home/programmer/Bachelorarbeit/weights/init_mh.pt")
68+
# def create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
69+
# weights = {}
70+
# with open(path_to_split, 'r') as csvfile:
71+
# reader = csv.reader(csvfile)
72+
# i = 0
73+
# for row in reader:
74+
# if (row[1] == "train") and i > 0:
75+
# #print(row[0])
76+
# weights[row[0]] = torch.full((1,1528),int(row[0]))
77+
# #print(row[0])
78+
# i = i +1
79+
# print(len(weights))
80+
# torch.save(weights,"/home/programmer/Bachelorarbeit/weights/init_mh.pt")
8081

8182

8283
#for 1_ada_no_normal_weights weights =0.0001
83-
def new_create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
84+
def new_create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/reworked_splits.csv"):
8485
weights = {}
8586
with open(path_to_split, 'r') as csvfile:
8687
reader = csv.reader(csvfile)
8788
i = 0
8889
for row in reader:
8990
if (row[1] == "train") and i > 0:
9091
# print(row[0])
91-
weights[row[0]] = [1/(1528*160715)]* 1528
92+
weights[row[0]] = [(1/(1528*160677))*10000]* 1528
9293
# print(row[0])
9394
i = i + 1
9495
print(len(weights))
95-
torch.save(weights, "/home/programmer/Bachelorarbeit/weights/init_mh.pt")
96+
torch.save(weights, "/home/programmer/Bachelorarbeit/weights/init_mh_10000.pt")
9697

9798

9899

@@ -114,4 +115,4 @@ def new_create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/split
114115

115116

116117
#new_create_weight()
117-
#create_weight()
118+

0 commit comments

Comments
 (0)