diff --git a/CHANGELOG.md b/CHANGELOG.md index 95d153b0e..d3dbe0b1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ Each version section may have have subsections for: _Added_, _Changed_, _Removed ### Added - Added a flag `--fixed-param-names` to prevent certain parameters from being optimized during training. This is useful if you want to keep pre-trained embeddings fixed during training. +- Added a flag `--dry-run` to `sockeye.train` to not perform any actual training, but print statistics about the model + and mode of operation. ## [1.17.3] ### Changed diff --git a/sockeye/arguments.py b/sockeye/arguments.py index 722854e87..53708c89b 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -926,6 +926,11 @@ def add_training_args(params): default=-1, help='Keep only the last n params files, use -1 to keep all files. Default: %(default)s') + train_params.add_argument('--dry-run', + action='store_true', + help="Do not perform any actual training, but print statistics about the model" + " and mode of operation.") + def add_train_cli_args(params): add_training_io_args(params) diff --git a/sockeye/train.py b/sockeye/train.py index a886bef1c..32dd30b97 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -19,6 +19,7 @@ import os import shutil import sys +import tempfile from contextlib import ExitStack from typing import Any, cast, Optional, Dict, List, Tuple @@ -724,6 +725,13 @@ def main(): arguments.add_train_cli_args(params) args = params.parse_args() + if args.dry_run: + # Modify arguments so that we write to a temporary directory and + # perform 0 training iterations + temp_dir = tempfile.TemporaryDirectory() # Will be automatically removed + args.output = temp_dir.name + args.max_updates = 0 + utils.seedRNGs(args.seed) check_arg_compatibility(args) diff --git a/test/unit/test_arguments.py b/test/unit/test_arguments.py index f9b73430f..0522450c8 100644 --- a/test/unit/test_arguments.py +++ b/test/unit/test_arguments.py @@ -173,7 +173,8 @@ def test_model_parameters(test_params, expected_params): decode_and_evaluate_use_cpu=False, decode_and_evaluate_device_id=None, seed=13, - keep_last_params=-1)), + keep_last_params=-1, + dry_run=False)), ]) def test_training_arg(test_params, expected_params): _test_args(test_params, expected_params, arguments.add_training_args)