forked from thuiar/MMSA
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMulT.py
More file actions
157 lines (137 loc) · 6.84 KB
/
MulT.py
File metadata and controls
157 lines (137 loc) · 6.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
paper: Multimodal Transformer for Unaligned Multimodal Language Sequences
From: https://github.com/yaohungt/Multimodal-Transformer
"""
import torch
from torch import nn
import torch.nn.functional as F
from models.subNets.transformers.transformer import TransformerEncoder
__all__ = ['MulT']
class MulT(nn.Module):
def __init__(self, args):
"""
Construct a MulT model.
"""
super(MulT, self).__init__()
dst_feature_dims, nheads = args.dst_feature_dim_nheads
self.orig_d_l, self.orig_d_a, self.orig_d_v = args.feature_dims
self.d_l = self.d_a = self.d_v = dst_feature_dims
self.vonly, self.aonly, self.lonly = args.vonly, args.aonly, args.lonly
self.num_heads = nheads
self.layers = args.nlevels
self.attn_dropout = args.attn_dropout
self.attn_dropout_a = args.attn_dropout_a
self.attn_dropout_v = args.attn_dropout_v
self.relu_dropout = args.relu_dropout
self.embed_dropout = args.embed_dropout
self.res_dropout = args.res_dropout
self.output_dropout = args.output_dropout
self.text_dropout = args.text_dropout
self.attn_mask = args.attn_mask
combined_dim = self.d_l + self.d_a + self.d_v
self.partial_mode = self.lonly + self.aonly + self.vonly
if self.partial_mode == 1:
combined_dim = 2 * self.d_l # assuming d_l == d_a == d_v
else:
combined_dim = 2 * (self.d_l + self.d_a + self.d_v)
output_dim = args.num_classes # This is actually not a hyperparameter :-)
# 1. Temporal convolutional layers
self.proj_l = nn.Conv1d(self.orig_d_l, self.d_l, kernel_size=args.conv1d_kernel_size_l, padding=0, bias=False)
self.proj_a = nn.Conv1d(self.orig_d_a, self.d_a, kernel_size=args.conv1d_kernel_size_a, padding=0, bias=False)
self.proj_v = nn.Conv1d(self.orig_d_v, self.d_v, kernel_size=args.conv1d_kernel_size_v, padding=0, bias=False)
# 2. Crossmodal Attentions
if self.lonly:
self.trans_l_with_a = self.get_network(self_type='la')
self.trans_l_with_v = self.get_network(self_type='lv')
if self.aonly:
self.trans_a_with_l = self.get_network(self_type='al')
self.trans_a_with_v = self.get_network(self_type='av')
if self.vonly:
self.trans_v_with_l = self.get_network(self_type='vl')
self.trans_v_with_a = self.get_network(self_type='va')
# 3. Self Attentions (Could be replaced by LSTMs, GRUs, etc.)
# [e.g., self.trans_x_mem = nn.LSTM(self.d_x, self.d_x, 1)
self.trans_l_mem = self.get_network(self_type='l_mem', layers=3)
self.trans_a_mem = self.get_network(self_type='a_mem', layers=3)
self.trans_v_mem = self.get_network(self_type='v_mem', layers=3)
# Projection layers
self.proj1 = nn.Linear(combined_dim, combined_dim)
self.proj2 = nn.Linear(combined_dim, combined_dim)
self.out_layer = nn.Linear(combined_dim, output_dim)
def get_network(self, self_type='l', layers=-1):
if self_type in ['l', 'al', 'vl']:
embed_dim, attn_dropout = self.d_l, self.attn_dropout
elif self_type in ['a', 'la', 'va']:
embed_dim, attn_dropout = self.d_a, self.attn_dropout_a
elif self_type in ['v', 'lv', 'av']:
embed_dim, attn_dropout = self.d_v, self.attn_dropout_v
elif self_type == 'l_mem':
embed_dim, attn_dropout = 2*self.d_l, self.attn_dropout
elif self_type == 'a_mem':
embed_dim, attn_dropout = 2*self.d_a, self.attn_dropout
elif self_type == 'v_mem':
embed_dim, attn_dropout = 2*self.d_v, self.attn_dropout
else:
raise ValueError("Unknown network type")
return TransformerEncoder(embed_dim=embed_dim,
num_heads=self.num_heads,
layers=max(self.layers, layers),
attn_dropout=attn_dropout,
relu_dropout=self.relu_dropout,
res_dropout=self.res_dropout,
embed_dropout=self.embed_dropout,
attn_mask=self.attn_mask)
def forward(self, x_l, x_a, x_v):
"""
text, audio, and vision should have dimension [batch_size, seq_len, n_features]
"""
x_l = F.dropout(x_l.transpose(1, 2), p=self.text_dropout, training=self.training)
x_a = x_a.transpose(1, 2)
x_v = x_v.transpose(1, 2)
# Project the textual/visual/audio features
proj_x_l = x_l if self.orig_d_l == self.d_l else self.proj_l(x_l)
proj_x_a = x_a if self.orig_d_a == self.d_a else self.proj_a(x_a)
proj_x_v = x_v if self.orig_d_v == self.d_v else self.proj_v(x_v)
proj_x_a = proj_x_a.permute(2, 0, 1)
proj_x_v = proj_x_v.permute(2, 0, 1)
proj_x_l = proj_x_l.permute(2, 0, 1)
if self.lonly:
# (V,A) --> L
h_l_with_as = self.trans_l_with_a(proj_x_l, proj_x_a, proj_x_a) # Dimension (L, N, d_l)
h_l_with_vs = self.trans_l_with_v(proj_x_l, proj_x_v, proj_x_v) # Dimension (L, N, d_l)
h_ls = torch.cat([h_l_with_as, h_l_with_vs], dim=2)
h_ls = self.trans_l_mem(h_ls)
if type(h_ls) == tuple:
h_ls = h_ls[0]
last_h_l = last_hs = h_ls[-1] # Take the last output for prediction
if self.aonly:
# (L,V) --> A
h_a_with_ls = self.trans_a_with_l(proj_x_a, proj_x_l, proj_x_l)
h_a_with_vs = self.trans_a_with_v(proj_x_a, proj_x_v, proj_x_v)
h_as = torch.cat([h_a_with_ls, h_a_with_vs], dim=2)
h_as = self.trans_a_mem(h_as)
if type(h_as) == tuple:
h_as = h_as[0]
last_h_a = last_hs = h_as[-1]
if self.vonly:
# (L,A) --> V
h_v_with_ls = self.trans_v_with_l(proj_x_v, proj_x_l, proj_x_l)
h_v_with_as = self.trans_v_with_a(proj_x_v, proj_x_a, proj_x_a)
h_vs = torch.cat([h_v_with_ls, h_v_with_as], dim=2)
h_vs = self.trans_v_mem(h_vs)
if type(h_vs) == tuple:
h_vs = h_vs[0]
last_h_v = last_hs = h_vs[-1]
if self.partial_mode == 3:
last_hs = torch.cat([last_h_l, last_h_a, last_h_v], dim=1)
# A residual block
last_hs_proj = self.proj2(F.dropout(F.relu(self.proj1(last_hs), inplace=True), p=self.output_dropout, training=self.training))
last_hs_proj += last_hs
output = self.out_layer(last_hs_proj)
res = {
'Feature_t': last_h_l,
'Feature_a': last_h_a,
'Feature_v': last_h_v,
'M': output
}
return res