Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Implements Encoder-Decoder Attention Model #28

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added i6_models/decoder/__init__.py
Empty file.
209 changes: 209 additions & 0 deletions i6_models/decoder/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch import nn

from .zoneout_lstm import ZoneoutLSTMCell


@dataclass
class AdditiveAttentionConfig:
"""
Attributes:
attention_dim: attention dimension
att_weights_dropout: attention weights dropout
"""

attention_dim: int
att_weights_dropout: float


class AdditiveAttention(nn.Module):
"""
Additive attention mechanism. This is defined as:
energies = v^T * tanh(h + s + beta) where beta is weight feedback information
weights = softmax(energies)
context = sum_t weights_t * h_t
Comment on lines +25 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The symbols in this docstring are partly undefined/different to the parameter names in forward. It would be easier to understand if the naming was unified.

"""

def __init__(self, cfg: AdditiveAttentionConfig):
super().__init__()
self.linear = nn.Linear(cfg.attention_dim, 1, bias=False)
self.att_weights_drop = nn.Dropout(cfg.att_weights_dropout)

def forward(
self,
key: torch.Tensor,
value: torch.Tensor,
query: torch.Tensor,
weight_feedback: torch.Tensor,
enc_seq_len: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param key: encoder keys of shape [B,T,D_k]
:param value: encoder values of shape [B,T,D_v]
:param query: query of shape [B,D_k]
:param weight_feedback: shape is [B,T,D_k]
:param enc_seq_len: encoder sequence lengths [B]
:return: attention context [B,D_v], attention weights [B,T,1]
"""
# all inputs are already projected
energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1]
time_arange = torch.arange(energies.size(1), device=energies.device) # [T]
seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T]
energies = torch.where(seq_len_mask.unsqueeze(2), energies, energies.new_tensor(-float("inf")))
weights = nn.functional.softmax(energies, dim=1) # [B,T,1]
weights = self.att_weights_drop(weights)
context = torch.bmm(weights.transpose(1, 2), value) # [B,1,D_v]
context = context.reshape(context.size(0), -1) # [B,D_v]
return context, weights


@dataclass
class AttentionLSTMDecoderV1Config:
"""
Attributes:
encoder_dim: encoder dimension
vocab_size: vocabulary size
target_embed_dim: embedding dimension
target_embed_dropout: embedding dropout
lstm_hidden_size: LSTM hidden size
zoneout_drop_h: zoneout drop probability for hidden state
zoneout_drop_c: zoneout drop probability for cell state
attention_cfg: attention config
output_proj_dim: output projection dimension
output_dropout: output dropout
"""

encoder_dim: int
vocab_size: int
target_embed_dim: int
target_embed_dropout: float
lstm_hidden_size: int
zoneout_drop_h: float
zoneout_drop_c: float
attention_cfg: AdditiveAttentionConfig
output_proj_dim: int
output_dropout: float


class AttentionLSTMDecoderV1(nn.Module):
"""
Single-headed Attention decoder with additive attention mechanism.
"""

def __init__(self, cfg: AttentionLSTMDecoderV1Config):
super().__init__()

self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim)
self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout)

lstm_cell = nn.LSTMCell(
input_size=cfg.target_embed_dim + cfg.encoder_dim,
hidden_size=cfg.lstm_hidden_size,
)
self.lstm_hidden_size = cfg.lstm_hidden_size
# if zoneout drop probs are 0, then it is equivalent to normal LSTMCell
self.s = ZoneoutLSTMCell(
cell=lstm_cell,
zoneout_h=cfg.zoneout_drop_h,
zoneout_c=cfg.zoneout_drop_c,
)

self.s_transformed = nn.Linear(cfg.lstm_hidden_size, cfg.attention_cfg.attention_dim, bias=False) # query

# for attention
self.enc_ctx = nn.Linear(cfg.encoder_dim, cfg.attention_cfg.attention_dim)
self.attention = AdditiveAttention(cfg.attention_cfg)

# for weight feedback
self.inv_fertility = nn.Linear(cfg.encoder_dim, 1, bias=False) # followed by sigmoid
self.weight_feedback = nn.Linear(1, cfg.attention_cfg.attention_dim, bias=False)

self.readout_in = nn.Linear(cfg.lstm_hidden_size + cfg.target_embed_dim + cfg.encoder_dim, cfg.output_proj_dim)
assert cfg.output_proj_dim % 2 == 0, "output projection dimension must be even for the MaxOut op of 2 pieces"
self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size)
self.output_dropout = nn.Dropout(cfg.output_dropout)

def forward(
self,
encoder_outputs: torch.Tensor,
labels: torch.Tensor,
enc_seq_len: torch.Tensor,
state: Optional[Tuple[torch.Tensor, ...]] = None,
shift_embeddings: bool = True,
):
"""
:param encoder_outputs: encoder outputs of shape [B,T,D], same for training and search
:param labels:
training: labels of shape [B,N]
(greedy-)search: hypotheses last label as [B,1]
:param enc_seq_len: encoder sequence lengths of shape [B,T], same for training and search
:param state: decoder state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape info for state tensors is missing.

training: Usually None, unless decoding should be initialized with a certain state (e.g. for context init)
search: current state of the active hypotheses
:param shift_embeddings: shift the embeddings by one position along U, padding with zero in front and drop last
training: this should be "True", in order to start with a zero target embedding
search: use True for the first step in order to start with a zero embedding, False otherwise
Comment on lines +146 to +148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of this shift_embeddings logic. I would rather handle this externally by prepending a begin-token to labels or using the begin-token in the first search step. If the embedding must be an all-zero vector this could be achieved via the padding_idx parameter in torch.nn.Embedding.

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs for the return values are missing.

if state is None:
zeros = encoder_outputs.new_zeros((encoder_outputs.size(0), self.lstm_hidden_size))
lstm_state = (zeros, zeros)
att_context = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(2)))
accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1))
else:
lstm_state, att_context, accum_att_weights = state

