From 378ace51259e8d2ef22e59ccafa687e3fe8edf88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20M=C3=BCller?= Date: Fri, 21 Dec 2018 11:52:11 +0100 Subject: [PATCH] Add fertility to count coverage (#612) * add fertility to count coverage --- CHANGELOG.md | 1 + sockeye/arguments.py | 12 ++++-- sockeye/constants.py | 11 ++++++ sockeye/coverage.py | 78 ++++++++++++++++++++++++++++++++++--- sockeye/train.py | 1 + test/unit/test_arguments.py | 1 + test/unit/test_attention.py | 5 ++- test/unit/test_coverage.py | 4 +- test/unit/test_decoder.py | 1 + 9 files changed, 101 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29a3519a7..93c6f375c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ Each version section may have have subsections for: _Added_, _Changed_, _Removed ## [1.18.67] ### Added +- Added `fertility` as a further type of attention coverage. - Added an option for training to keep the initializations of the model via `--keep-initializations`. When set, the trainer will avoid deleting the params file for the first checkpoint, no matter what `--keep-last-params` is set to. ## [1.18.66] diff --git a/sockeye/arguments.py b/sockeye/arguments.py index 44a8fda9b..2aaf4cf12 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -720,12 +720,18 @@ def add_model_parameters(params): '[Vaswani et al, 2017]') model_params.add_argument('--rnn-attention-coverage-type', - choices=["tanh", "sigmoid", "relu", "softrelu", "gru", "count"], - default="count", + choices=C.COVERAGE_TYPES, + default=C.COVERAGE_COUNT, help="Type of model for updating coverage vectors. 'count' refers to an update method " - "that accumulates attention scores. 'tanh', 'sigmoid', 'relu', 'softrelu' " + "that accumulates attention scores. 'fertility' accumulates attention scores as well " + "but also computes a fertility value for every source word. " + "'tanh', 'sigmoid', 'relu', 'softrelu' " "use non-linear layers with the respective activation type, and 'gru' uses a " "GRU to update the coverage vectors. Default: %(default)s.") + model_params.add_argument('--rnn-attention-coverage-max-fertility', + type=int, + default=2, + help="Maximum fertility for individual source words. Default: %(default)s.") model_params.add_argument('--rnn-attention-coverage-num-hidden', type=int, default=1, diff --git a/sockeye/constants.py b/sockeye/constants.py index 0feca80fe..161822dde 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -156,6 +156,17 @@ CNN_PAD_LEFT = "left" CNN_PAD_CENTERED = "centered" +# coverage types +COVERAGE_COUNT = "count" +COVERAGE_FERTILITY = "fertility" +COVERAGE_TYPES = [TANH, + SIGMOID, + RELU, + SOFT_RELU, + GRU_TYPE, + COVERAGE_COUNT, + COVERAGE_FERTILITY] + # default I/O variable names SOURCE_NAME = "source" SOURCE_LENGTH_NAME = "source_length" diff --git a/sockeye/coverage.py b/sockeye/coverage.py index 5b0925cee..83fba81ce 100644 --- a/sockeye/coverage.py +++ b/sockeye/coverage.py @@ -35,13 +35,16 @@ class CoverageConfig(config.Config): :param type: Coverage name. :param num_hidden: Number of hidden units for coverage networks. :param layer_normalization: Apply layer normalization to coverage networks. + :param max_fertility: Maximum number of target words generated by a source word. """ def __init__(self, type: str, num_hidden: int, - layer_normalization: bool) -> None: + layer_normalization: bool, + max_fertility: int = 2) -> None: super().__init__() self.type = type + self.max_fertility = max_fertility self.num_hidden = num_hidden self.layer_normalization = layer_normalization @@ -53,14 +56,16 @@ def get_coverage(config: CoverageConfig) -> 'Coverage': :param config: Coverage configuration. :return: Instance of Coverage. """ - if config.type == 'count': - utils.check_condition(config.num_hidden == 1, "Count coverage requires coverage_num_hidden==1") - if config.type == "gru": + if config.type == C.COVERAGE_COUNT or config.type == C.COVERAGE_FERTILITY: + utils.check_condition(config.num_hidden == 1, "Count or fertility coverage requires coverage_num_hidden==1") + if config.type == C.GRU_TYPE: return GRUCoverage(config.num_hidden, config.layer_normalization) - elif config.type in {"tanh", "sigmoid", "relu", "softrelu"}: + elif config.type in {C.TANH, C.SIGMOID, C.RELU, C.SOFT_RELU}: return ActivationCoverage(config.num_hidden, config.type, config.layer_normalization) - elif config.type == "count": + elif config.type == C.COVERAGE_COUNT: return CountCoverage() + elif config.type == C.COVERAGE_FERTILITY: + return FertilityCoverage(config.max_fertility) else: raise ValueError("Unknown coverage type %s" % config.type) @@ -129,6 +134,67 @@ def update_coverage(prev_hidden: mx.sym.Symbol, return update_coverage +class FertilityCoverage(Coverage): + """ + Coverage class that accumulates the attention weights for each source word, + and also computes a fertility value for each source word. + """ + + def __init__(self, max_fertility: int) -> None: + super().__init__() + self.max_fertility = max_fertility + # input (encoder) to fertility + self.cov_e2f_weight = mx.sym.Variable("%se2f_weight" % self.prefix) + + def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: + """ + Returns callable to be used for updating coverage vectors in a sequence decoder. + + :param source: Shape: (batch_size, seq_len, encoder_num_hidden). + :param source_length: Shape: (batch_size,). + :param source_seq_len: Maximum length of source sequences. + :return: Coverage callable. + """ + + # (batch_size, seq_len, 1) + source_fertility = mx.sym.FullyConnected(data=source, + weight=self.cov_e2f_weight, + no_bias=True, + num_hidden=1, + flatten=False, + name="%ssource_fertility_fc" % self.prefix) + + # (batch_size, seq_len, 1) + fertility = mx.sym.Activation(data=source_fertility, + act_type="sigmoid", + name="%sactivation" % self.prefix) + + # (batch_size, seq_len, 1) + scaled_fertility = 1 / (self.max_fertility * fertility) + + def update_coverage(prev_hidden: mx.sym.Symbol, + attention_prob_scores: mx.sym.Symbol, + prev_coverage: mx.sym.Symbol): + """ + :param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden). + :param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len). + :param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden). + :return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden). + """ + + # (batch_size, source_seq_len, 1) + expanded_att_scores = mx.sym.expand_dims(data=attention_prob_scores, + axis=2, + name="%sexpand_attention_scores" % self.prefix) + + # (batch_size, source_seq_len, 1) + new_coverage = scaled_fertility * expanded_att_scores + + return prev_coverage + new_coverage + + return update_coverage + + class GRUCoverage(Coverage): """ Implements a GRU whose state is the coverage vector. diff --git a/sockeye/train.py b/sockeye/train.py index f4cbcf4ea..d043c04c7 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -522,6 +522,7 @@ def create_decoder_config(args: argparse.Namespace, encoder_num_hidden: int, config_coverage = None if args.rnn_attention_type == C.ATT_COV: config_coverage = coverage.CoverageConfig(type=args.rnn_attention_coverage_type, + max_fertility=args.rnn_attention_coverage_max_fertility, num_hidden=args.rnn_attention_coverage_num_hidden, layer_normalization=args.layer_normalization) config_attention = rnn_attention.AttentionConfig(type=args.rnn_attention_type, diff --git a/test/unit/test_arguments.py b/test/unit/test_arguments.py index 6c3f0fa37..c067ca5b9 100644 --- a/test/unit/test_arguments.py +++ b/test/unit/test_arguments.py @@ -87,6 +87,7 @@ def test_device_args(test_params, expected_params): rnn_scale_dot_attention=False, rnn_attention_coverage_type='count', rnn_attention_coverage_num_hidden=1, + rnn_attention_coverage_max_fertility=2, weight_tying=False, weight_tying_type="trg_softmax", rnn_attention_mhdot_heads=None, diff --git a/test/unit/test_attention.py b/test/unit/test_attention.py index 2c86c9e9c..ab948de7d 100644 --- a/test/unit/test_attention.py +++ b/test/unit/test_attention.py @@ -151,7 +151,7 @@ def test_att_mlp(): def test_att_cov(): - config_coverage = sockeye.coverage.CoverageConfig(type='tanh', num_hidden=5, layer_normalization=True) + config_coverage = sockeye.coverage.CoverageConfig(type='tanh', max_fertility=2, num_hidden=5, layer_normalization=True) config_attention = sockeye.rnn_attention.AttentionConfig(type=C.ATT_COV, num_hidden=16, @@ -216,7 +216,7 @@ def test_attention(attention_type, assert np.isclose(attention_prob_result, np.asarray([[0.5, 0.5, 0.]])).all() -coverage_cases = [("gru", 10), ("tanh", 4), ("count", 1), ("sigmoid", 1), ("relu", 30)] +coverage_cases = [("gru", 10), ("tanh", 4), ("count", 1), ("sigmoid", 1), ("relu", 30), ("fertility", 1)] @pytest.mark.parametrize("attention_coverage_type,attention_coverage_num_hidden", coverage_cases) @@ -232,6 +232,7 @@ def test_coverage_attention(attention_coverage_type, source_seq_len = 10 config_coverage = sockeye.coverage.CoverageConfig(type=attention_coverage_type, + max_fertility=2, num_hidden=attention_coverage_num_hidden, layer_normalization=False) config_attention = sockeye.rnn_attention.AttentionConfig(type="coverage", diff --git a/test/unit/test_coverage.py b/test/unit/test_coverage.py index 3833b215f..183670a8f 100644 --- a/test/unit/test_coverage.py +++ b/test/unit/test_coverage.py @@ -42,7 +42,7 @@ def test_gru_coverage(): def _test_activation_coverage(act_type): - config_coverage = sockeye.coverage.CoverageConfig(type=act_type, num_hidden=2, layer_normalization=False) + config_coverage = sockeye.coverage.CoverageConfig(type=act_type, max_fertility=2, num_hidden=2, layer_normalization=False) encoder_num_hidden, decoder_num_hidden, source_seq_len, batch_size = 5, 5, 10, 4 # source: (batch_size, source_seq_len, encoder_num_hidden) source = mx.sym.Variable("source") @@ -89,7 +89,7 @@ def _test_activation_coverage(act_type): def _test_gru_coverage(): - config_coverage = sockeye.coverage.CoverageConfig(type="gru", num_hidden=2, layer_normalization=False) + config_coverage = sockeye.coverage.CoverageConfig(type="gru", num_hidden=2, max_fertility=2, layer_normalization=False) encoder_num_hidden, decoder_num_hidden, source_seq_len, batch_size = 5, 5, 10, 4 # source: (batch_size, source_seq_len, encoder_num_hidden) source = mx.sym.Variable("source") diff --git a/test/unit/test_decoder.py b/test/unit/test_decoder.py index 170eb74d2..034b14fa1 100644 --- a/test/unit/test_decoder.py +++ b/test/unit/test_decoder.py @@ -71,6 +71,7 @@ def test_step(cell_type, context_gating, states_shape = (batch_size, decoder_num_hidden) config_coverage = sockeye.coverage.CoverageConfig(type="tanh", + max_fertility=2, num_hidden=2, layer_normalization=False) config_attention = sockeye.rnn_attention.AttentionConfig(type="coverage",