Skip to content

Commit

Permalink
Implement target update in consistency with discounting factor
Browse files Browse the repository at this point in the history
  • Loading branch information
idobrusin committed Oct 11, 2020
1 parent badfc2a commit 0572b6c
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 9 deletions.
38 changes: 30 additions & 8 deletions a2c_ppo_acktr/algo/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from a2c_ppo_acktr.combi_policy import CombiPolicy

from a2c_ppo_acktr.augmentation.augmenters import Augmenter
from a2c_ppo_acktr.utils import update_state_dict


def are_models_equal(actor_critic, target_actor_critic):
Expand Down Expand Up @@ -77,12 +78,15 @@ def update(self, rollouts):

if self.actor_critic.return_cnn_output:
_, actor_critic_action, _, _, _ = self.actor_critic.act(obs_batch, None, None, deterministic=True)
_, target_action, _, _, _ = self.target_actor_critic.act(obs_batch, None, None, deterministic=True)
with torch.no_grad():
_, target_action, _, _, _ = self.target_actor_critic.act(obs_batch, None, None, deterministic=True)
else:
_, actor_critic_action, _, _ = self.actor_critic.act(obs_batch, None, None, deterministic=True)
_, target_action, _, _ = self.target_actor_critic.act(obs_batch, None, None, deterministic=True)
with torch.no_grad():
_, target_action, _, _ = self.target_actor_critic.act(obs_batch, None, None, deterministic=True)

action_loss = torch.nn.functional.mse_loss(target_action.detach(), actor_critic_action)
# print("Actions: ", actor_critic_action == target_action)

aug_obs_batch_orig = next(augmentation_data_loader_iter)
# Move data to model's device
Expand All @@ -97,15 +101,31 @@ def update(self, rollouts):
masks_batch=None,
return_images=self.return_images)

action_loss.retain_grad() # retain grad for norm calculation
action_loss_aug.retain_grad()
action_loss_aug_weight = self.augmentation_loss_weight_function(self.current_num_steps, action_loss.item())
# action_loss.retain_grad() # retain grad for norm calculation
# action_loss_aug.retain_grad()
action_loss_aug_weight = self.augmentation_loss_weight_function(self.current_num_steps, action_loss)
action_loss_aug_weighted = action_loss_aug_weight * action_loss_aug

if self.force_ignore_loss_aug:
action_loss_sum = action_loss
else:
action_loss_sum = action_loss + action_loss_aug_weighted
#
# print("="*20)
# print("Params Diff")
# for params in zip(list(self.actor_critic.base.cnn.parameters())[0],
# list(self.target_actor_critic.base.cnn.parameters())[0]):
# params_model = params[0][0][0][0]
# params_target = params[1][0][0][0]
# print("Model : {}".format(params_model))
# print("Target: {}".format(params_target))
# diff = params_model - params_target
#
# print("Diff: {}".format(diff))
# total_diff = list(self.actor_critic.base.cnn.parameters())[0] - list(self.target_actor_critic.base.cnn.parameters())[0]
# print(total_diff)
# print(torch.sum(total_diff))
# print("="*20)

self.optimizer.zero_grad()
action_loss_sum.retain_grad() # retain grad for norm calculation
Expand Down Expand Up @@ -136,6 +156,8 @@ def update(self, rollouts):
if max_actions_batch >= update_log['action_max_value']:
update_log['action_max_value'] = max_actions_batch

# if self.update_count():
# pass
return update_log['value_loss'], update_log['action_loss'], update_log['dist_entropy'], update_log

def init_update_logging(self, with_augmentation=False):
Expand All @@ -162,9 +184,9 @@ def init_update_logging(self, with_augmentation=False):

return update_log

def update_target_critic(self):
self.target_actor_critic.load_state_dict(self.actor_critic.state_dict())
def update_target_critic(self, tau=1.0):
update_state_dict(model=self.actor_critic, state_dict=self.actor_critic.state_dict(), tau=tau)
self.target_actor_critic.eval()

for param in self.target_actor_critic.parameters():
param.requires_grad = False

5 changes: 5 additions & 0 deletions a2c_ppo_acktr/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ def get_args(sysargs):
help="If set, an additional evaluation will be performed on the given environment with the set eval_interval")
parser.add_argument('--learning.consistency_loss.force-disable-consistency', default=False, action="store_true",
help="If set to to true, the action loss is calculated without consistency loss")
parser.add_argument('--learning.consistency_loss.target-model-update-frequency', default=-1,
help="If set to value > 0, the target model in the consistency loss will be updated with given "
"frequency")
parser.add_argument('--learning.consistency_loss.target-model-discount', default=1.0,
help="Discount of 1.0: use new weights for target model. Discount < 1.0: exponential moving avg.")

# experiment
parser.add_argument('--experiment.num-bc-epochs', type=int, default=1000,
Expand Down
14 changes: 14 additions & 0 deletions a2c_ppo_acktr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,17 @@ def init(module, weight_init, bias_init, gain=1):
weight_init(module.weight.data, gain=gain)
bias_init(module.bias.data)
return module


def update_state_dict(model, state_dict, tau=1.0):
"""
Update the state dict of ``model`` using the input ``state_dict``/
``tau==1`` applies hard update, copying the values, ``0<tau<1``
applies soft update: ``tau * new + (1 - tau) * old``.
"""
if tau == 1:
model.load_state_dict(state_dict)
elif tau > 0:
update_sd = {k: tau * state_dict[k] + (1 - tau) * v
for k, v in model.state_dict().items()}
model.load_state_dict(update_sd)
5 changes: 4 additions & 1 deletion example_config_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ learning:
use_cnn_loss: false
use_action_loss_as_weight: false
clip_aug_actions: false
eval_target_env:
eval_target_env:
target_model_update_frequency: 0
target_model_discount: 1.0


experiment:
snapshot:
Expand Down
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,12 @@ def train(sysargs):
value_loss, action_loss, dist_entropy, update_log = agent.update(rollouts)
rollouts.after_update()

# Update target network in consistency loss
if cfg.learning.consistency_loss.target_model_update_frequency > 0 and \
(update_step % cfg.learning.consistency_loss.target_model_update_frequency == 0) and \
hasattr(agent, "update_target_critic"):
agent.update_target_critic(tau=cfg.learning.consistency_loss.target_model_discount)

# save for every interval-th episode or for the last epoch
if (update_step % cfg.experiment.save_interval == 0 or update_step == num_updates - 1) and cfg.save_dir != "":
last_model_save_path = save_model(cfg, envs, actor_critic, update_step)
Expand Down

0 comments on commit 0572b6c

Please sign in to comment.