target_embeddings = self.target_embed(labels) # [B,N,D]
target_embeddings = self.target_embed_dropout(target_embeddings)

if shift_embeddings:
# pad for BOS and remove last token as this represents history and last token is not used
target_embeddings = nn.functional.pad(target_embeddings, (0, 0, 1, 0), value=0)[:, :-1, :] # [B,N,D]

enc_ctx = self.enc_ctx(encoder_outputs) # [B,T,D]
enc_inv_fertility = nn.functional.sigmoid(self.inv_fertility(encoder_outputs)) # [B,T,1]

num_steps = labels.size(1) # N

# collect for computing later the decoder logits outside the loop
s_list = []
att_context_list = []

# decoder loop
for step in range(num_steps):
target_embed = target_embeddings[:, step, :] # [B,D]

lstm_state = self.s(torch.cat([target_embed, att_context], dim=-1), lstm_state)
lstm_out = lstm_state[0]
s_transformed = self.s_transformed(lstm_out) # project query
s_list.append(lstm_out)

# attention mechanism
weight_feedback = self.weight_feedback(accum_att_weights)
att_context, att_weights = self.attention(
key=enc_ctx,
value=encoder_outputs,
query=s_transformed,
weight_feedback=weight_feedback,
enc_seq_len=enc_seq_len,
)
att_context_list.append(att_context)
accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5

# output layer
s_stacked = torch.stack(s_list, dim=1) # [B,N,D]
att_context_stacked = torch.stack(att_context_list, dim=1) # [B,N,D]
readout_in = self.readout_in(torch.cat([s_stacked, target_embeddings, att_context_stacked], dim=-1)) # [B,N,D]

# maxout layer
readout_in = readout_in.view(readout_in.size(0), readout_in.size(1), -1, 2) # [B,N,D/2,2]
readout, _ = torch.max(readout_in, dim=-1) # [B,N,D/2]

readout_drop = self.output_dropout(readout)
decoder_logits = self.output(readout_drop)

state = lstm_state, att_context, accum_att_weights

return decoder_logits, state
47 changes: 47 additions & 0 deletions i6_models/decoder/zoneout_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from torch import nn

from typing import Tuple


class ZoneoutLSTMCell(nn.Module):
Atticus1806 marked this conversation as resolved.
Show resolved Hide resolved
"""
Wrap an LSTM cell with Zoneout regularization (https://arxiv.org/abs/1606.01305)
"""

