-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
您好,我看到论文里的描述,HSIC的输入似乎应该是batch_size * assign_num * embedding_dim的大小。但是在代码里,由于diffpool_net里使用了torch.sum(h, dim=1)的readout,输入变成了batch_size * embedding_dim。
我打印了nets.superpixels_graph_classification.diffpool.py 里HSIC_weight模块中的feature1。
def forward(self, lambdap, global_epoch, lambda_decay_rate, lambda_decay_epoch, min_lambda_times, first):
if True:
self.all_weights = self.weights
else:
self.all_weights = torch.cat((self.weights, self.pre_weight.detach()), dim=0)
lossb = Variable(torch.FloatTensor([0]).cuda())
for i in range(self.assign_num-1):
for j in range(i+1, self.assign_num):
feature1 = self.all_features[:, i*self.embedding_size : (i+1)*self.embedding_size]
feature2 = self.all_features[:, j*self.embedding_size : (j+1)*self.embedding_size]
print(feature1)
#lossb += self.loss_dependence(feature1, feature2, softmax(self.all_weights), self.all_features.size(0)).view(1)
lossb += self.biased_estimator(feature1, feature2, softmax(self.all_weights)).view(1)
#lossq = (torch.sum(self.weights*self.weights)-self.n)**2
lossp = softmax(self.weights).pow(2).sum()
lambdap = lambdap * max((lambda_decay_rate ** (global_epoch // lambda_decay_epoch)),
min_lambda_times)
lossg = 1e+3*lossb / lambdap + lossp
return lossg, lossb
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels
