forked from jkallini/mrt5
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodeling_mrt5.py
More file actions
1039 lines (905 loc) · 44.3 KB
/
modeling_mrt5.py
File metadata and controls
1039 lines (905 loc) · 44.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# modeling_mrt5.py
# Author: Julie Kallini
# Description: This file contains the implementation of the MrT5 model.
# The code is adapted from HuggingFace's modeling_t5.py. New code sequences
# are labeled with comments.
import torch
import copy
from torch import nn
from transformers import GradientCheckpointingLayer
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5LayerNorm,
T5LayerFF,
T5Stack,
T5ForConditionalGeneration,
)
from .configuration_mrt5 import MrT5Config
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
)
from transformers.utils import logging, is_torchdynamo_compiling
from transformers.utils.deprecation import deprecate_kwarg
from typing import Optional, Tuple, Union
from dataclasses import dataclass
logger = logging.get_logger(__name__)
@dataclass
class MrT5BaseModelOutputWithPastAndCrossAttentions(BaseModelOutputWithPastAndCrossAttentions):
delete_gate_mask: torch.FloatTensor = None
delete_gate_output: torch.FloatTensor = None
delete_gate_logits: torch.FloatTensor = None
attention_mask: torch.FloatTensor = None
attention_queries: torch.FloatTensor = None
attention_keys: torch.FloatTensor = None
attention_values: torch.FloatTensor = None
attention_scores: torch.FloatTensor = None
cross_attention_keys: torch.FloatTensor = None
cross_attention_queries: torch.FloatTensor = None
cross_attention_values: torch.FloatTensor = None
cross_attention_scores: torch.FloatTensor = None
@dataclass
class MrT5Seq2SeqLMOutput(Seq2SeqLMOutput):
delete_gate_mask: torch.FloatTensor = None
delete_gate_output: torch.FloatTensor = None
delete_gate_logits: torch.FloatTensor = None
encoder_keys: torch.FloatTensor = None
encoder_queries: torch.FloatTensor = None
encoder_values: torch.FloatTensor = None
encoder_scores: torch.FloatTensor = None
decoder_keys: torch.FloatTensor = None
decoder_queries: torch.FloatTensor = None
decoder_values: torch.FloatTensor = None
decoder_scores: torch.FloatTensor = None
cross_attention_keys: torch.FloatTensor = None
cross_attention_queries: torch.FloatTensor = None
cross_attention_values: torch.FloatTensor = None
cross_attention_scores: torch.FloatTensor = None
TORCH_INIT_FUNCTIONS = {
"uniform_": nn.init.uniform_,
"normal_": nn.init.normal_,
"trunc_normal_": nn.init.trunc_normal_,
"constant_": nn.init.constant_,
"xavier_uniform_": nn.init.xavier_uniform_,
"xavier_normal_": nn.init.xavier_normal_,
"kaiming_uniform_": nn.init.kaiming_uniform_,
"kaiming_normal_": nn.init.kaiming_normal_,
"uniform": nn.init.uniform,
"normal": nn.init.normal,
"xavier_uniform": nn.init.xavier_uniform,
"xavier_normal": nn.init.xavier_normal,
"kaiming_uniform": nn.init.kaiming_uniform,
"kaiming_normal": nn.init.kaiming_normal,
}
def softmax1(logits: torch.Tensor, dim: int = -1) -> torch.Tensor:
if logits.shape[dim] == 0:
return logits
m = logits.detach().max(dim, keepdim=True)[0]
logits = logits - m
logits = logits.exp()
return logits / (logits.sum(dim, keepdim=True) + (-m).exp())
class ScaledSigmoid(nn.Module):
def __init__(self, sigmoid_mask_scale):
super().__init__()
self.sigmoid_mask_scale = sigmoid_mask_scale
def forward(self, input):
return self.sigmoid_mask_scale * torch.sigmoid(-input)
class SigmoidDeleteGate(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm = T5LayerNorm(config.hidden_size)
self.feed_forward = nn.Linear(config.hidden_size, 1)
self._init_weights(self.feed_forward)
self.activation = ScaledSigmoid(config.sigmoid_mask_scale)
def forward(self, hidden_states, input_ids):
hidden_states = self.layer_norm(hidden_states)
delete_gate_logits = self.feed_forward(hidden_states)
gate_values = self.activation(delete_gate_logits)
# Check if there are any pad tokens in input_ids
if (input_ids == 0).any():
# Set gate values for pad tokens (input_ids == 0) to sigmoid_mask_scale
pad_mask = (input_ids == 0).unsqueeze(-1)
gate_values = torch.where(pad_mask, torch.tensor(self.activation.sigmoid_mask_scale), gate_values)
return gate_values, delete_gate_logits
def _init_weights(self, m, init_func="xavier_uniform_"):
# Initialize the weights. This is necessary because
# HuggingFace disables initialization during "from_pretrained"
if isinstance(m, nn.Linear):
TORCH_INIT_FUNCTIONS[init_func](m.weight)
m.bias.data.fill_(1)
class MrT5Attention(T5Attention):
"""
Extends the T5Attention class to include a delete gate. Only the forward
method is modified. The delete_gate_mask passed to the forward function
is applied to the attention scores.
"""
def __init__(self, config: MrT5Config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__(config, has_relative_attention_bias, layer_idx)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_values=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
cache_position=None,
#### NEW CODE ####
delete_gate_mask=None,
#### NEW CODE ####
):
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
batch_size, seq_length = hidden_states.shape[:2]
# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None
query_states = self.q(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
is_updated = False
if isinstance(past_key_values, EncoderDecoderCache):
is_updated = past_key_values.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
curr_past_key_value = past_key_values.cross_attention_cache
else:
curr_past_key_value = past_key_values.self_attention_cache
else:
curr_past_key_value = past_key_values
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_values is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = curr_past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
past_key_values.is_updated[self.layer_idx] = True
# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul(query_states, key_states.transpose(3, 2))
if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
)
position_bias = position_bias[:, :, -seq_length:, :]
if mask is not None:
causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
scores += position_bias_masked
#### NEW CODE ####
# Apply the mask from the delete gate
if delete_gate_mask is not None:
scores = scores + delete_gate_mask.squeeze(-1).unsqueeze(-2).unsqueeze(-2)
attn_weights = softmax1(scores.float(), dim=-1).type_as(scores)
#### NEW CODE ####
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.o(attn_output)
outputs = (attn_output, position_bias)
if output_attentions:
outputs = outputs + (attn_weights,)
return outputs
class MrT5LayerSelfAttention(nn.Module):
"""
Modified version of T5LayerSelfAttention that uses MrT5Attention instead
of T5Attention.
"""
def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__()
#### NEW CODE ####
# Use MrT5Attention instead of T5Attention
self.SelfAttention = MrT5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
#### NEW CODE ####
self.layer_norm = T5LayerNorm(
config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
cache_position=None,
#### NEW CODE ####
delete_gate_mask=None,
#### NEW CODE ####
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
#### NEW CODE ####
delete_gate_mask=delete_gate_mask,
#### NEW CODE ####
)
hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
class MrT5LayerCrossAttention(nn.Module):
"""
Modified version of T5LayerCrossAttention that uses MrT5Attention instead
of T5Attention.
"""
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
#### NEW CODE ####
# Use MrT5Attention instead of T5Attention
self.EncDecAttention = MrT5Attention(
config, has_relative_attention_bias=False, layer_idx=layer_idx)
#### NEW CODE ####
self.layer_norm = T5LayerNorm(
config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_values=None,
use_cache=False,
query_length=None,
output_attentions=False,
cache_position=None,
#### NEW CODE ####
delete_gate_mask=None,
#### NEW CODE ####
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
cache_position=cache_position,
#### NEW CODE ####
delete_gate_mask=delete_gate_mask,
#### NEW CODE ####
)
layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs
class MrT5Block(GradientCheckpointingLayer):
"""
Modified version of T5Block that uses MrT5LayerSelfAttention and
MrT5LayerCrossAttention instead of T5LayerSelfAttention and
T5LayerCrossAttention.
"""
def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__()
self.is_decoder = config.is_decoder
self.layer = nn.ModuleList()
#### NEW CODE ####
# Use MrT5LayerSelfAttention and MrT5LayerCrossAttention
# instead of T5LayerSelfAttention and T5LayerCrossAttention
self.layer.append(MrT5LayerSelfAttention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx))
if self.is_decoder:
self.layer.append(MrT5LayerCrossAttention(config, layer_idx=layer_idx))
#### NEW CODE ####
self.layer.append(T5LayerFF(config))
#### NEW CODE ####
# Add delete gate if needed
self.has_delete_gate = not config.is_decoder and (layer_idx == config.delete_gate_layer)
if self.has_delete_gate:
self.delete_gate = SigmoidDeleteGate(config)
# Set hard_delete flags
self.sigmoid_mask_scale = config.sigmoid_mask_scale
self.deletion_threshold = config.deletion_threshold
#### NEW CODE ####
#### NEW CODE ####
def __get_new_positions_and_mask(self, batch_size, seq_len, delete_gate_mask, deletion_threshold, device):
delete_gate_mask = delete_gate_mask.squeeze(-1)
# Create filter from delete gate mask
deletion_threshold = deletion_threshold if deletion_threshold is not None else self.deletion_threshold
keep_this = delete_gate_mask > deletion_threshold
# Calculate the target position for each token
target_pos = torch.cumsum(keep_this, dim=1) - 1
new_len = target_pos[:, -1].max().item() + 1
# Clamp the target position to avoid out of bounds when deleting everything
target_pos = target_pos.clamp(min=0)
# Map the positions to the src side. Do this in int32, because it's faster and we will not have sequences
# longer than 2^31
positions = torch.arange(seq_len, device=device, dtype=torch.int32).repeat(batch_size, 1)
positions *= keep_this.int()
src_side_pos = torch.zeros(batch_size, new_len, device=device, dtype=torch.int32)
src_side_pos.scatter_add_(1, target_pos, positions)
# Create the new mask
new_mask = torch.arange(new_len, device=device).expand(batch_size, -1) <= target_pos[:, -1:]
new_mask = (~new_mask).float() * -1e9
new_mask = new_mask.unsqueeze(-1)
return src_side_pos.long(), new_mask
def __hard_delete_hidden_states(self, hidden_states, positions):
new_hidden_states = torch.gather(hidden_states, 1, positions.unsqueeze(2).expand(-1, -1, hidden_states.size(2)))
return new_hidden_states
def __hard_delete_4_dimensions(self, position_bias, positions):
new_position_bias = torch.gather(position_bias, 1, positions.unsqueeze(2).unsqueeze(3).expand(-1, -1, position_bias.size(2), position_bias.size(3)))
return new_position_bias
#### NEW CODE ####
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
#### NEW CODE ####
delete_gate_mask=None,
#### NEW CODE ####
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
return_dict=True,
cache_position=None,
#### NEW CODE ####
input_ids=None,
hard_delete=None,
deletion_threshold=None,
#### NEW CODE ####
):
##### NEW CODE #####
# Initialize delete gate values and logits for logging/loss calculation
delete_gate_values = None
delete_gate_logits = None
if self.has_delete_gate:
delete_gate_values, delete_gate_logits = self.delete_gate(
hidden_states, input_ids)
delete_gate_mask = delete_gate_values
# Raise error if all tokens are deleted in any sequence in batch
if (delete_gate_values < self.deletion_threshold).all():
raise ValueError("All tokens are deleted in this batch. " + \
"Please adjust the deletion rate or " + \
"alpha hyperparameter.")
# Apply hard deletion
if hard_delete:
# Compute new token positions
new_positions, delete_gate_mask = self.__get_new_positions_and_mask(
hidden_states.size(0), hidden_states.size(1), delete_gate_mask, deletion_threshold, hidden_states.device)
# Compute new position bias
if position_bias is not None:
new_position_bias = self.__hard_delete_4_dimensions(
position_bias.permute(0, 2, 3, 1), new_positions)
new_position_bias = self.__hard_delete_4_dimensions(
new_position_bias.permute(0, 2, 1, 3), new_positions)
position_bias = new_position_bias.permute(0, 3, 2, 1)
# Compute new attention mask
new_attention_mask = self.__hard_delete_4_dimensions(
attention_mask.permute(0, 3, 1, 2), new_positions)
attention_mask = new_attention_mask.permute(0, 2, 3, 1)
# Compute new hidden states and delete gate mask
hidden_states = self.__hard_delete_hidden_states(
hidden_states, new_positions)
##### NEW CODE #####
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
#### NEW CODE ####
# Only apply delete_gate_mask to self-attention if the block
# is the encoder
delete_gate_mask=None if self.is_decoder else delete_gate_mask,
#### NEW CODE ####
)
hidden_states = self_attention_outputs[0]
attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention:
cross_attention_outputs = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask,
past_key_values=past_key_values,
query_length=cache_position[-1] + 1,
use_cache=use_cache,
output_attentions=output_attentions,
#### NEW CODE ####
delete_gate_mask=delete_gate_mask,
#### NEW CODE ####
)
hidden_states = cross_attention_outputs[0]
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[1:]
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
##### NEW CODE #####
# hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
outputs = outputs + attention_outputs
if self.has_delete_gate:
outputs = outputs + \
(delete_gate_values, delete_gate_logits, delete_gate_mask, attention_mask)
# hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (delete_gate_mask), (delete_gate_logits)
return outputs
##### NEW CODE #####
class MrT5Stack(T5Stack):
def __init__(self, config, embed_tokens=None):
super().__init__(config, embed_tokens)
##### NEW CODE #####
self.block = nn.ModuleList(
[MrT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)]
)
##### NEW CODE #####
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
cache_position=None,
#### NEW CODE ####
delete_gate_mask=None,
delete_gate_output=None,
delete_gate_logits=None,
hard_delete=None,
deletion_threshold=None,
#### NEW CODE ####
):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(self.first_device)
self.embed_tokens = self.embed_tokens.to(self.first_device)
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None:
if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings")
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
if use_cache is True:
if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if self.is_decoder:
if use_cache and past_key_values is None:
if self.config.is_encoder_decoder:
past_key_values = EncoderDecoderCache(
DynamicCache(config=self.config), DynamicCache(config=self.config)
)
else:
past_key_values = DynamicCache(config=self.config)
elif not self.is_decoder:
# do not pass cache object down the line for encoder stack
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
)
if attention_mask is None and not is_torchdynamo_compiling():
# required mask seq length can be calculated via length of past cache
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.config.is_decoder:
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values,
output_attentions,
)
elif attention_mask is not None:
causal_mask = attention_mask[:, None, None, :]
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
else:
causal_mask = None
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
#### NEW CODE ####
# Return a new encoder attention mask if hard delete is enabled
attention_mask_to_return = None
#### NEW CODE ####
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
position_bias = None
encoder_decoder_position_bias = None
hidden_states = self.dropout(inputs_embeds)
for i, layer_module in enumerate(self.block):
layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i]
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if causal_mask is not None:
causal_mask = causal_mask.to(hidden_states.device)
if position_bias is not None:
position_bias = position_bias.to(hidden_states.device)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
if encoder_extended_attention_mask is not None:
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if layer_head_mask is not None:
layer_head_mask = layer_head_mask.to(hidden_states.device)
if cross_attn_layer_head_mask is not None:
cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
causal_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
#### NEW CODE ####
delete_gate_mask,
#### NEW CODE ####
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
#### NEW CODE ####
input_ids=input_ids,
hard_delete=hard_delete,
deletion_threshold=deletion_threshold,
#### NEW CODE ####
)
#### NEW CODE ####
# Update delete_gate_mask if the previous layer had a delete gate
if layer_module.has_delete_gate:
delete_gate_output, delete_gate_logits, delete_gate_mask, new_attention_mask = layer_outputs[-4], layer_outputs[-3], layer_outputs[-2], layer_outputs[-1]
# Update resized masks if the previous layer did a hard deletion
if hard_delete:
extended_attention_mask = new_attention_mask
attention_mask_to_return = extended_attention_mask.squeeze(-2).squeeze(-2)
attention_mask_to_return = (attention_mask_to_return == 0).int()
#### NEW CODE ####
hidden_states = layer_outputs[0]
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
# (cross-attention position bias), (cross-attention weights)
position_bias = layer_outputs[1]
if self.is_decoder and encoder_hidden_states is not None:
#### NEW CODE ####
encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
#### NEW CODE ####
if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
past_key_values,
all_hidden_states,
all_attentions,
all_cross_attentions,
#### NEW CODE ####
delete_gate_mask,
delete_gate_output,
delete_gate_logits,
attention_mask_to_return,
#### NEW CODE ####
]
if v is not None
)
return MrT5BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
#### NEW CODE ####
delete_gate_mask=delete_gate_mask,
delete_gate_output=delete_gate_output,
delete_gate_logits=delete_gate_logits,
attention_mask=attention_mask_to_return,
#### NEW CODE ####
)
class MrT5ForConditionalGeneration(T5ForConditionalGeneration):
config_class = MrT5Config
def __init__(self, config: MrT5Config):
super().__init__(config)
#### NEW CODE ####
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
self.encoder = MrT5Stack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.num_layers = config.num_decoder_layers
self.decoder = MrT5Stack(decoder_config, self.shared)
#### NEW CODE ####
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
#### NEW CODE ####
hard_delete: bool = False,
deletion_threshold: Optional[float] = None,
#### NEW CODE ####
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
#### NEW CODE ####
hard_delete=hard_delete,
deletion_threshold=deletion_threshold,
#### NEW CODE ####
)
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
#### NEW CODE ####
encoder_outputs = MrT5BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=encoder_outputs.last_hidden_state,
hidden_states=encoder_outputs.hidden_states if 'hidden_states' in encoder_outputs else None,
attentions=encoder_outputs.attentions if 'attentions' in encoder_outputs else None,
delete_gate_mask=encoder_outputs.delete_gate_mask if 'delete_gate_mask' in encoder_outputs else None,
)
#### NEW CODE ####
#### NEW CODE ####
hidden_states = encoder_outputs.last_hidden_state
attention_mask = encoder_outputs.attention_mask if 'attention_mask' in encoder_outputs else attention_mask
#### NEW CODE ####
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids.to(
self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(
self.decoder.first_device)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
#### NEW CODE ####
delete_gate_mask=encoder_outputs.delete_gate_mask,
#### NEW CODE ####
)
sequence_output = decoder_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.encoder.first_device)
self.lm_head = self.lm_head.to(self.encoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output)
loss = None
if labels is not None: