@@ -11,9 +11,11 @@ def __init__(
1111 #True :bagging, False : boosting
1212 ensemble :bool ,
1313 load_path :str ,
14+ dim :int ,
1415 ):
1516 self .ensemble = ensemble
1617 self .load_path = load_path
18+ self .dim = dim
1719
1820
1921
@@ -26,14 +28,13 @@ def add_train_weights(self,ids,load_path):
2628 if it % 10000 == 0 :
2729 print (it )
2830 ident = i ["ident" ]
29- print (d [str (ident )])
3031 i ["weight" ] = d [str (ident )]
3132 it = it + 1
3233 return ids
3334
34- def add_val_weights (self ,ids ):
35+ def add_val_weights (self ,ids , dim ):
3536 for i in ids :
36- i ["weight" ] = [1 ]* 1528
37+ i ["weight" ] = [1 ]* dim
3738 return ids
3839 #dict reverse to the dict created by the method bootstrapping in sample.py
3940 def add_duplicates (self ,data ,load_path ):
@@ -64,35 +65,35 @@ def create_data_weights(batchsize:int,dim:int,weights:dict[str,list[float,...]],
6465 index = index + 1
6566 return weight
6667
67- def create_weight (path_to_split = "/home/programmer/Bachelorarbeit/split/splits.csv" ):
68- weights = {}
69- with open (path_to_split , 'r' ) as csvfile :
70- reader = csv .reader (csvfile )
71- i = 0
72- for row in reader :
73- if (row [1 ] == "train" ) and i > 0 :
74- #print(row[0])
75- weights [row [0 ]] = torch .full ((1 ,1528 ),int (row [0 ]))
76- #print(row[0])
77- i = i + 1
78- print (len (weights ))
79- torch .save (weights ,"/home/programmer/Bachelorarbeit/weights/init_mh.pt" )
68+ # def create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
69+ # weights = {}
70+ # with open(path_to_split, 'r') as csvfile:
71+ # reader = csv.reader(csvfile)
72+ # i = 0
73+ # for row in reader:
74+ # if (row[1] == "train") and i > 0:
75+ # #print(row[0])
76+ # weights[row[0]] = torch.full((1,1528),int(row[0]))
77+ # #print(row[0])
78+ # i = i +1
79+ # print(len(weights))
80+ # torch.save(weights,"/home/programmer/Bachelorarbeit/weights/init_mh.pt")
8081
8182
8283#for 1_ada_no_normal_weights weights =0.0001
83- def new_create_weight (path_to_split = "/home/programmer/Bachelorarbeit/split/splits .csv" ):
84+ def new_create_weight (path_to_split = "/home/programmer/Bachelorarbeit/split/reworked_splits .csv" ):
8485 weights = {}
8586 with open (path_to_split , 'r' ) as csvfile :
8687 reader = csv .reader (csvfile )
8788 i = 0
8889 for row in reader :
8990 if (row [1 ] == "train" ) and i > 0 :
9091 # print(row[0])
91- weights [row [0 ]] = [1 / (1528 * 160715 ) ]* 1528
92+ weights [row [0 ]] = [( 1 / (1528 * 160677 )) * 10000 ]* 1528
9293 # print(row[0])
9394 i = i + 1
9495 print (len (weights ))
95- torch .save (weights , "/home/programmer/Bachelorarbeit/weights/init_mh .pt" )
96+ torch .save (weights , "/home/programmer/Bachelorarbeit/weights/init_mh_10000 .pt" )
9697
9798
9899
@@ -114,4 +115,4 @@ def new_create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/split
114115
115116
116117#new_create_weight()
117- #create_weight()
118+
0 commit comments