Skip to content

Commit

Permalink
Add fertility to count coverage (#612)
Browse files Browse the repository at this point in the history
* add fertility to count coverage
  • Loading branch information
bricksdont authored and fhieber committed Dec 21, 2018
1 parent f42abe1 commit 378ace5
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Each version section may have have subsections for: _Added_, _Changed_, _Removed

## [1.18.67]
### Added
- Added `fertility` as a further type of attention coverage.
- Added an option for training to keep the initializations of the model via `--keep-initializations`. When set, the trainer will avoid deleting the params file for the first checkpoint, no matter what `--keep-last-params` is set to.

## [1.18.66]
Expand Down
12 changes: 9 additions & 3 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,12 +720,18 @@ def add_model_parameters(params):
'[Vaswani et al, 2017]')

model_params.add_argument('--rnn-attention-coverage-type',
choices=["tanh", "sigmoid", "relu", "softrelu", "gru", "count"],
default="count",
choices=C.COVERAGE_TYPES,
default=C.COVERAGE_COUNT,
help="Type of model for updating coverage vectors. 'count' refers to an update method "
"that accumulates attention scores. 'tanh', 'sigmoid', 'relu', 'softrelu' "
"that accumulates attention scores. 'fertility' accumulates attention scores as well "
"but also computes a fertility value for every source word. "
"'tanh', 'sigmoid', 'relu', 'softrelu' "
"use non-linear layers with the respective activation type, and 'gru' uses a "
"GRU to update the coverage vectors. Default: %(default)s.")
model_params.add_argument('--rnn-attention-coverage-max-fertility',
type=int,
default=2,
help="Maximum fertility for individual source words. Default: %(default)s.")
model_params.add_argument('--rnn-attention-coverage-num-hidden',
type=int,
default=1,
Expand Down
11 changes: 11 additions & 0 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@
CNN_PAD_LEFT = "left"
CNN_PAD_CENTERED = "centered"

# coverage types
COVERAGE_COUNT = "count"
COVERAGE_FERTILITY = "fertility"
COVERAGE_TYPES = [TANH,
SIGMOID,
RELU,
SOFT_RELU,
GRU_TYPE,
COVERAGE_COUNT,
COVERAGE_FERTILITY]

# default I/O variable names
SOURCE_NAME = "source"
SOURCE_LENGTH_NAME = "source_length"
Expand Down
78 changes: 72 additions & 6 deletions sockeye/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ class CoverageConfig(config.Config):
:param type: Coverage name.
:param num_hidden: Number of hidden units for coverage networks.
:param layer_normalization: Apply layer normalization to coverage networks.
:param max_fertility: Maximum number of target words generated by a source word.
"""
def __init__(self,
type: str,
num_hidden: int,
layer_normalization: bool) -> None:
layer_normalization: bool,
max_fertility: int = 2) -> None:
super().__init__()
self.type = type
self.max_fertility = max_fertility
self.num_hidden = num_hidden
self.layer_normalization = layer_normalization

Expand All @@ -53,14 +56,16 @@ def get_coverage(config: CoverageConfig) -> 'Coverage':
:param config: Coverage configuration.
:return: Instance of Coverage.
"""
if config.type == 'count':
utils.check_condition(config.num_hidden == 1, "Count coverage requires coverage_num_hidden==1")
if config.type == "gru":
if config.type == C.COVERAGE_COUNT or config.type == C.COVERAGE_FERTILITY:
utils.check_condition(config.num_hidden == 1, "Count or fertility coverage requires coverage_num_hidden==1")
if config.type == C.GRU_TYPE:
return GRUCoverage(config.num_hidden, config.layer_normalization)
elif config.type in {"tanh", "sigmoid", "relu", "softrelu"}:
elif config.type in {C.TANH, C.SIGMOID, C.RELU, C.SOFT_RELU}:
return ActivationCoverage(config.num_hidden, config.type, config.layer_normalization)
elif config.type == "count":
elif config.type == C.COVERAGE_COUNT:
return CountCoverage()
elif config.type == C.COVERAGE_FERTILITY:
return FertilityCoverage(config.max_fertility)
else:
raise ValueError("Unknown coverage type %s" % config.type)

Expand Down Expand Up @@ -129,6 +134,67 @@ def update_coverage(prev_hidden: mx.sym.Symbol,
return update_coverage


class FertilityCoverage(Coverage):
"""
Coverage class that accumulates the attention weights for each source word,
and also computes a fertility value for each source word.
"""

def __init__(self, max_fertility: int) -> None:
super().__init__()
self.max_fertility = max_fertility
# input (encoder) to fertility
self.cov_e2f_weight = mx.sym.Variable("%se2f_weight" % self.prefix)

def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable:
"""
Returns callable to be used for updating coverage vectors in a sequence decoder.
:param source: Shape: (batch_size, seq_len, encoder_num_hidden).
:param source_length: Shape: (batch_size,).
:param source_seq_len: Maximum length of source sequences.
:return: Coverage callable.
"""

# (batch_size, seq_len, 1)
source_fertility = mx.sym.FullyConnected(data=source,
weight=self.cov_e2f_weight,
no_bias=True,
num_hidden=1,
flatten=False,
name="%ssource_fertility_fc" % self.prefix)

# (batch_size, seq_len, 1)
fertility = mx.sym.Activation(data=source_fertility,
act_type="sigmoid",
name="%sactivation" % self.prefix)

# (batch_size, seq_len, 1)
scaled_fertility = 1 / (self.max_fertility * fertility)

def update_coverage(prev_hidden: mx.sym.Symbol,
attention_prob_scores: mx.sym.Symbol,
prev_coverage: mx.sym.Symbol):
"""
:param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden).
:param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len).
:param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden).
:return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden).
"""

# (batch_size, source_seq_len, 1)
expanded_att_scores = mx.sym.expand_dims(data=attention_prob_scores,
axis=2,
name="%sexpand_attention_scores" % self.prefix)

# (batch_size, source_seq_len, 1)
new_coverage = scaled_fertility * expanded_att_scores

return prev_coverage + new_coverage

return update_coverage


class GRUCoverage(Coverage):
"""
Implements a GRU whose state is the coverage vector.
Expand Down
1 change: 1 addition & 0 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def create_decoder_config(args: argparse.Namespace, encoder_num_hidden: int,
config_coverage = None
if args.rnn_attention_type == C.ATT_COV:
config_coverage = coverage.CoverageConfig(type=args.rnn_attention_coverage_type,
max_fertility=args.rnn_attention_coverage_max_fertility,
num_hidden=args.rnn_attention_coverage_num_hidden,
layer_normalization=args.layer_normalization)
config_attention = rnn_attention.AttentionConfig(type=args.rnn_attention_type,
Expand Down
1 change: 1 addition & 0 deletions test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_device_args(test_params, expected_params):
rnn_scale_dot_attention=False,
rnn_attention_coverage_type='count',
rnn_attention_coverage_num_hidden=1,
rnn_attention_coverage_max_fertility=2,
weight_tying=False,
weight_tying_type="trg_softmax",
rnn_attention_mhdot_heads=None,
Expand Down
5 changes: 3 additions & 2 deletions test/unit/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_att_mlp():


def test_att_cov():
config_coverage = sockeye.coverage.CoverageConfig(type='tanh', num_hidden=5, layer_normalization=True)
config_coverage = sockeye.coverage.CoverageConfig(type='tanh', max_fertility=2, num_hidden=5, layer_normalization=True)

config_attention = sockeye.rnn_attention.AttentionConfig(type=C.ATT_COV,
num_hidden=16,
Expand Down Expand Up @@ -216,7 +216,7 @@ def test_attention(attention_type,
assert np.isclose(attention_prob_result, np.asarray([[0.5, 0.5, 0.]])).all()


coverage_cases = [("gru", 10), ("tanh", 4), ("count", 1), ("sigmoid", 1), ("relu", 30)]
coverage_cases = [("gru", 10), ("tanh", 4), ("count", 1), ("sigmoid", 1), ("relu", 30), ("fertility", 1)]


@pytest.mark.parametrize("attention_coverage_type,attention_coverage_num_hidden", coverage_cases)
Expand All @@ -232,6 +232,7 @@ def test_coverage_attention(attention_coverage_type,
source_seq_len = 10

config_coverage = sockeye.coverage.CoverageConfig(type=attention_coverage_type,
max_fertility=2,
num_hidden=attention_coverage_num_hidden,
layer_normalization=False)
config_attention = sockeye.rnn_attention.AttentionConfig(type="coverage",
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_gru_coverage():


def _test_activation_coverage(act_type):
config_coverage = sockeye.coverage.CoverageConfig(type=act_type, num_hidden=2, layer_normalization=False)
config_coverage = sockeye.coverage.CoverageConfig(type=act_type, max_fertility=2, num_hidden=2, layer_normalization=False)
encoder_num_hidden, decoder_num_hidden, source_seq_len, batch_size = 5, 5, 10, 4
# source: (batch_size, source_seq_len, encoder_num_hidden)
source = mx.sym.Variable("source")
Expand Down Expand Up @@ -89,7 +89,7 @@ def _test_activation_coverage(act_type):


def _test_gru_coverage():
config_coverage = sockeye.coverage.CoverageConfig(type="gru", num_hidden=2, layer_normalization=False)
config_coverage = sockeye.coverage.CoverageConfig(type="gru", num_hidden=2, max_fertility=2, layer_normalization=False)
encoder_num_hidden, decoder_num_hidden, source_seq_len, batch_size = 5, 5, 10, 4
# source: (batch_size, source_seq_len, encoder_num_hidden)
source = mx.sym.Variable("source")
Expand Down
1 change: 1 addition & 0 deletions test/unit/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_step(cell_type, context_gating,
states_shape = (batch_size, decoder_num_hidden)

config_coverage = sockeye.coverage.CoverageConfig(type="tanh",
max_fertility=2,
num_hidden=2,
layer_normalization=False)
config_attention = sockeye.rnn_attention.AttentionConfig(type="coverage",
Expand Down

0 comments on commit 378ace5

Please sign in to comment.