diff --git a/CHANGELOG.md b/CHANGELOG.md index 34e363492..c4fe39ea8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [1.18.27] +### Fixed +- Fix silent failing of NDArray splits during inference by using a version that always returns a list. This was causing incorrect behavior when using lexicon restriction and batch inference with a single source factor. + ## [1.18.26] ### Added - ROUGE score evaluation. It can be used as the stopping criterion for tasks such as summarization. diff --git a/sockeye/__init__.py b/sockeye/__init__.py index d05159401..503e14a61 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '1.18.26' +__version__ = '1.18.27' diff --git a/sockeye/inference.py b/sockeye/inference.py index ac4508d40..d50a4d9da 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -1334,7 +1334,7 @@ def _beam_search(self, """ Translates multiple sentences using beam search. - :param source: Source ids. Shape: (batch_size, bucket_key). + :param source: Source ids. Shape: (batch_size, bucket_key, num_factors). :param source_length: Max source length. :param raw_constraint_list: A list of optional lists containing phrases (as lists of target word IDs) that must appear in each output. @@ -1383,9 +1383,9 @@ def _beam_search(self, pad_dist = self.pad_dist vocab_slice_ids = None # type: mx.nd.NDArray if self.restrict_lexicon: + source_words = utils.split(source, num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0] # TODO: See note in method about migrating to pure MXNet when set operations are supported. # We currently convert source to NumPy and target ids back to NDArray. - source_words = source.split(num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0] vocab_slice_ids = self.restrict_lexicon.get_trg_ids(source_words.astype("int32").asnumpy()) if any(raw_constraint_list): # Add the constraint IDs to the list of permissibled IDs, and then project them into the reduced space diff --git a/sockeye/utils.py b/sockeye/utils.py index 2583dc6f6..958e13a07 100644 --- a/sockeye/utils.py +++ b/sockeye/utils.py @@ -860,3 +860,28 @@ def uncast_conditionally(data: mx.sym.Symbol, dtype: str) -> mx.sym.Symbol: if dtype != C.DTYPE_FP32: return mx.sym.cast(data=data, dtype=C.DTYPE_FP32) return data + + +def split(data: mx.nd.NDArray, + num_outputs: int, + axis: int = 1, + squeeze_axis: bool = False) -> List[mx.nd.NDArray]: + """ + Version of mxnet.ndarray.split that always returns a list. The original + implementation only returns a list if num_outputs > 1: + https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.split + + Splits an array along a particular axis into multiple sub-arrays. + + :param data: The input. + :param num_outputs: Number of splits. Note that this should evenly divide + the length of the axis. + :param axis: Axis along which to split. + :param squeeze_axis: If true, Removes the axis with length 1 from the shapes + of the output arrays. + :return: List of NDArrays resulting from the split. + """ + ndarray_or_list = data.split(num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis) + if num_outputs == 1: + return [ndarray_or_list] + return ndarray_or_list diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index 9b718b90c..73504e586 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -338,4 +338,12 @@ def test_metric_value_is_better(new, old, metric, result): assert utils.metric_value_is_better(new, old, metric) == result - +@pytest.mark.parametrize("num_factors", [1, 2, 3]) +def test_split(num_factors): + batch_size = 4 + bucket_key = 10 + # Simulates splitting factored input + data = mx.nd.random.normal(shape=(batch_size, bucket_key, num_factors)) + result = utils.split(data, num_outputs=num_factors, axis=2, squeeze_axis=True) + assert isinstance(result, list) + assert result[0].shape == (batch_size, bucket_key)