Skip to content

Commit

Permalink
Add an option to keep the first checkpoint, in addition to the last n (
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchellgordon95 authored and fhieber committed Dec 21, 2018
1 parent 390acde commit f42abe1
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.67]
### Added
- Added an option for training to keep the initializations of the model via `--keep-initializations`. When set, the trainer will avoid deleting the params file for the first checkpoint, no matter what `--keep-last-params` is set to.

## [1.18.66]
### Fixed
- Fix to argument names that are allowed to differ for resuming training.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.66'
__version__ = '1.18.67'
4 changes: 4 additions & 0 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,10 @@ def add_training_args(params):
default=-1,
help='Keep only the last n params files, use -1 to keep all files. Default: %(default)s')

train_params.add_argument('--keep-initializations',
action="store_true",
help='In addition to keeping the last n params files, also keep params from checkpoint 0.')

train_params.add_argument('--dry-run',
action='store_true',
help="Do not perform any actual training, but print statistics about the model"
Expand Down
1 change: 1 addition & 0 deletions sockeye/image_captioning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def train(args: argparse.Namespace):
optimizer_config=create_optimizer_config(args, [1.0],
extra_initializers),
max_params_files_to_keep=args.keep_last_params,
keep_initializations=args.keep_initializations,
source_vocabs=[None],
target_vocab=target_vocab)

Expand Down
1 change: 1 addition & 0 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ def train(args: argparse.Namespace) -> training.TrainState:
trainer = training.EarlyStoppingTrainer(model=training_model,
optimizer_config=create_optimizer_config(args, source_vocab_sizes),
max_params_files_to_keep=args.keep_last_params,
keep_initializations=args.keep_initializations,
source_vocabs=source_vocabs,
target_vocab=target_vocab)

Expand Down
7 changes: 5 additions & 2 deletions sockeye/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ class EarlyStoppingTrainer:
:param model: TrainingModel instance.
:param optimizer_config: The optimizer configuration.
:param max_params_files_to_keep: Maximum number of params files to keep in the output folder (last n are kept).
:param keep_initializations: Regardless of number of params to keep, never delete the first checkpoint.
:param source_vocabs: Source vocabulary (and optional source factor vocabularies).
:param target_vocab: Target vocabulary.
"""
Expand All @@ -428,11 +429,13 @@ def __init__(self,
model: TrainingModel,
optimizer_config: OptimizerConfig,
max_params_files_to_keep: int,
keep_initializations: bool,
source_vocabs: List[vocab.Vocab],
target_vocab: vocab.Vocab) -> None:
self.model = model
self.optimizer_config = optimizer_config
self.max_params_files_to_keep = max_params_files_to_keep
self.keep_initializations = keep_initializations
self.tflogger = TensorboardLogger(logdir=os.path.join(model.output_dir, C.TENSORBOARD_NAME),
source_vocab=source_vocabs[0],
target_vocab=target_vocab)
Expand Down Expand Up @@ -758,7 +761,7 @@ def _cleanup(self, lr_decay_opt_states_reset: str, process_manager: Optional['De
Cleans parameter files, training state directory and waits for remaining decoding processes.
"""
utils.cleanup_params_files(self.model.output_dir, self.max_params_files_to_keep,
self.state.checkpoint, self.state.best_checkpoint)
self.state.checkpoint, self.state.best_checkpoint, self.keep_initializations)
if process_manager is not None:
result = process_manager.collect_results()
if result is not None:
Expand Down Expand Up @@ -922,7 +925,7 @@ def _save_params(self):
"""
self.model.save_params_to_file(self.current_params_fname)
utils.cleanup_params_files(self.model.output_dir, self.max_params_files_to_keep, self.state.checkpoint,
self.state.best_checkpoint)
self.state.best_checkpoint, self.keep_initializations)

def _save_training_state(self, train_iter: data_io.BaseParallelSampleIter):
"""
Expand Down
5 changes: 3 additions & 2 deletions sockeye/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,20 +898,21 @@ def metric_value_is_better(new: float, old: float, metric: str) -> bool:
return new < old


def cleanup_params_files(output_folder: str, max_to_keep: int, checkpoint: int, best_checkpoint: int):
def cleanup_params_files(output_folder: str, max_to_keep: int, checkpoint: int, best_checkpoint: int, keep_first: bool):
"""
Deletes oldest parameter files from a model folder.
:param output_folder: Folder where param files are located.
:param max_to_keep: Maximum number of files to keep, negative to keep all.
:param checkpoint: Current checkpoint (i.e. index of last params file created).
:param best_checkpoint: Best checkpoint. The parameter file corresponding to this checkpoint will not be deleted.
:param keep_first: Don't delete the first checkpoint.
"""
if max_to_keep <= 0:
return
existing_files = glob.glob(os.path.join(output_folder, C.PARAMS_PREFIX + "*"))
params_name_with_dir = os.path.join(output_folder, C.PARAMS_NAME)
for n in range(0, max(1, checkpoint - max_to_keep + 1)):
for n in range(1 if keep_first else 0, max(1, checkpoint - max_to_keep + 1)):
if n != best_checkpoint:
param_fname_n = params_name_with_dir % n
if param_fname_n in existing_files:
Expand Down
1 change: 1 addition & 0 deletions test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def test_model_parameters(test_params, expected_params):
decode_and_evaluate_device_id=None,
seed=13,
keep_last_params=-1,
keep_initializations=False,
rnn_enc_last_hidden_concat_to_embedding=False,
dry_run=False)),
])
Expand Down
23 changes: 18 additions & 5 deletions test/unit/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,26 @@


def test_cleanup_param_files():
with tempfile.TemporaryDirectory() as tmpDir:
with tempfile.TemporaryDirectory() as tmp_dir:
for n in itertools.chain(range(1, 20, 2), range(21, 41)):
# Create empty files
open(os.path.join(tmpDir, C.PARAMS_NAME % n), "w").close()
sockeye.utils.cleanup_params_files(tmpDir, 5, 40, 17)
open(os.path.join(tmp_dir, C.PARAMS_NAME % n), "w").close()
sockeye.utils.cleanup_params_files(tmp_dir, 5, 40, 17, False)

expectedSurviving = set([os.path.join(tmpDir, C.PARAMS_NAME % n)
expectedSurviving = set([os.path.join(tmp_dir, C.PARAMS_NAME % n)
for n in [17, 36, 37, 38, 39, 40]])
# 17 must survive because it is the best one
assert set(glob.glob(os.path.join(tmpDir, C.PARAMS_PREFIX + "*"))) == expectedSurviving
assert set(glob.glob(os.path.join(tmp_dir, C.PARAMS_PREFIX + "*"))) == expectedSurviving

def test_cleanup_param_files_keep_first():
with tempfile.TemporaryDirectory() as tmp_dir:
for n in itertools.chain(range(0, 20, 2), range(21, 41)):
# Create empty files
open(os.path.join(tmp_dir, C.PARAMS_NAME % n), "w").close()
sockeye.utils.cleanup_params_files(tmp_dir, 5, 40, 16, True)

expectedSurviving = set([os.path.join(tmp_dir, C.PARAMS_NAME % n)
for n in [0, 16, 36, 37, 38, 39, 40]])
# 16 must survive because it is the best one
# 0 should also survive because we set keep_first to True
assert set(glob.glob(os.path.join(tmp_dir, C.PARAMS_PREFIX + "*"))) == expectedSurviving

0 comments on commit f42abe1

Please sign in to comment.