Skip to content

Commit

Permalink
Fixed the maximum input length calculation at inference. (#255)
Browse files Browse the repository at this point in the history
* Fixed the maximum input length calculation at inference.

* doc string
  • Loading branch information
tdomhan authored and fhieber committed Dec 19, 2017
1 parent 078070a commit 47dc73f
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 33 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.15.8]
### Fixed
- Taking the BOS and EOS tag into account when calculating the maximum input length at inference.

## [1.15.7]
### Fixed
- fixed a problem with `--num-samples-per-shard` flag not being parsed as int.
- fixed a problem with `--num-samples-per-shard` flag not being parsed as int.

## [1.15.6]
### Added
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.15.7'
__version__ = '1.15.8'
6 changes: 3 additions & 3 deletions sockeye/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,13 @@ def sequence_pair(self,
source: List[int],
target: List[int],
bucket_idx: Optional[int]):
source_len = len(source)
target_len = len(target)

if bucket_idx is None:
self.num_discarded += 1
return

source_len = len(source)
target_len = len(target)

self._mean_len_target_per_bucket[bucket_idx].update(target_len)

self.num_sents += 1
Expand Down
83 changes: 60 additions & 23 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __init__(self,

self._build_model_components()

self.max_input_length, self.get_max_output_length = get_max_input_output_length([self],
max_output_length_num_stds)
self.max_input_length, self.get_max_output_length = models_max_input_output_length([self],
max_output_length_num_stds)

self.encoder_module = None # type: Optional[mx.mod.BucketingModule]
self.encoder_default_bucket_key = None # type: Optional[int]
Expand Down Expand Up @@ -401,17 +401,18 @@ def load_models(context: mx.context.Context,
utils.check_condition(vocab.are_identical(*target_vocabs), "Target vocabulary ids do not match")

# set a common max_output length for all models.
max_input_len, get_max_output_length = get_max_input_output_length(models,
max_output_length_num_stds,
max_input_len)
max_input_len, get_max_output_length = models_max_input_output_length(models,
max_output_length_num_stds,
max_input_len)
for model in models:
model.initialize(max_input_len, get_max_output_length)

return models, source_vocabs[0], target_vocabs[0]


def get_max_input_output_length(models: List[InferenceModel], num_stds: int,
max_input_len: Optional[int] = None) -> Tuple[int, Callable]:
def models_max_input_output_length(models: List[InferenceModel],
num_stds: int,
forced_max_input_len: Optional[int] = None) -> Tuple[int, Callable]:
"""
Returns a function to compute maximum output length given a fixed number of standard deviations as a
safety margin, and the current input length.
Expand All @@ -421,46 +422,83 @@ def get_max_input_output_length(models: List[InferenceModel], num_stds: int,
:param models: List of models.
:param num_stds: Number of standard deviations to add as a safety margin. If -1, returned maximum output lengths
will always be 2 * input_length.
:param max_input_len: An optional overwrite of the maximum input length.
:param forced_max_input_len: An optional overwrite of the maximum input length.
:return: The maximum input length and a function to get the output length given the input length.
"""
max_mean = max(model.length_ratio_mean for model in models)
max_std = max(model.length_ratio_std for model in models)

if num_stds < 0:
factor = C.TARGET_MAX_LENGTH_FACTOR # type: float
else:
factor = max_mean + (max_std * num_stds)

supported_max_seq_len_source = min((model.max_supported_seq_len_source for model in models
if model.max_supported_seq_len_source is not None),
default=None)
supported_max_seq_len_target = min((model.max_supported_seq_len_target for model in models
if model.max_supported_seq_len_target is not None),
default=None)

training_max_seq_len_source = min(model.training_max_seq_len_source for model in models)

if max_input_len is None:
return get_max_input_output_length(supported_max_seq_len_source,
supported_max_seq_len_target,
training_max_seq_len_source,
forced_max_input_len=forced_max_input_len,
length_ratio_mean=max_mean,
length_ratio_std=max_std,
num_stds=num_stds)


def get_max_input_output_length(supported_max_seq_len_source: Optional[int],
supported_max_seq_len_target: Optional[int],
training_max_seq_len_source: Optional[int],
forced_max_input_len: Optional[int],
length_ratio_mean: float,
length_ratio_std: float,
num_stds: int) -> Tuple[int, Callable]:
"""
Returns a function to compute maximum output length given a fixed number of standard deviations as a
safety margin, and the current input length. It takes into account optional maximum source and target lengths.
:param supported_max_seq_len_source: The maximum source length supported by the models.
:param supported_max_seq_len_target: The maximum target length supported by the models.
:param training_max_seq_len_source: The maximum source length observed during training.
:param forced_max_input_len: An optional overwrite of the maximum input length.
:param length_ratio_mean: The mean of the length ratio that was calculated on the raw sequences with special
symbols such as EOS or BOS.
:param length_ratio_std: The standard deviation of the length ratio.
:param num_stds: The number of standard deviations the target length may exceed the mean target length (as long as
the supported maximum length allows for this).
:return: The maximum input length and a function to get the output length given the input length.
"""
space_for_bos = 1
space_for_eos = 1

if num_stds < 0:
factor = C.TARGET_MAX_LENGTH_FACTOR # type: float
else:
factor = length_ratio_mean + (length_ratio_std * num_stds)

if forced_max_input_len is None:
# Make sure that if there is a hard constraint on the maximum source or target length we never exceed this
# constraint. This is for example the case for learned positional embeddings, which are only defined for the
# maximum source and target sequence length observed during training.
if supported_max_seq_len_source is not None and supported_max_seq_len_target is None:
max_input_len = supported_max_seq_len_source
elif supported_max_seq_len_source is None and supported_max_seq_len_target is not None:
if np.ceil(factor * training_max_seq_len_source) > supported_max_seq_len_target:
max_input_len = int(np.floor(supported_max_seq_len_target / factor))
max_output_len = supported_max_seq_len_target - space_for_bos - space_for_eos
if np.ceil(factor * training_max_seq_len_source) > max_output_len:
max_input_len = int(np.floor(max_output_len / factor))
else:
max_input_len = training_max_seq_len_source
elif supported_max_seq_len_source is not None or supported_max_seq_len_target is not None:
if np.ceil(factor * supported_max_seq_len_source) > supported_max_seq_len_target:
max_input_len = int(np.floor(supported_max_seq_len_target / factor))
max_output_len = supported_max_seq_len_target - space_for_bos - space_for_eos
if np.ceil(factor * supported_max_seq_len_source) > max_output_len:
max_input_len = int(np.floor(max_output_len / factor))
else:
max_input_len = supported_max_seq_len_source
else:
# Any source/target length is supported and max_input_len was not manually set, therefore we use the
# maximum length from training.
max_input_len = training_max_seq_len_source
else:
max_input_len = forced_max_input_len

def get_max_output_length(input_length: int):
"""
Expand All @@ -469,8 +507,7 @@ def get_max_output_length(input_length: int):
that the mean length ratio computed on the training data do not include these special symbols.
(see data_io.analyze_sequence_lengths)
"""
space_for_bos = 1
space_for_eos = 1

return int(np.ceil(factor * input_length)) + space_for_bos + space_for_eos

return max_input_len, get_max_output_length
Expand Down Expand Up @@ -737,7 +774,7 @@ def translate(self, trans_inputs: List[TranslatorInput]) -> List[TranslatorOutpu
translated_chunks = []

# split into chunks
input_chunks = [] # type: List[InputChunk]
input_chunks = [] # type: List[InputChunk]
for input_idx, trans_input in enumerate(trans_inputs):
if len(trans_input.tokens) == 0:
empty_translation = Translation(target_ids=[],
Expand Down Expand Up @@ -1043,7 +1080,7 @@ def _beam_search(self,
sliced_scores = scores if t == 1 and self.batch_size == 1 else scores[rows]
# TODO we could save some tiny amount of time here by not running smallest_k for a finished sent
(best_hyp_indices_np[rows], best_word_indices_np[rows]), \
scores_accumulated_np[rows] = utils.smallest_k(sliced_scores, self.beam_size, t == 1)
scores_accumulated_np[rows] = utils.smallest_k(sliced_scores, self.beam_size, t == 1)
# offsetting since the returned smallest_k() indices were slice-relative
best_hyp_indices_np[rows] += rows.start

Expand Down
3 changes: 2 additions & 1 deletion test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def tmp_digits_dataset(prefix: str,
" --output {output} {quiet}"

_TRAIN_PARAMS_PREPARED_DATA_COMMON = "--use-cpu --max-seq-len {max_len} --prepared-data {prepared_data}" \
" --validation-source {dev_source} --validation-target {dev_target} --output {model} {quiet}"
" --validation-source {dev_source} --validation-target {dev_target} " \
"--output {model} {quiet}"

_TRANSLATE_PARAMS_COMMON = "--use-cpu --models {model} --input {input} --output {output} {quiet}"

Expand Down
55 changes: 51 additions & 4 deletions test/unit/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

import mxnet as mx
import numpy as np
import pytest

import sockeye.inference


_BOS = 0
_EOS = -1

Expand All @@ -26,7 +26,7 @@ def test_concat_translations():
NUM_SRC = 7

def length_penalty(length):
return 1./length
return 1. / length

expected_score = (1 + 2 + 3) / length_penalty(len(expected_target_ids))

Expand All @@ -53,16 +53,63 @@ def test_length_penalty_default():
def test_length_penalty():
lengths = mx.nd.array([[1], [2], [3]])
length_penalty = sockeye.inference.LengthPenalty(.2, 5.0)
expected_lp = np.array([[6**0.2/6**0.2], [7**0.2/6**0.2], [8**0.2/6**0.2]])
expected_lp = np.array([[6 ** 0.2 / 6 ** 0.2], [7 ** 0.2 / 6 ** 0.2], [8 ** 0.2 / 6 ** 0.2]])

assert np.isclose(length_penalty(lengths).asnumpy(), expected_lp).all()


def test_length_penalty_int_input():
length = 1
length_penalty = sockeye.inference.LengthPenalty(.2, 5.0)
expected_lp = [6**0.2/6**0.2]
expected_lp = [6 ** 0.2 / 6 ** 0.2]

assert np.isclose(np.asarray([length_penalty(length)]),
np.asarray(expected_lp)).all()


@pytest.mark.parametrize("supported_max_seq_len_source, supported_max_seq_len_target, training_max_seq_len_source, "
"forced_max_input_len, length_ratio_mean, length_ratio_std, "
"expected_max_input_len, expected_max_output_len",
[
(100, 100, 100, None, 0.9, 0.2, 89, 100),
(100, 100, 100, None, 1.1, 0.2, 75, 100),
# No source length constraints.
(None, 100, 100, None, 0.9, 0.1, 98, 100),
# No target length constraints.
(80, None, 100, None, 1.1, 0.4, 80, 122),
# No source/target length constraints. Source is max observed during training and target
# based on length ratios.
(None, None, 100, None, 1.0, 0.1, 100, 113),
# Force a maximum input length.
(100, 100, 100, 50, 1.1, 0.2, 50, 67),
])
def test_get_max_input_output_length(
supported_max_seq_len_source,
supported_max_seq_len_target,
training_max_seq_len_source,
forced_max_input_len,
length_ratio_mean,
length_ratio_std,
expected_max_input_len,
expected_max_output_len):

max_input_len, get_max_output_len = sockeye.inference.get_max_input_output_length(
supported_max_seq_len_source=supported_max_seq_len_source,
supported_max_seq_len_target=supported_max_seq_len_target,
training_max_seq_len_source=training_max_seq_len_source,
forced_max_input_len=forced_max_input_len,
length_ratio_mean=length_ratio_mean,
length_ratio_std=length_ratio_std,
num_stds=1)
max_output_len = get_max_output_len(max_input_len)

if supported_max_seq_len_source is not None:
assert max_input_len <= supported_max_seq_len_source
if supported_max_seq_len_target is not None:
assert max_output_len <= supported_max_seq_len_target
if expected_max_input_len is not None:
assert max_input_len == expected_max_input_len
if expected_max_output_len is not None:
assert max_output_len == expected_max_output_len


0 comments on commit 47dc73f

Please sign in to comment.