Skip to content

Commit

Permalink
ikostrikov#109 Add type to discount factor (tau) in target network up…
Browse files Browse the repository at this point in the history
…date argument
  • Loading branch information
idobrusin committed Oct 14, 2020
1 parent 0572b6c commit 975063b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion a2c_ppo_acktr/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get_args(sysargs):
parser.add_argument('--learning.consistency_loss.target-model-update-frequency', default=-1,
help="If set to value > 0, the target model in the consistency loss will be updated with given "
"frequency")
parser.add_argument('--learning.consistency_loss.target-model-discount', default=1.0,
parser.add_argument('--learning.consistency_loss.target-model-discount', default=1.0, type=float,
help="Discount of 1.0: use new weights for target model. Discount < 1.0: exponential moving avg.")

# experiment
Expand Down
6 changes: 3 additions & 3 deletions a2c_ppo_acktr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def update_state_dict(model, state_dict, tau=1.0):
``tau==1`` applies hard update, copying the values, ``0<tau<1``
applies soft update: ``tau * new + (1 - tau) * old``.
"""
if tau == 1:
if tau == 1.0:
model.load_state_dict(state_dict)
elif tau > 0:
update_sd = {k: tau * state_dict[k] + (1 - tau) * v
elif tau > 0.0:
update_sd = {k: tau * state_dict[k] + (1.0 - tau) * v
for k, v in model.state_dict().items()}
model.load_state_dict(update_sd)
7 changes: 6 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def setup_dirs_and_logging(cfg: SimpleNamespace):
# curriculum_log_file = open(os.path.join(cfg.log_dir, "curriculum_log.json"), 'w')

# Copy original configuration if present to new location
shutil.copyfile(str(cfg.config[0]), os.path.join(cfg.log_dir, "config.yaml"))
shutil.copyfile(str(cfg.config[0]), os.path.join(cfg.log_dir, "config_original.yaml"))

with open(os.path.join(cfg.log_dir, "config.yaml"), 'w+') as file:
yaml.dump(vars(cfg), file)



return tb_writer, tb_writer_img, log_file

Expand Down

0 comments on commit 975063b

Please sign in to comment.