Skip to content

Commit

Permalink
Added --fixed-param-names argument for freezing model parameters (#320)
Browse files Browse the repository at this point in the history
* Added fixed_param_names argument for freezing model parameters
  • Loading branch information
David Vilar authored and fhieber committed Mar 9, 2018
1 parent ac47312 commit 618e813
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 12 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.17.4]
### Added
- Added a flag `--fixed-param-names` to prevent certain parameters from being optimized during training.
This is useful if you want to keep pre-trained embeddings fixed during training.

## [1.17.3]
### Changed
- `sockeye.evaluate` can now handle multiple hypotheses files by simply specifying `--hypotheses file1 file2...`.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.17.3'
__version__ = '1.17.4'
10 changes: 8 additions & 2 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,11 @@ def add_training_args(params):
choices=[C.RNN_INIT_ORTHOGONAL, C.RNN_INIT_ORTHOGONAL_STACKED, C.RNN_INIT_DEFAULT],
help="Initialization method for RNN parameters. Default: %(default)s.")

train_params.add_argument('--fixed-param-names',
default=[],
nargs='*',
help="Names of parameters to fix at training time. Default: %(default)s.")

train_params.add_argument(C.TRAIN_ARGS_MONITOR_BLEU,
default=0,
type=int,
Expand Down Expand Up @@ -1088,8 +1093,9 @@ def add_init_embedding_args(params):
help='List of input vocabularies as token-index dictionaries in .json format.')
params.add_argument('--vocabularies-out', '-o', required=True, nargs='+',
help='List of output vocabularies as token-index dictionaries in .json format.')
params.add_argument('--names', '-n', required=True, nargs='+',
help='List of Sockeye parameter names for (embedding) weights.')
params.add_argument('--names', '-n', nargs='+',
help='List of Sockeye parameter names for (embedding) weights. Default: %(default)s.',
default=[n + "weight" for n in [C.SOURCE_EMBEDDING_PREFIX, C.TARGET_EMBEDDING_PREFIX]])
params.add_argument('--file', '-f', required=True,
help='File to write initialized parameters to.')
params.add_argument('--encoding', '-c', type=str, default=C.VOCAB_ENCODING,
Expand Down
8 changes: 4 additions & 4 deletions sockeye/init_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ def main():
"'output vocabularies' and 'Sockeye parameter names' should be provided.")
sys.exit(1)

params = {} # type: Dict[str, mx.nd.NDArray]
weight_file_cache = {} # type: Dict[str, np.ndarray]
for weight_file, vocab_in_file, vocab_out_file, name in zip(args.weight_files, args.vocabularies_in, \
args.vocabularies_out, args.names):
params = {} # type: Dict[str, mx.nd.NDArray]
weight_file_cache = {} # type: Dict[str, np.ndarray]
for weight_file, vocab_in_file, vocab_out_file, name in zip(args.weight_files, args.vocabularies_in,
args.vocabularies_out, args.names):
weight = load_weight(weight_file, name, weight_file_cache)
logger.info('Loading input/output vocabularies: %s %s', vocab_in_file, vocab_out_file)
vocab_in = vocab.vocab_from_json(vocab_in_file, encoding=args.encoding)
Expand Down
2 changes: 1 addition & 1 deletion sockeye/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import copy
import logging
import os
from typing import cast, Dict, Optional, Tuple
from typing import cast, Dict, Optional, Tuple, List

import mxnet as mx

Expand Down
3 changes: 2 additions & 1 deletion sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,8 @@ def create_training_model(config: model.ModelConfig,
provide_label=train_iter.provide_label,
default_bucket_key=train_iter.default_bucket_key,
bucketing=not args.no_bucketing,
gradient_compression_params=gradient_compression_params(args))
gradient_compression_params=gradient_compression_params(args),
fixed_param_names=args.fixed_param_names)

return training_model

Expand Down
13 changes: 10 additions & 3 deletions sockeye/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TrainingModel(model.SockeyeModel):
:param bucketing: If True bucketing will be used, if False the computation graph will always be
unrolled to the full length.
:param gradient_compression_params: Optional dictionary of gradient compression parameters.
:param fixed_param_names: Optional list of params to fix during training (i.e. their values will not be trained).
"""

def __init__(self,
Expand All @@ -63,10 +64,12 @@ def __init__(self,
provide_label: List[mx.io.DataDesc],
default_bucket_key: Tuple[int, int],
bucketing: bool,
gradient_compression_params: Optional[Dict[str, Any]] = None) -> None:
gradient_compression_params: Optional[Dict[str, Any]] = None,
fixed_param_names: Optional[List[str]] = None) -> None:
super().__init__(config)
self.context = context
self.output_dir = output_dir
self.fixed_param_names = fixed_param_names
self._bucketing = bucketing
self._gradient_compression_params = gradient_compression_params
self._initialize(provide_data, provide_label, default_bucket_key)
Expand Down Expand Up @@ -147,7 +150,8 @@ def sym_gen(seq_lens):
logger=logger,
default_bucket_key=default_bucket_key,
context=self.context,
compression_params=self._gradient_compression_params)
compression_params=self._gradient_compression_params,
fixed_param_names=self.fixed_param_names)
else:
logger.info("No bucketing. Unrolled to (%d,%d)",
self.config.config_data.max_seq_len_source, self.config.config_data.max_seq_len_target)
Expand All @@ -157,7 +161,8 @@ def sym_gen(seq_lens):
label_names=label_names,
logger=logger,
context=self.context,
compression_params=self._gradient_compression_params)
compression_params=self._gradient_compression_params,
fixed_param_names=self.fixed_param_names)

self.module.bind(data_shapes=provide_data,
label_shapes=provide_label,
Expand Down Expand Up @@ -290,6 +295,8 @@ def log_parameters(self):
info.append("%s: %s" % (name, array.shape))
total_parameters += reduce(lambda x, y: x * y, array.shape)
logger.info("Model parameters: %s", ", ".join(info))
if self.fixed_param_names:
logger.info("Fixed model parameters: %s", ", ".join(self.fixed_param_names))
logger.info("Total # of parameters: %d", total_parameters)

def save_params_to_file(self, fname: str):
Expand Down
1 change: 1 addition & 0 deletions test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def test_model_parameters(test_params, expected_params):
rnn_decoder_hidden_dropout=.0,
cnn_hidden_dropout=0.0,
rnn_forget_bias=0.0,
fixed_param_names=[],
rnn_h2h_init=C.RNN_INIT_ORTHOGONAL,
decode_and_evaluate=0,
decode_and_evaluate_use_cpu=False,
Expand Down

0 comments on commit 618e813

Please sign in to comment.