-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft: Implements Encoder-Decoder Attention Model #28
base: main
Are you sure you want to change the base?
Changes from 18 commits
3476c07
b06aac0
3f5ada0
4378fb9
085f9f3
523ce6c
396a664
9c22aa2
c4d5710
cd366ab
d0ed59b
5a78e40
163315a
6d200eb
1b8dd52
dcf0381
6a147b7
a988e85
bb8fa4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from .zoneout_lstm import ZoneoutLSTMCell | ||
|
||
|
||
@dataclass | ||
class AdditiveAttentionConfig: | ||
""" | ||
Attributes: | ||
attention_dim: attention dimension | ||
att_weights_dropout: attention weights dropout | ||
""" | ||
|
||
attention_dim: int | ||
att_weights_dropout: float | ||
|
||
|
||
class AdditiveAttention(nn.Module): | ||
""" | ||
Additive attention mechanism. This is defined as: | ||
energies = v^T * tanh(h + s + beta) where beta is weight feedback information | ||
weights = softmax(energies) | ||
context = sum_t weights_t * h_t | ||
""" | ||
|
||
def __init__(self, cfg: AdditiveAttentionConfig): | ||
super().__init__() | ||
self.linear = nn.Linear(cfg.attention_dim, 1, bias=False) | ||
self.att_weights_drop = nn.Dropout(cfg.att_weights_dropout) | ||
|
||
def forward( | ||
self, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
query: torch.Tensor, | ||
weight_feedback: torch.Tensor, | ||
enc_seq_len: torch.Tensor, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
:param key: encoder keys of shape [B,T,D_k] | ||
:param value: encoder values of shape [B,T,D_v] | ||
:param query: query of shape [B,D_k] | ||
:param weight_feedback: shape is [B,T,D_k] | ||
:param enc_seq_len: encoder sequence lengths [B] | ||
:return: attention context [B,D_v], attention weights [B,T,1] | ||
""" | ||
# all inputs are already projected | ||
energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1] | ||
time_arange = torch.arange(energies.size(1), device=energies.device) # [T] | ||
seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T] | ||
energies = torch.where(seq_len_mask.unsqueeze(2), energies, energies.new_tensor(-float("inf"))) | ||
weights = nn.functional.softmax(energies, dim=1) # [B,T,1] | ||
weights = self.att_weights_drop(weights) | ||
context = torch.bmm(weights.transpose(1, 2), value) # [B,1,D_v] | ||
context = context.reshape(context.size(0), -1) # [B,D_v] | ||
return context, weights | ||
|
||
|
||
@dataclass | ||
class AttentionLSTMDecoderV1Config: | ||
""" | ||
Attributes: | ||
encoder_dim: encoder dimension | ||
vocab_size: vocabulary size | ||
target_embed_dim: embedding dimension | ||
target_embed_dropout: embedding dropout | ||
lstm_hidden_size: LSTM hidden size | ||
zoneout_drop_h: zoneout drop probability for hidden state | ||
zoneout_drop_c: zoneout drop probability for cell state | ||
attention_cfg: attention config | ||
output_proj_dim: output projection dimension | ||
output_dropout: output dropout | ||
""" | ||
|
||
encoder_dim: int | ||
vocab_size: int | ||
target_embed_dim: int | ||
target_embed_dropout: float | ||
lstm_hidden_size: int | ||
zoneout_drop_h: float | ||
zoneout_drop_c: float | ||
attention_cfg: AdditiveAttentionConfig | ||
output_proj_dim: int | ||
output_dropout: float | ||
|
||
|
||
class AttentionLSTMDecoderV1(nn.Module): | ||
""" | ||
Single-headed Attention decoder with additive attention mechanism. | ||
""" | ||
|
||
def __init__(self, cfg: AttentionLSTMDecoderV1Config): | ||
super().__init__() | ||
|
||
self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim) | ||
self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout) | ||
|
||
lstm_cell = nn.LSTMCell( | ||
input_size=cfg.target_embed_dim + cfg.encoder_dim, | ||
hidden_size=cfg.lstm_hidden_size, | ||
) | ||
self.lstm_hidden_size = cfg.lstm_hidden_size | ||
# if zoneout drop probs are 0, then it is equivalent to normal LSTMCell | ||
self.s = ZoneoutLSTMCell( | ||
cell=lstm_cell, | ||
zoneout_h=cfg.zoneout_drop_h, | ||
zoneout_c=cfg.zoneout_drop_c, | ||
) | ||
|
||
self.s_transformed = nn.Linear(cfg.lstm_hidden_size, cfg.attention_cfg.attention_dim, bias=False) # query | ||
|
||
# for attention | ||
self.enc_ctx = nn.Linear(cfg.encoder_dim, cfg.attention_cfg.attention_dim) | ||
self.attention = AdditiveAttention(cfg.attention_cfg) | ||
|
||
# for weight feedback | ||
self.inv_fertility = nn.Linear(cfg.encoder_dim, 1, bias=False) # followed by sigmoid | ||
self.weight_feedback = nn.Linear(1, cfg.attention_cfg.attention_dim, bias=False) | ||
|
||
self.readout_in = nn.Linear(cfg.lstm_hidden_size + cfg.target_embed_dim + cfg.encoder_dim, cfg.output_proj_dim) | ||
assert cfg.output_proj_dim % 2 == 0, "output projection dimension must be even for the MaxOut op of 2 pieces" | ||
self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size) | ||
self.output_dropout = nn.Dropout(cfg.output_dropout) | ||
|
||
def forward( | ||
self, | ||
encoder_outputs: torch.Tensor, | ||
labels: torch.Tensor, | ||
enc_seq_len: torch.Tensor, | ||
state: Optional[Tuple[torch.Tensor, ...]] = None, | ||
shift_embeddings: bool = True, | ||
): | ||
""" | ||
:param encoder_outputs: encoder outputs of shape [B,T,D], same for training and search | ||
:param labels: | ||
training: labels of shape [B,N] | ||
(greedy-)search: hypotheses last label as [B,1] | ||
:param enc_seq_len: encoder sequence lengths of shape [B,T], same for training and search | ||
:param state: decoder state | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shape info for state tensors is missing. |
||
training: Usually None, unless decoding should be initialized with a certain state (e.g. for context init) | ||
search: current state of the active hypotheses | ||
:param shift_embeddings: shift the embeddings by one position along U, padding with zero in front and drop last | ||
training: this should be "True", in order to start with a zero target embedding | ||
search: use True for the first step in order to start with a zero embedding, False otherwise | ||
Comment on lines
+146
to
+148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a fan of this |
||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docs for the return values are missing. |
||
if state is None: | ||
zeros = encoder_outputs.new_zeros((encoder_outputs.size(0), self.lstm_hidden_size)) | ||
lstm_state = (zeros, zeros) | ||
att_context = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(2))) | ||
accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1)) | ||
else: | ||
lstm_state, att_context, accum_att_weights = state | ||
|
||
target_embeddings = self.target_embed(labels) # [B,N,D] | ||
target_embeddings = self.target_embed_dropout(target_embeddings) | ||
|
||
if shift_embeddings: | ||
# pad for BOS and remove last token as this represents history and last token is not used | ||
target_embeddings = nn.functional.pad(target_embeddings, (0, 0, 1, 0), value=0)[:, :-1, :] # [B,N,D] | ||
|
||
enc_ctx = self.enc_ctx(encoder_outputs) # [B,T,D] | ||
enc_inv_fertility = nn.functional.sigmoid(self.inv_fertility(encoder_outputs)) # [B,T,1] | ||
|
||
num_steps = labels.size(1) # N | ||
|
||
# collect for computing later the decoder logits outside the loop | ||
s_list = [] | ||
att_context_list = [] | ||
|
||
# decoder loop | ||
for step in range(num_steps): | ||
target_embed = target_embeddings[:, step, :] # [B,D] | ||
|
||
lstm_state = self.s(torch.cat([target_embed, att_context], dim=-1), lstm_state) | ||
lstm_out = lstm_state[0] | ||
s_transformed = self.s_transformed(lstm_out) # project query | ||
s_list.append(lstm_out) | ||
|
||
# attention mechanism | ||
weight_feedback = self.weight_feedback(accum_att_weights) | ||
att_context, att_weights = self.attention( | ||
key=enc_ctx, | ||
value=encoder_outputs, | ||
query=s_transformed, | ||
weight_feedback=weight_feedback, | ||
enc_seq_len=enc_seq_len, | ||
) | ||
att_context_list.append(att_context) | ||
accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5 | ||
|
||
# output layer | ||
s_stacked = torch.stack(s_list, dim=1) # [B,N,D] | ||
att_context_stacked = torch.stack(att_context_list, dim=1) # [B,N,D] | ||
readout_in = self.readout_in(torch.cat([s_stacked, target_embeddings, att_context_stacked], dim=-1)) # [B,N,D] | ||
|
||
# maxout layer | ||
readout_in = readout_in.view(readout_in.size(0), readout_in.size(1), -1, 2) # [B,N,D/2,2] | ||
readout, _ = torch.max(readout_in, dim=-1) # [B,N,D/2] | ||
|
||
readout_drop = self.output_dropout(readout) | ||
decoder_logits = self.output(readout_drop) | ||
|
||
state = lstm_state, att_context, accum_att_weights | ||
|
||
return decoder_logits, state |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from typing import Tuple | ||
|
||
|
||
class ZoneoutLSTMCell(nn.Module): | ||
Atticus1806 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Wrap an LSTM cell with Zoneout regularization (https://arxiv.org/abs/1606.01305) | ||
""" | ||
|
||
def __init__(self, cell: nn.RNNCellBase, zoneout_h: float, zoneout_c: float): | ||
""" | ||
:param cell: LSTM cell | ||
:param zoneout_h: zoneout drop probability for hidden state | ||
:param zoneout_c: zoneout drop probability for cell state | ||
""" | ||
super().__init__() | ||
self.cell = cell | ||
assert 0.0 <= zoneout_h <= 1.0 and 0.0 <= zoneout_c <= 1.0, "Zoneout drop probability must be in [0, 1]" | ||
self.zoneout_h = zoneout_h | ||
self.zoneout_c = zoneout_c | ||
|
||
def forward( | ||
self, inputs: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor] | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
h, c = self.cell(inputs) | ||
prev_h, prev_c = state | ||
h = self._zoneout(prev_h, h, self.zoneout_h) | ||
c = self._zoneout(prev_c, c, self.zoneout_c) | ||
return h, c | ||
|
||
def _zoneout(self, prev_state: torch.Tensor, curr_state: torch.Tensor, factor: float): | ||
""" | ||
Apply Zoneout. | ||
|
||
:param prev: previous state tensor | ||
:param curr: current state tensor | ||
:param factor: drop probability | ||
""" | ||
if factor == 0.0: | ||
return curr_state | ||
if self.training: | ||
mask = curr_state.new_empty(size=curr_state.size()).bernoulli_(factor) | ||
return mask * prev_state + (1 - mask) * curr_state | ||
else: | ||
return factor * prev_state + (1 - factor) * curr_state |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from i6_models.decoder.attention import AdditiveAttention, AdditiveAttentionConfig | ||
from i6_models.decoder.attention import AttentionLSTMDecoderV1, AttentionLSTMDecoderV1Config | ||
|
||
|
||
def test_additive_attention(): | ||
cfg = AdditiveAttentionConfig(attention_dim=5, att_weights_dropout=0.1) | ||
att = AdditiveAttention(cfg) | ||
key = torch.rand((10, 20, 5)) | ||
value = torch.rand((10, 20, 5)) | ||
query = torch.rand((10, 5)) | ||
|
||
enc_seq_len = torch.arange(start=10, end=20) # [10, ..., 19] | ||
|
||
# pass key as weight feedback just for testing | ||
context, weights = att(key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len) | ||
assert context.shape == (10, 5) | ||
assert weights.shape == (10, 20, 1) | ||
|
||
# Testing attention weights masking: | ||
# for first seq, the enc seq length is 10 so half the weights should be 0 | ||
assert torch.eq(weights[0, 10:, 0], torch.tensor(0.0)).all() | ||
# test for other seqs | ||
assert torch.eq(weights[5, 15:, 0], torch.tensor(0.0)).all() | ||
|
||
|
||
def test_encoder_decoder_attention_model(): | ||
encoder = torch.rand((10, 20, 5)) | ||
encoder_seq_len = torch.arange(start=10, end=20) # [10, ..., 19] | ||
decoder_cfg = AttentionLSTMDecoderV1Config( | ||
encoder_dim=5, | ||
vocab_size=15, | ||
target_embed_dim=3, | ||
target_embed_dropout=0.1, | ||
lstm_hidden_size=12, | ||
attention_cfg=AdditiveAttentionConfig(attention_dim=10, att_weights_dropout=0.1), | ||
output_proj_dim=12, | ||
output_dropout=0.1, | ||
zoneout_drop_c=0.0, | ||
zoneout_drop_h=0.0, | ||
) | ||
decoder = AttentionLSTMDecoderV1(decoder_cfg) | ||
target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N] | ||
|
||
decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len) | ||
|
||
assert decoder_logits.shape == (10, 7, 15) | ||
|
||
|
||
def test_zoneout_lstm_cell(): | ||
encoder = torch.rand((10, 20, 5)) | ||
encoder_seq_len = torch.arange(start=10, end=20) # [10, ..., 19] | ||
target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N] | ||
|
||
def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float): | ||
decoder_cfg = AttentionLSTMDecoderV1Config( | ||
encoder_dim=5, | ||
vocab_size=15, | ||
target_embed_dim=3, | ||
target_embed_dropout=0.1, | ||
lstm_hidden_size=12, | ||
attention_cfg=AdditiveAttentionConfig(attention_dim=10, att_weights_dropout=0.1), | ||
output_proj_dim=12, | ||
output_dropout=0.1, | ||
zoneout_drop_c=zoneout_drop_c, | ||
zoneout_drop_h=zoneout_drop_h, | ||
) | ||
decoder = AttentionLSTMDecoderV1(decoder_cfg) | ||
decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len) | ||
return decoder_logits | ||
|
||
decoder_logits = forward_decoder(zoneout_drop_c=0.15, zoneout_drop_h=0.05) | ||
assert decoder_logits.shape == (10, 7, 15) | ||
|
||
decoder_logits = forward_decoder(zoneout_drop_c=0.0, zoneout_drop_h=0.0) | ||
assert decoder_logits.shape == (10, 7, 15) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The symbols in this docstring are partly undefined/different to the parameter names in
forward
. It would be easier to understand if the naming was unified.