Skip to content

HSIC输入参数的大小是什么? #2

@pending1face

Description

@pending1face

您好,我看到论文里的描述,HSIC的输入似乎应该是batch_size * assign_num * embedding_dim的大小。但是在代码里,由于diffpool_net里使用了torch.sum(h, dim=1)的readout,输入变成了batch_size * embedding_dim。
我打印了nets.superpixels_graph_classification.diffpool.py 里HSIC_weight模块中的feature1。

 def forward(self, lambdap, global_epoch, lambda_decay_rate, lambda_decay_epoch, min_lambda_times, first):
     if True:
         self.all_weights = self.weights
     else:
         self.all_weights = torch.cat((self.weights, self.pre_weight.detach()), dim=0)
     lossb = Variable(torch.FloatTensor([0]).cuda())
     for i in range(self.assign_num-1):
         for j in range(i+1, self.assign_num):
             feature1 = self.all_features[:, i*self.embedding_size : (i+1)*self.embedding_size]
             feature2 = self.all_features[:, j*self.embedding_size : (j+1)*self.embedding_size]
            
            print(feature1)
            #lossb += self.loss_dependence(feature1, feature2, softmax(self.all_weights), self.all_features.size(0)).view(1)
            lossb += self.biased_estimator(feature1, feature2, softmax(self.all_weights)).view(1)
    #lossq = (torch.sum(self.weights*self.weights)-self.n)**2 
    lossp = softmax(self.weights).pow(2).sum()
    lambdap = lambdap * max((lambda_decay_rate ** (global_epoch // lambda_decay_epoch)),
                                 min_lambda_times)
    lossg = 1e+3*lossb / lambdap + lossp
    return lossg, lossb

发现它大小很奇怪。
image
可以解答一下吗?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions