Skip to content

KL loss specifications #8

@JoakimEdin

Description

@JoakimEdin

Hi again! I am having troubles reproducing your results. I think it is my loss function that is the issue. Below I have written the code I used with questions as comments. Would appreciate it if you could give me some guidance 🙏

kl_div_func = torch.nn.KLDivLoss(reduction='batchmean') # is this correct, or should I choose mean or sum instead?

y_prob, attention = model(input)

evidence_token_ids = torch.softmax(evidence_token_ids, dim=-1) # Is this correct, or should I remove this line?
attention = torch.log(attention)

binary_cross_loss =  torch.nn.functional.binary_cross_entropy_with_logits(y_prob, y)
kl_div = kl_div_func(attention, evidence_token_ids) # are you  providing the function (evidence_token_ids, attention) instead.

loss = binary_cross_loss  + lambda_1 *  kl_div 

Did you calculate the Kl divergence between the attention and the boolean ground truth evidence, or did you use a softmax on the ground truth (I see some people doing this)? Furthermore, which reduction did you use in the KL divergence? torch.nn.KLDivLoss gives you the option to choose between mean, batchmean, and sum (see documentation here: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html). Finally, how did you deal with all examples without annotated evidence?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions