-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathloss.py
More file actions
168 lines (132 loc) · 6.26 KB
/
loss.py
File metadata and controls
168 lines (132 loc) · 6.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
from torch import Tensor
from dataset import Batch
from model import LSTMModel
from util import make_mask_2d
from collections import namedtuple
from torch.nn.functional import cross_entropy
from torch.nn.functional import ctc_loss as ctc
ModelOutput = namedtuple("ModelOutput", field_names=["loss", "logits"])
def _get_logits(model: LSTMModel, batch: Batch) -> Tensor:
return model(
inputs=batch.sources, lengths=batch.source_lengths,
features=batch.features, feature_lengths=batch.feature_lengths
)
def cross_entropy_loss(model: LSTMModel, batch: Batch, reduction: str = "mean") -> ModelOutput:
logits = _get_logits(model=model, batch=batch)
flattened_logits = torch.flatten(logits, end_dim=-2)
flattened_targets = torch.flatten(batch.targets).to(logits.device)
loss = cross_entropy(flattened_logits, flattened_targets, ignore_index=0, reduction=reduction)
return ModelOutput(loss=loss, logits=logits)
def ctc_loss(model: LSTMModel, batch: Batch, reduction: str = "mean") -> ModelOutput:
logits = _get_logits(model=model, batch=batch)
log_probs = torch.log_softmax(logits, dim=-1)
log_probs = torch.transpose(log_probs, 0, 1)
targets = batch.targets.to(logits.device)
tau = model.tau
loss = ctc(
log_probs=log_probs, targets=targets, input_lengths=tau * batch.source_lengths,
target_lengths=batch.target_lengths, blank=0, reduction=reduction
)
return ModelOutput(loss=loss, logits=logits)
def crf_loss(model: LSTMModel, batch: Batch, reduction: str = "mean") -> ModelOutput:
logits = _get_logits(model=model, batch=batch)
targets = batch.targets.to(logits.device)
# emission scores shape: batch x timesteps x #labels
# tags: batch x timesteps
batch_size = logits.shape[0]
# Apply log-softmax
emission_scores = torch.log_softmax(logits, dim=-1)
# Extract emission scores
tag_index = targets.unsqueeze(2)
emission_scores = torch.gather(emission_scores, index=tag_index, dim=2)
emission_scores = emission_scores.squeeze(2)
# Shape [Batch, Timesteps]
# Extract transition scores
transition_scores = model.crf.get_transition_scores(batch.targets)
# Extract prior
prior = model.crf.prior[batch.targets[:, 0]].contiguous()
prior = prior.reshape((batch_size, 1))
# Combine prior and transition scores
transition_scores = torch.cat([prior, transition_scores], dim=1)
# Calculate transition probabilities to stop tag
length_index = (batch.target_lengths - 1).unsqueeze(1).to(emission_scores.device)
final_tags = torch.gather(targets, index=length_index, dim=1)
final_tags = final_tags.flatten()
final_transition_scores = model.crf.final_transition_scores[final_tags]
final_transition_scores = final_transition_scores.contiguous()
# Calculate path probabilities
nll = transition_scores + emission_scores
# Mask padding
mask = make_mask_2d(batch.source_lengths).to(emission_scores.device)
nll = torch.masked_fill(nll, mask=mask, value=0.0)
# Sum tag scores for each sequence
# Add transition probabilities to stop tag
nll = -(torch.sum(nll, dim=1) + final_transition_scores)
if reduction == "mean":
nll = torch.mean(nll, dim=0)
elif reduction == "sum":
nll = torch.sum(nll, dim=0)
elif reduction == "none":
pass
else:
raise ValueError(f"Unknown reduction: {reduction}")
return ModelOutput(loss=nll, logits=logits)
def ctc_crf_loss(model: LSTMModel, batch: Batch, reduction: str = "mean") -> ModelOutput:
# Get prediction log-probs
logits = _get_logits(model=model, batch=batch)
scores = torch.log_softmax(logits, dim=-1)
source_lengths = model.tau * batch.source_lengths
# Save constants
batch_size = scores.shape[0]
source_length = scores.shape[1]
target_length = batch.targets.shape[1]
batch_indexer = torch.arange(batch_size)
neg_inf_score = -1e8
neg_inf_array = torch.full((batch_size, 1), fill_value=neg_inf_score, device=scores.device)
# Extract prior
prior = model.crf.prior[batch.targets[:, 0]].contiguous()
prior = prior.reshape((batch_size, 1))
# Extract transition_scores
transition_scores = model.crf.get_transition_scores(batch.targets)
transition_scores = torch.cat([neg_inf_array, prior, transition_scores], dim=1)
alpha = []
alpha_0 = torch.cat(
[
torch.zeros(batch_size, 1, device=scores.device),
torch.full((batch_size, target_length), fill_value=neg_inf_score, device=scores.device)
],
dim=1
)
alpha.append(alpha_0)
for t in range(1, source_length+1):
prev_alpha = alpha[-1]
blank_scores_t = scores[:, t-1, 0].unsqueeze(1).expand(batch_size, target_length+1)
scores_t = scores[:, t-1, :].gather(dim=-1, index=batch.targets.to(scores.device))
prediction_scores_t = torch.cat([neg_inf_array, scores_t], dim=1)
blank_transition_score = prev_alpha + blank_scores_t
prediction_transition_score = torch.cat([neg_inf_array, prev_alpha[:, :-1]], dim=1)
prediction_transition_score = prediction_transition_score + prediction_scores_t + transition_scores
alpha_t = torch.logsumexp(torch.stack([blank_transition_score, prediction_transition_score]), dim=0)
# alpha_t = prediction_transition_score
alpha.append(alpha_t)
alpha = torch.stack(alpha)
alpha = alpha.transpose(0, 1)
log_likelihoods = alpha[batch_indexer, source_lengths, batch.target_lengths]
log_likelihoods = log_likelihoods.flatten().contiguous()
# Calculate transition probabilities to stop tag
length_index = (batch.target_lengths - 1).unsqueeze(1).to(scores.device)
final_tags = torch.gather(batch.targets.to(scores.device), index=length_index, dim=1)
final_tags = final_tags.flatten()
final_transition_scores = model.crf.final_transition_scores[final_tags]
final_transition_scores = final_transition_scores.contiguous()
nll = -(log_likelihoods + final_transition_scores)
if reduction == "mean":
nll = torch.mean(nll, dim=0)
elif reduction == "sum":
nll = torch.sum(nll, dim=0)
elif reduction == "none":
pass
else:
raise ValueError(f"Unknown reduction: {reduction}")
return ModelOutput(loss=nll, logits=logits)