-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathlayers.py
More file actions
157 lines (131 loc) · 6.5 KB
/
layers.py
File metadata and controls
157 lines (131 loc) · 6.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import torch.nn as nn
from einops import rearrange
import numpy as np
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
if layer.bias is not None:
torch.nn.init.constant_(layer.bias, bias_const)
return layer
def batched_index_select(input, dim, index):
for ii in range(1, len(input.shape)):
if ii != dim:
index = index.unsqueeze(ii)
expanse = list(input.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.expand(expanse)
return torch.gather(input, dim, index)
class PositionalEncoding(nn.Module):
def __init__(self, dim, min_timescale=2.0, max_timescale=1e4):
super().__init__()
freqs = torch.arange(0, dim, min_timescale)
inv_freqs = max_timescale ** (-freqs / dim)
self.register_buffer("inv_freqs", inv_freqs)
def forward(self, seq_len):
seq = torch.arange(seq_len - 1, -1, -1.0)
sinusoidal_inp = rearrange(seq, "n -> n ()") * rearrange(self.inv_freqs, "d -> () d")
pos_emb = torch.cat((sinusoidal_inp.sin(), sinusoidal_inp.cos()), dim=-1)
return pos_emb
class GatingMechanism(nn.Module):
"""Gating mechanism for GTRXL"""
def __init__(self, dim):
super().__init__()
self.gating_layer = nn.Sequential(
nn.Linear(dim, dim),
nn.Sigmoid()
)
def forward(self, x, y):
"""Apply gating mechanism between x and y"""
gate = self.gating_layer(x)
return gate * y + (1 - gate) * x
class MultiHeadAttention(nn.Module):
"""Multi Head Attention without dropout inspired by https://github.com/aladdinpersson/Machine-Learning-Collection"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_size = embed_dim // num_heads
assert self.head_size * num_heads == embed_dim, "Embedding dimension needs to be divisible by the number of heads"
self.values = nn.Linear(self.head_size, self.head_size, bias=False)
self.keys = nn.Linear(self.head_size, self.head_size, bias=False)
self.queries = nn.Linear(self.head_size, self.head_size, bias=False)
self.fc_out = nn.Linear(self.num_heads * self.head_size, embed_dim)
def forward(self, values, keys, query, mask):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
values = values.reshape(N, value_len, self.num_heads, self.head_size)
keys = keys.reshape(N, key_len, self.num_heads, self.head_size)
query = query.reshape(N, query_len, self.num_heads, self.head_size)
values = self.values(values) # (N, value_len, heads, head_dim)
keys = self.keys(keys) # (N, key_len, heads, head_dim)
queries = self.queries(query) # (N, query_len, heads, heads_dim)
# Dot-product
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# Mask padded indices so their attention weights become 0
if mask is not None:
energy = energy.masked_fill(mask.unsqueeze(1).unsqueeze(1) == 0, float("-1e20")) # -inf causes NaN
# Normalize energy values and apply softmax to retrieve the attention scores
attention = torch.softmax(
energy / (self.embed_dim ** (1 / 2)), dim=3
) # attention shape: (N, heads, query_len, key_len)
# Scale values by attention weights
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.num_heads * self.head_size)
return self.fc_out(out), attention
class TransformerLayer(nn.Module):
def __init__(self, dim, num_heads, is_gated=False):
super().__init__()
self.attention = MultiHeadAttention(dim, num_heads)
self.layer_norm_q = nn.LayerNorm(dim)
self.norm_kv = nn.LayerNorm(dim)
self.layer_norm_attn = nn.LayerNorm(dim)
self.fc_projection = nn.Sequential(nn.Linear(dim, dim), nn.ReLU())
self.is_gated = is_gated
# Gating mechanisms
self.attn_gate = GatingMechanism(dim)
self.ffn_gate = GatingMechanism(dim)
def forward(self, value, key, query, mask):
# Pre-layer normalization (post-layer normalization is usually less effective)
query_ = self.layer_norm_q(query)
value = self.norm_kv(value)
key = value # K = V -> self-attention
attention, attention_weights = self.attention(value, key, query_, mask) # MHA
if self.is_gated:
x = self.attn_gate(query, attention)
else:
x = attention + query # Skip connection
x_ = self.layer_norm_attn(x) # Pre-layer normalization
forward = self.fc_projection(x_) # Forward projection
if self.is_gated:
out = self.ffn_gate(x, forward)
else:
out = forward + x # Skip connection
return out, attention_weights
class Transformer(nn.Module):
def __init__(self, num_layers, dim, num_heads, max_episode_steps, positional_encoding, is_gated=False):
super().__init__()
self.max_episode_steps = max_episode_steps
self.positional_encoding = positional_encoding
if positional_encoding == "absolute":
self.pos_embedding = PositionalEncoding(dim)
elif positional_encoding == "learned":
self.pos_embedding = nn.Parameter(torch.randn(max_episode_steps, dim))
self.transformer_layers = nn.ModuleList([TransformerLayer(dim, num_heads, is_gated=is_gated) for _ in range(num_layers)])
def forward(self, x, memories, mask, memory_indices):
# Add positional encoding to every transformer layer input
if self.positional_encoding == "absolute":
pos_embedding = self.pos_embedding(self.max_episode_steps)[memory_indices]
memories = memories + pos_embedding.unsqueeze(2)
elif self.positional_encoding == "learned":
memories = memories + self.pos_embedding[memory_indices].unsqueeze(2)
# Forward transformer layers and return new memories (i.e. hidden states)
out_memories = []
for i, layer in enumerate(self.transformer_layers):
out_memories.append(x.detach())
x, attention_weights = layer(
memories[:, :, i], memories[:, :, i], x.unsqueeze(1), mask
) # args: value, key, query, mask
x = x.squeeze()
if len(x.shape) == 1:
x = x.unsqueeze(0)
return x, torch.stack(out_memories, dim=1)