-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloss_func.py
More file actions
61 lines (46 loc) · 1.8 KB
/
loss_func.py
File metadata and controls
61 lines (46 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self, alpha=1.0, gamma=0.5, logits=False, reduce=False):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
nn.CrossEntropyLoss()
eps = 1e-8
ce_loss = nn.CrossEntropyLoss()(inputs, targets)
pt = torch.exp(-ce_loss)
# F_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
F_loss = self.alpha * (1 - pt + eps) ** self.gamma * (-torch.log(pt+eps))
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
class AccuracyLoss(nn.Module):
def __init__(self):
super(AccuracyLoss, self).__init__()
def forward(self, inputs, targets):
scores = inputs + targets - 1
zeros = torch.zeros_like(scores)
return torch.max(zeros, scores).sum() / inputs.size(0)
class EMDLoss(nn.Module):
def __init__(self):
super(EMDLoss, self).__init__()
def forward(self, inputs, targets):
normalized_inputs = inputs / inputs.sum(dim=1, keepdim=True)
normalized_targets = targets / inputs.sum(dim=1, keepdim=True)
cdf_inputs = torch.cumsum(normalized_inputs, dim=1)
cdf_targets = torch.cumsum(normalized_targets, dim=1)
grad = torch.sum(torch.abs(cdf_targets - cdf_inputs)) / inputs.size(1)
return grad
loss_functions = {
'accuracy': AccuracyLoss(),
'acc': AccuracyLoss(),
'kld': lambda x, y: -nn.functional.kl_div(x, y, reduction='batchmean'),
'ce': lambda x, y: -nn.functional.cross_entropy(x, y),
'emd': EMDLoss(),
}
def get_loss_function(name: str):
return loss_functions[name]