def __init__(self, cell: nn.RNNCellBase, zoneout_h: float, zoneout_c: float):
"""
:param cell: LSTM cell
:param zoneout_h: zoneout drop probability for hidden state
:param zoneout_c: zoneout drop probability for cell state
"""
super().__init__()
self.cell = cell
assert 0.0 <= zoneout_h <= 1.0 and 0.0 <= zoneout_c <= 1.0, "Zoneout drop probability must be in [0, 1]"
self.zoneout_h = zoneout_h
self.zoneout_c = zoneout_c

def forward(
self, inputs: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
h, c = self.cell(inputs)
prev_h, prev_c = state
h = self._zoneout(prev_h, h, self.zoneout_h)
c = self._zoneout(prev_c, c, self.zoneout_c)
return h, c

def _zoneout(self, prev_state: torch.Tensor, curr_state: torch.Tensor, factor: float):
"""
Apply Zoneout.

:param prev: previous state tensor
:param curr: current state tensor
:param factor: drop probability
"""
if factor == 0.0:
return curr_state
if self.training:
mask = curr_state.new_empty(size=curr_state.size()).bernoulli_(factor)
return mask * prev_state + (1 - mask) * curr_state
else:
return factor * prev_state + (1 - factor) * curr_state
78 changes: 78 additions & 0 deletions tests/test_enc_dec_att.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
from torch import nn

from i6_models.decoder.attention import AdditiveAttention, AdditiveAttentionConfig
from i6_models.decoder.attention import AttentionLSTMDecoderV1, AttentionLSTMDecoderV1Config


def test_additive_attention():
cfg = AdditiveAttentionConfig(attention_dim=5, att_weights_dropout=0.1)
att = AdditiveAttention(cfg)
key = torch.rand((10, 20, 5))
value = torch.rand((10, 20, 5))
query = torch.rand((10, 5))

enc_seq_len = torch.arange(start=10, end=20) # [10, ..., 19]

# pass key as weight feedback just for testing
context, weights = att(key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len)
assert context.shape == (10, 5)
assert weights.shape == (10, 20, 1)

# Testing attention weights masking:
# for first seq, the enc seq length is 10 so half the weights should be 0
assert torch.eq(weights[0, 10:, 0], torch.tensor(0.0)).all()
# test for other seqs
assert torch.eq(weights[5, 15:, 0], torch.tensor(0.0)).all()


def test_encoder_decoder_attention_model():
encoder = torch.rand((10, 20, 5))
encoder_seq_len = torch.arange(start=10, end=20) # [10, ..., 19]
decoder_cfg = AttentionLSTMDecoderV1Config(
encoder_dim=5,
vocab_size=15,
target_embed_dim=3,
target_embed_dropout=0.1,
lstm_hidden_size=12,
attention_cfg=AdditiveAttentionConfig(attention_dim=10, att_weights_dropout=0.1),
output_proj_dim=12,
output_dropout=0.1,
zoneout_drop_c=0.0,
zoneout_drop_h=0.0,
)
decoder = AttentionLSTMDecoderV1(decoder_cfg)
target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N]

decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len)

assert decoder_logits.shape == (10, 7, 15)


def test_zoneout_lstm_cell():
encoder = torch.rand((10, 20, 5))
encoder_seq_len = torch.arange(start=10, end=20) # [10, ..., 19]
target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N]

def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float):
decoder_cfg = AttentionLSTMDecoderV1Config(
encoder_dim=5,
vocab_size=15,
target_embed_dim=3,
target_embed_dropout=0.1,
lstm_hidden_size=12,
attention_cfg=AdditiveAttentionConfig(attention_dim=10, att_weights_dropout=0.1),
output_proj_dim=12,
output_dropout=0.1,
zoneout_drop_c=zoneout_drop_c,
zoneout_drop_h=zoneout_drop_h,
)
decoder = AttentionLSTMDecoderV1(decoder_cfg)
decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len)
return decoder_logits

decoder_logits = forward_decoder(zoneout_drop_c=0.15, zoneout_drop_h=0.05)
assert decoder_logits.shape == (10, 7, 15)

decoder_logits = forward_decoder(zoneout_drop_c=0.0, zoneout_drop_h=0.0)
assert decoder_logits.shape == (10, 7, 15)
Loading