Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions signjoey/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ def forward(self, query: Tensor = None, mask: Tensor = None, values: Tensor = No
"""
self._check_input_shapes_forward(query=query, mask=mask, values=values)

assert mask is not None, "mask is required"
assert self.proj_keys is not None, "projection keys have to get pre-computed"
if mask is None:
raise ValueError("mask is required")
if self.proj_keys is None:
raise ValueError("projection keys have to get pre-computed")

# We first project the query (the decoder state).
# The projected keys (the encoder states) were already pre-computated.
Expand Down Expand Up @@ -117,11 +119,16 @@ def _check_input_shapes_forward(
:param values:
:return:
"""
assert query.shape[0] == values.shape[0] == mask.shape[0]
assert query.shape[1] == 1 == mask.shape[1]
assert query.shape[2] == self.query_layer.in_features
assert values.shape[2] == self.key_layer.in_features
assert mask.shape[2] == values.shape[1]
if query.shape[0] != values.shape[0] != mask.shape[0]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is wrong. It should be if not (query.shape[0] == values.shape[0] == mask.shape[0]):

raise ValueError("Invalid input shape for query, mask, and values")
if query.shape[1] != 1 or mask.shape[1] != 1:
raise ValueError("Invalid input shape for query and mask")
if query.shape[2] != self.query_layer.in_features:
raise ValueError("Invalid number of features in query")
if values.shape[2] != self.key_layer.in_features:
raise ValueError("Invalid number of features in values")
if mask.shape[2] != values.shape[1]:
raise ValueError("Inconsistent key size with mask")

def __repr__(self):
return "BahdanauAttention"
Expand Down Expand Up @@ -172,8 +179,10 @@ def forward(
"""
self._check_input_shapes_forward(query=query, mask=mask, values=values)

assert self.proj_keys is not None, "projection keys have to get pre-computed"
assert mask is not None, "mask is required"
if self.proj_keys is None:
raise ValueError("projection keys have to get pre-computed")
if mask is None:
raise ValueError("mask is required")

# scores: batch_size x 1 x sgn_length
scores = query @ self.proj_keys.transpose(1, 2)
Expand Down Expand Up @@ -203,20 +212,20 @@ def compute_proj_keys(self, keys: Tensor):
def _check_input_shapes_forward(
self, query: torch.Tensor, mask: torch.Tensor, values: torch.Tensor
):
"""
Make sure that inputs to `self.forward` are of correct shape.
Same input semantics as for `self.forward`.

:param query:
:param mask:
:param values:
:return:
"""
assert query.shape[0] == values.shape[0] == mask.shape[0]
assert query.shape[1] == 1 == mask.shape[1]
assert query.shape[2] == self.key_layer.out_features
assert values.shape[2] == self.key_layer.in_features
assert mask.shape[2] == values.shape[1]
if not query.shape[0] == values.shape[0] == mask.shape[0]:
raise ValueError("Shapes mismatch for inputs")

if not query.shape[1] == 1 == mask.shape[1]:
raise ValueError("Shapes mismatch for inputs")

if not query.shape[2] == self.key_layer.out_features:
raise ValueError("Shapes mismatch for query and key_layer.out_features")

if not values.shape[2] == self.key_layer.in_features:
raise ValueError("Shapes mismatch for values and key_layer.in_features")

if not mask.shape[2] == values.shape[1]:
raise ValueError("Shapes mismatch for mask and values")

def __repr__(self):
return "LuongAttention"
36 changes: 18 additions & 18 deletions signjoey/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class Batch:
Input is a batch from a torch text iterator.
"""

import secrets

def __init__(
self,
torch_batch,
Expand All @@ -26,73 +28,71 @@ def __init__(
This batch extends torch text's batch attributes with sgn (sign),
gls (gloss), and txt (text) length, masks, number of non-padded tokens in txt.
Furthermore, it can be sorted by sgn length.

:param torch_batch:
:param txt_pad_index:
:param sgn_dim:
:param is_train:
:param use_cuda:
:param random_frame_subsampling
"""

# Sequence Information
self.sequence = torch_batch.sequence
self.signer = torch_batch.signer
# Sign
self.sgn, self.sgn_lengths = torch_batch.sgn

# Here be dragons
if frame_subsampling_ratio:
tmp_sgn = torch.zeros_like(self.sgn)
tmp_sgn_lengths = torch.zeros_like(self.sgn_lengths)
for idx, (features, length) in enumerate(zip(self.sgn, self.sgn_lengths)):
features = features.clone()
if random_frame_subsampling and is_train:
init_frame = random.randint(0, (frame_subsampling_ratio - 1))
init_frame = secrets.randbelow(frame_subsampling_ratio)
else:
init_frame = math.floor((frame_subsampling_ratio - 1) / 2)

init_frame = (frame_subsampling_ratio - 1) // 2
tmp_data = features[: length.long(), :]
tmp_data = tmp_data[init_frame::frame_subsampling_ratio]
tmp_sgn[idx, 0 : tmp_data.shape[0]] = tmp_data
tmp_sgn_lengths[idx] = tmp_data.shape[0]

self.sgn = tmp_sgn[:, : tmp_sgn_lengths.max().long(), :]
self.sgn_lengths = tmp_sgn_lengths

if random_frame_masking_ratio and is_train:
tmp_sgn = torch.zeros_like(self.sgn)
num_mask_frames = (
(self.sgn_lengths * random_frame_masking_ratio).floor().long()
)
for idx, features in enumerate(self.sgn):
features = features.clone()
mask_frame_idx = np.random.permutation(
int(self.sgn_lengths[idx].long().numpy())
)[: num_mask_frames[idx]]
mask_frame_idx = secrets.choice(range(int(self.sgn_lengths[idx].long().numpy())), num_mask_frames[idx])
features[mask_frame_idx, :] = 1e-8
tmp_sgn[idx] = features
self.sgn = tmp_sgn

self.sgn_dim = sgn_dim
self.sgn_mask = (self.sgn != torch.zeros(sgn_dim))[..., 0].unsqueeze(1)

# Text
self.txt = None
self.txt_mask = None
self.txt_input = None
self.txt_lengths = None

# Gloss
self.gls = None
self.gls_lengths = None

# Other
self.num_txt_tokens = None
self.num_gls_tokens = None
self.use_cuda = use_cuda
self.num_seqs = self.sgn.size(0)

if hasattr(torch_batch, "txt"):
txt, txt_lengths = torch_batch.txt
# txt_input is used for teacher forcing, last one is cut off
Expand All @@ -103,11 +103,11 @@ def __init__(
# we exclude the padded areas from the loss computation
self.txt_mask = (self.txt_input != txt_pad_index).unsqueeze(1)
self.num_txt_tokens = (self.txt != txt_pad_index).data.sum().item()

if hasattr(torch_batch, "gls"):
self.gls, self.gls_lengths = torch_batch.gls
self.num_gls_tokens = self.gls_lengths.sum().detach().clone().numpy()

if use_cuda:
self._make_cuda()

Expand Down
46 changes: 34 additions & 12 deletions signjoey/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,26 +157,48 @@ def _check_shapes_input_forward_step(
"""
Make sure the input shapes to `self._forward_step` are correct.
Same inputs as `self._forward_step`.

:param prev_embed:
:param prev_att_vector:
:param encoder_output:
:param src_mask:
:param hidden:
"""
assert prev_embed.shape[1:] == torch.Size([1, self.emb_size])
assert prev_att_vector.shape[1:] == torch.Size([1, self.hidden_size])
assert prev_att_vector.shape[0] == prev_embed.shape[0]
assert encoder_output.shape[0] == prev_embed.shape[0]
assert len(encoder_output.shape) == 3
assert src_mask.shape[0] == prev_embed.shape[0]
assert src_mask.shape[1] == 1
assert src_mask.shape[2] == encoder_output.shape[1]
if prev_embed.shape[1:] != torch.Size([1, self.emb_size]):
raise ValueError("Invalid shape for prev_embed")

if prev_att_vector.shape[1:] != torch.Size([1, self.hidden_size]):
raise ValueError("Invalid shape for prev_att_vector")

if prev_att_vector.shape[0] != prev_embed.shape[0]:
raise ValueError("Shapes of prev_att_vector and prev_embed do not match")

if encoder_output.shape[0] != prev_embed.shape[0]:
raise ValueError("Shapes of encoder_output and prev_embed do not match")

if len(encoder_output.shape) != 3:
raise ValueError("Invalid shape for encoder_output")

if src_mask.shape[0] != prev_embed.shape[0]:
raise ValueError("Shapes of src_mask and prev_embed do not match")

if src_mask.shape[1] != 1:
raise ValueError("Invalid shape for src_mask")

if src_mask.shape[2] != encoder_output.shape[1]:
raise ValueError("Shapes of src_mask and encoder_output do not match")

if isinstance(hidden, tuple): # for lstm
hidden = hidden[0]
assert hidden.shape[0] == self.num_layers
assert hidden.shape[1] == prev_embed.shape[0]
assert hidden.shape[2] == self.hidden_size

if hidden.shape[0] != self.num_layers:
raise ValueError("Invalid shape for hidden")

if hidden.shape[1] != prev_embed.shape[0]:
raise ValueError("Shapes of hidden and prev_embed do not match")

if hidden.shape[2] != self.hidden_size:
raise ValueError("Invalid shape for hidden")

def _check_shapes_input_forward(
self,
Expand Down