From 975063b0144fdf93f3b0a9013f056e7359bd52f6 Mon Sep 17 00:00:00 2001 From: Ilia Dobrusin Date: Wed, 14 Oct 2020 12:55:14 +0200 Subject: [PATCH] #109 Add type to discount factor (tau) in target network update argument --- a2c_ppo_acktr/arguments.py | 2 +- a2c_ppo_acktr/utils.py | 6 +++--- main.py | 7 ++++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/a2c_ppo_acktr/arguments.py b/a2c_ppo_acktr/arguments.py index 43c5bec58..413ed7cc0 100644 --- a/a2c_ppo_acktr/arguments.py +++ b/a2c_ppo_acktr/arguments.py @@ -156,7 +156,7 @@ def get_args(sysargs): parser.add_argument('--learning.consistency_loss.target-model-update-frequency', default=-1, help="If set to value > 0, the target model in the consistency loss will be updated with given " "frequency") - parser.add_argument('--learning.consistency_loss.target-model-discount', default=1.0, + parser.add_argument('--learning.consistency_loss.target-model-discount', default=1.0, type=float, help="Discount of 1.0: use new weights for target model. Discount < 1.0: exponential moving avg.") # experiment diff --git a/a2c_ppo_acktr/utils.py b/a2c_ppo_acktr/utils.py index 9bac9c988..dd9e5497d 100644 --- a/a2c_ppo_acktr/utils.py +++ b/a2c_ppo_acktr/utils.py @@ -87,9 +87,9 @@ def update_state_dict(model, state_dict, tau=1.0): ``tau==1`` applies hard update, copying the values, ``0 0: - update_sd = {k: tau * state_dict[k] + (1 - tau) * v + elif tau > 0.0: + update_sd = {k: tau * state_dict[k] + (1.0 - tau) * v for k, v in model.state_dict().items()} model.load_state_dict(update_sd) diff --git a/main.py b/main.py index 05140bca0..2af17474a 100755 --- a/main.py +++ b/main.py @@ -138,7 +138,12 @@ def setup_dirs_and_logging(cfg: SimpleNamespace): # curriculum_log_file = open(os.path.join(cfg.log_dir, "curriculum_log.json"), 'w') # Copy original configuration if present to new location - shutil.copyfile(str(cfg.config[0]), os.path.join(cfg.log_dir, "config.yaml")) + shutil.copyfile(str(cfg.config[0]), os.path.join(cfg.log_dir, "config_original.yaml")) + + with open(os.path.join(cfg.log_dir, "config.yaml"), 'w+') as file: + yaml.dump(vars(cfg), file) + + return tb_writer, tb_writer_img, log_file