From f407b365c0b63677fad7caeee41b5cf8e644c292 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Tue, 10 Jan 2023 13:39:40 +0800 Subject: [PATCH 1/5] feat: implement rambo --- .gitignore | 3 +- offlinerlkit/policy/__init__.py | 4 +- offlinerlkit/policy/model_based/rambo.py | 196 ++++++++++++++ .../policy_trainer/mb_policy_trainer.py | 10 +- run_example/run_rambo.py | 240 ++++++++++++++++++ 5 files changed, 450 insertions(+), 3 deletions(-) create mode 100644 offlinerlkit/policy/model_based/rambo.py create mode 100644 run_example/run_rambo.py diff --git a/.gitignore b/.gitignore index 3456339..c3aa7b4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ **/dist **/*.egg-info **/*.png -**/*.txt \ No newline at end of file +**/*.txt +**/.vscode \ No newline at end of file diff --git a/offlinerlkit/policy/__init__.py b/offlinerlkit/policy/__init__.py index 9acaa36..5b7e959 100755 --- a/offlinerlkit/policy/__init__.py +++ b/offlinerlkit/policy/__init__.py @@ -15,6 +15,7 @@ from offlinerlkit.policy.model_based.mopo_ensemble import MOPOEnsemblePolicy from offlinerlkit.policy.model_based.mobile import MOBILEPolicy from offlinerlkit.policy.model_based.mobile_ensemble import MOBILEEnsemblePolicy +from offlinerlkit.policy.model_based.rambo import RAMBOPolicy __all__ = [ @@ -30,5 +31,6 @@ "CMOPOPolicy", "MOPOEnsemblePolicy", "MOBILEPolicy", - "MOBILEEnsemble" + "MOBILEEnsemble", + "RAMBOPolicy" ] \ No newline at end of file diff --git a/offlinerlkit/policy/model_based/rambo.py b/offlinerlkit/policy/model_based/rambo.py new file mode 100644 index 0000000..e323038 --- /dev/null +++ b/offlinerlkit/policy/model_based/rambo.py @@ -0,0 +1,196 @@ +import numpy as np +import torch +import torch.nn as nn +import gym +import os + +from torch.nn import functional as F +from typing import Dict, Union, Tuple +from collections import defaultdict +from operator import itemgetter +from offlinerlkit.policy import MOPOPolicy +from offlinerlkit.dynamics import BaseDynamics + + +class RAMBOPolicy(MOPOPolicy): + """ + RAMBO-RL: Robust Adversarial Model-Based Offline Reinforcement Learning + """ + + def __init__( + self, + dynamics: BaseDynamics, + actor: nn.Module, + critic1: nn.Module, + critic2: nn.Module, + actor_optim: torch.optim.Optimizer, + critic1_optim: torch.optim.Optimizer, + critic2_optim: torch.optim.Optimizer, + dynamics_adv_optim: torch.optim.Optimizer, + tau: float = 0.005, + gamma: float = 0.99, + alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, + adv_weight: float=0, + adv_rollout_batch_size: int=256, + adv_rollout_length: int=5, + include_ent_in_adv: bool=False, # CHECK 这里是不是False + device="cpu" + ) -> None: + super().__init__( + dynamics, + actor, + critic1, + critic2, + actor_optim, + critic1_optim, + critic2_optim, + tau=tau, + gamma=gamma, + alpha=alpha + ) + + self._dynmics_adv_optim = dynamics_adv_optim + self._adv_weight = adv_weight + self._adv_rollout_batch_size = adv_rollout_batch_size + self._adv_rollout_length = adv_rollout_length + self._include_ent_in_adv = include_ent_in_adv + self.device = device + + def load(self, path): + self.load_state_dict(torch.load(path, map_location="cpu")) + + def pretrain(self, data: Dict, n_epoch, batch_size, lr, logger) -> None: + self._bc_optim = torch.optim.Adam(self.actor.parameters(), lr=lr) + observations = data["observations"] + actions = data["actions"] + sample_num = observations.shape[0] + idxs = np.arange(sample_num) + + logger.log("Pretraining policy") + self.actor.train() + for i_epoch in range(n_epoch): + np.random.shuffle(idxs) + sum_loss = 0 + for i_batch in range(sample_num // batch_size): + batch_obs = observations[i_batch * batch_size: (i_batch + 1) * batch_size] + batch_act = actions[i_batch * batch_size: (i_batch + 1) * batch_size] + batch_obs = torch.from_numpy(batch_obs).to(self.device) + batch_act = torch.from_numpy(batch_act).to(self.device) + dist = self.actor(batch_obs) + log_prob = dist.log_prob(batch_act) + bc_loss = - log_prob.mean() + + self._bc_optim.zero_grad() + bc_loss.backward() + self._bc_optim.step() + sum_loss += bc_loss.cpu().item() + print(f"Epoch {i_epoch}, mean bc loss {sum_loss/i_batch}") + torch.save(self.state_dict(), os.path.join(logger.model_dir, "rambo_pretrain.pt")) + + + def update_dynamics( + self, + real_buffer, + ) -> Tuple[Dict[str, np.ndarray], Dict]: + all_loss_info = { + "all_loss": 0, + "sl_loss": 0, + "adv_loss": 0 + } + steps = 0 + while steps < 1000: + init_obss = real_buffer.sample(self._adv_batch_size)["observations"].cpu().numpy() + observations = init_obss + for t in range(self._adv_rollout_length): + actions = self.select_action(observations) + sl_observations, sl_actions, sl_next_observations, sl_rewards = \ + itemgetter("observations", "actions", "next_observations", "rewards")(real_buffer.sample(self._adv_batch_size)) + next_observations, terminals, loss_info = self.dynamics_step_and_forward(observations, actions, sl_observations, sl_actions, sl_next_observations, sl_rewards) + all_loss_info["all_loss"] += loss_info["all_loss"] + all_loss_info["adv_loss"] += loss_info["adv_loss"] + all_loss_info["sl_loss"] += loss_info["sl_loss"] + + nonterm_mask = (~terminals).flatten() + steps += 1 + observations = next_observations[nonterm_mask] + if nonterm_mask.sum() == 0: + break + if steps == 1000: + break + return {_key: _value/steps for _key, _value in all_loss_info.items()} + + + def dynamics_step_and_forward( + self, + observations, + actions, + sl_observations, + sl_actions, + sl_next_observations, + sl_rewards, + ): + obs_act = np.concatenate([observations, actions], axis=-1) + obs_act = self.dynamics.scaler.transform(obs_act) + with torch.no_grad(): + mean, logvar = self.dynamics.model(obs_act) + # mean = mean.cpu().numpy() + # logvar = logvar.cpu().numpy() + observations = torch.from_numpy(observations).to(mean.device) + mean[..., :-1] += observations + std = torch.sqrt(torch.exp(logvar)) + _noise_generator = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(mean)) + noise = _noise_generator.sample() + + # select the next observations + sample_size = mean.shape[1] + selected_indexes = np.random.randint(0, noise.shape[0], size=sample_size) + noise = noise[selected_indexes, np.arange(sample_size)] + sample = mean + noise * std + next_observations = sample[..., :-1][selected_indexes, np.arange(sample_size)] + rewards = sample[..., -1][selected_indexes, np.arange(sample_size)] + terminals = np.squeeze(self.dynamics.terminal_fn(observations.detach().cpu().numpy(), actions, next_observations.detach().cpu().numpy())) + # terminals = torch.from_numpy(terminals).to(mean.device) + # evaluate the noises + log_prob = _noise_generator.log_prob(noise) + log_prob = log_prob.exp().sum(dim=0).log().sum(-1) + + # compute the advantage + with torch.no_grad(): + next_actions, next_policy_log_prob = self.actforward(next_observations, deterministic=True) + next_q = torch.minimum( + self.critic1(next_observations, next_actions), + self.critic2(next_observations, next_actions) + ) + if self._include_ent_in_adv: + next_q = next_q - self._alpha * next_policy_log_prob + value = rewards.unsqueeze(1) + (1-torch.from_numpy(terminals).to(mean.device).float().unsqueeze(1)) * self._gamma * next_q + + q = torch.minimum( + self.critic1(observations, actions), + self.critic2(observations, actions) + ) + advantage = q - value + adv_loss = (log_prob * advantage).mean() + + # compute the supervised loss + sl_input = torch.cat([sl_observations, sl_actions], dim=-1).cpu().numpy() + sl_target = torch.cat([sl_next_observations-sl_observations, sl_rewards], dim=-1) + sl_input = self.dynamics.transform(sl_input) + sl_mean, sl_logvar = self.dynamics.model(sl_input) + sl_inv_var = torch.exp(-sl_logvar) + sl_mse_loss_inv = (torch.pow(sl_mean - sl_target, 2) * sl_inv_var).mean(dim=(1, 2)) + sl_var_loss = sl_logvar.mean(dim=(1, 2)) + sl_loss = sl_mse_loss_inv.sum() + sl_var_loss.sum() + sl_loss = sl_loss + self.dynamics.model.get_decay_loss() + sl_loss = sl_loss + 0.01 * self.dynamics.model.max_logvar.sum() - 0.01 * self.dynamics.model.min_logvar.sum() + + all_loss = self._adv_weight * adv_loss + sl_loss + self._dynmics_adv_optim.zero_grad() + all_loss.backward() + self._dynmics_adv_optim.step() + + return next_observations.cpu().numpy(), terminals, { + "all_loss": all_loss.cpu().item(), + "sl_loss": sl_loss.cpu().item(), + "adv_loss": adv_loss.cpu().item() + } \ No newline at end of file diff --git a/offlinerlkit/policy_trainer/mb_policy_trainer.py b/offlinerlkit/policy_trainer/mb_policy_trainer.py index a75cc37..6d622b2 100755 --- a/offlinerlkit/policy_trainer/mb_policy_trainer.py +++ b/offlinerlkit/policy_trainer/mb_policy_trainer.py @@ -30,7 +30,8 @@ def __init__( eval_episodes: int = 10, normalize_obs: bool = False, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - oracle_dynamics=None + oracle_dynamics=None, + dynamics_update_freq: int = 0 ) -> None: self.policy = policy self.eval_env = eval_env @@ -40,6 +41,7 @@ def __init__( self._rollout_freq, self._rollout_batch_size, \ self._rollout_length = rollout_setting + self._dynamics_update_freq = dynamics_update_freq self._epoch = epoch self._step_per_epoch = step_per_epoch @@ -85,6 +87,12 @@ def train(self) -> Dict[str, float]: for k, v in loss.items(): self.logger.logkv_mean(k, v) + # update the dynamics if necessary + if 0 < self._dynamics_update_freq and (num_timesteps+1)%self._dynamics_update_freq == 0: + dynamics_update_info = self.policy.update_dynamics(self.real_buffer) + for k, v in dynamics_update_info.items(): + self.logger.logkv_mean(k, v) + num_timesteps += 1 if self.lr_scheduler is not None: diff --git a/run_example/run_rambo.py b/run_example/run_rambo.py new file mode 100644 index 0000000..f8ae07f --- /dev/null +++ b/run_example/run_rambo.py @@ -0,0 +1,240 @@ +import argparse +import os +import sys +import random + +import gym +import d4rl + +import numpy as np +import torch + + +from offlinerlkit.nets import MLP +from offlinerlkit.modules import ActorProb, Critic, TanhDiagGaussian, EnsembleDynamicsModel +from offlinerlkit.dynamics import EnsembleDynamics +from offlinerlkit.dynamics import MujocoOracleDynamics +from offlinerlkit.utils.scaler import StandardScaler +from offlinerlkit.utils.termination_fns import get_termination_fn +from offlinerlkit.utils.load_dataset import qlearning_dataset +from offlinerlkit.buffer import ReplayBuffer +from offlinerlkit.utils.logger import Logger, make_log_dirs +from offlinerlkit.policy_trainer import MBPolicyTrainer +from offlinerlkit.policy import RAMBOPolicy + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--algo-name", type=str, default="rambo") + parser.add_argument("--task", type=str, default="hopper-medium-replay-v2") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--actor-lr", type=float, default=1e-4) + parser.add_argument("--critic-lr", type=float, default=3e-4) + parser.add_argument("--dynamics-lr", type=float, default=3e-4) + parser.add_argument("--dynamics-adv-lr", type=float, default=3e-4) + parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256]) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--tau", type=float, default=0.005) + parser.add_argument("--alpha", type=float, default=0.2) + parser.add_argument("--auto-alpha", default=True) + parser.add_argument("--target-entropy", type=int, default=-3) + parser.add_argument("--alpha-lr", type=float, default=3e-4) + + parser.add_argument("--dynamics-hidden-dims", type=int, nargs='*', default=[200, 200, 200, 200]) + parser.add_argument("--dynamics-weight-decay", type=float, nargs='*', default=[2.5e-5, 5e-5, 7.5e-5, 7.5e-5, 1e-4]) + parser.add_argument("--n-ensemble", type=int, default=7) + parser.add_argument("--n-elites", type=int, default=5) + parser.add_argument("--rollout-freq", type=int, default=250) + parser.add_argument("--dynamics-update-freq", type=int, default=1000) + parser.add_argument("--adv-batch-size", type=int, default=256) + parser.add_argument("--rollout-batch-size", type=int, default=50000) + parser.add_argument("--rollout-length", type=int, default=5) + parser.add_argument("--adv-weight", type=float, default=0) + parser.add_argument("--model-retain-epochs", type=int, default=5) + parser.add_argument("--real-ratio", type=float, default=0.5) + parser.add_argument("--load-dynamics-path", type=str, default=None) + + parser.add_argument("--epoch", type=int, default=1000) + parser.add_argument("--step-per-epoch", type=int, default=1000) + parser.add_argument("--eval_episodes", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + + parser.add_argument("--include-ent-in-adv", type=int, default=0) + parser.add_argument("--do-bc", type=int, default=1) + parser.add_argument("--load-bc", type=str, default=None) + parser.add_argument("--bc-lr", type=float, default=1e-4) + parser.add_argument("--bc-epoch", type=int, default=50) + parser.add_argument("--bc-batch-size", type=int, default=256) + + return parser.parse_args() + + +def train(args=get_args()): + # create env and dataset + env = gym.make(args.task) + dataset = qlearning_dataset(env) + if 'antmaze' in args.task: + dataset["rewards"] -= 1.0 + args.obs_shape = env.observation_space.shape + args.action_dim = np.prod(env.action_space.shape) + args.max_action = env.action_space.high[0] + + # seed + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.deterministic = True + env.seed(args.seed) + + # create policy model + actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims) + critic1_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims) + critic2_backbone = MLP(input_dim=np.prod(args.obs_shape) + args.action_dim, hidden_dims=args.hidden_dims) + dist = TanhDiagGaussian( + latent_dim=getattr(actor_backbone, "output_dim"), + output_dim=args.action_dim, + unbounded=True, + conditioned_sigma=True + ) + actor = ActorProb(actor_backbone, dist, args.device) + critic1 = Critic(critic1_backbone, args.device) + critic2 = Critic(critic2_backbone, args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + # CHECK: do anealing? + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(actor_optim, args.epoch) + + if args.auto_alpha: + target_entropy = args.target_entropy if args.target_entropy \ + else -np.prod(env.action_space.shape) + + args.target_entropy = target_entropy + + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + alpha = (target_entropy, log_alpha, alpha_optim) + else: + alpha = args.alpha + + # create dynamics + load_dynamics_model = True if args.load_dynamics_path else False + dynamics_model = EnsembleDynamicsModel( + obs_dim=np.prod(args.obs_shape), + action_dim=args.action_dim, + hidden_dims=args.dynamics_hidden_dims, + num_ensemble=args.n_ensemble, + num_elites=args.n_elites, + weight_decays=args.dynamics_weight_decay, + device=args.device + ) + dynamics_optim = torch.optim.Adam( + dynamics_model.parameters(), + lr=args.dynamics_lr + ) + dynamics_adv_optim = torch.optim.Adam( + dynamics_model.parameters(), + lr=args.dynamics_adv_lr + ) + scaler = StandardScaler() + termination_fn = get_termination_fn(task=args.task) + dynamics = EnsembleDynamics( + dynamics_model, + dynamics_optim, + scaler, + termination_fn, + ) + + if args.load_dynamics_path: + dynamics.load(args.load_dynamics_path) + + oracle_dynamics = MujocoOracleDynamics(env) + + # create policy + policy = RAMBOPolicy( + dynamics, + actor, + critic1, + critic2, + actor_optim, + critic1_optim, + critic2_optim, + dynamics_adv_optim, + tau=args.tau, + gamma=args.gamma, + alpha=alpha, + adv_weight=args.adv_weight, + adv_rollout_length=args.rollout_length, + adv_rollout_batch_size=args.adv_batch_size, + device=args.device + ).to(args.device) + + # create buffer + real_buffer = ReplayBuffer( + buffer_size=len(dataset["observations"]), + obs_shape=args.obs_shape, + obs_dtype=np.float32, + action_dim=args.action_dim, + action_dtype=np.float32, + device=args.device + ) + real_buffer.load_dataset(dataset) + fake_buffer_size = args.step_per_epoch // args.rollout_freq * args.model_retain_epochs * args.rollout_batch_size * args.rollout_length + fake_buffer = ReplayBuffer( + buffer_size=fake_buffer_size, + obs_shape=args.obs_shape, + obs_dtype=np.float32, + action_dim=args.action_dim, + action_dtype=np.float32, + device=args.device + ) + + # log + log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args)) + # key: output file name, value: output handler type + output_config = { + "consoleout_backup": "stdout", + "policy_training_progress": "csv", + "dynamics_training_progress": "csv", + "tb": "tensorboard" + } + logger = Logger(log_dirs, output_config) + logger.log_hyperparameters(vars(args)) + + # create policy trainer + policy_trainer = MBPolicyTrainer( + policy=policy, + eval_env=env, + real_buffer=real_buffer, + fake_buffer=fake_buffer, + logger=logger, + rollout_setting=(args.rollout_freq, args.rollout_batch_size, args.rollout_length), + dynamics_update_freq=args.dynamics_update_freq, + epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + batch_size=args.batch_size, + real_ratio=args.real_ratio, + eval_episodes=args.eval_episodes, + lr_scheduler=lr_scheduler, + oracle_dynamics=oracle_dynamics, + normalize_obs=True + ) + + # train + if args.do_bc: + if args.load_bc is not None: + policy.load(args.load_bc) + policy.to(args.device) + else: + policy.pretrain(real_buffer.sample_all(), args.bc_epoch, args.bc_batch_size, args.bc_lr, logger) + if not load_dynamics_model: + dynamics.train(real_buffer.sample_all(), logger) + + policy_trainer.train() + + +if __name__ == "__main__": + train() \ No newline at end of file From bc9a74f9cab5147fadbab32867662d6a1b0e3619 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Tue, 10 Jan 2023 19:02:38 +0800 Subject: [PATCH 2/5] chore: updates --- offlinerlkit/dynamics/ensemble_dynamics.py | 4 +- offlinerlkit/policy/model_based/mopo.py | 1 - offlinerlkit/policy/model_based/rambo.py | 72 ++++++++++--------- .../policy_trainer/mb_policy_trainer.py | 17 +++-- offlinerlkit/utils/termination_fns.py | 6 ++ run_example/run_rambo.py | 72 ++++++++++--------- 6 files changed, 93 insertions(+), 79 deletions(-) diff --git a/offlinerlkit/dynamics/ensemble_dynamics.py b/offlinerlkit/dynamics/ensemble_dynamics.py index 069816e..a13c829 100755 --- a/offlinerlkit/dynamics/ensemble_dynamics.py +++ b/offlinerlkit/dynamics/ensemble_dynamics.py @@ -136,7 +136,7 @@ def train( ) -> None: inputs, targets = self.format_samples_for_training(data) data_size = inputs.shape[0] - holdout_size = min(int(data_size * 0.2), 1000) + holdout_size = min(int(data_size * 0.15), 1000) train_size = data_size - holdout_size train_splits, holdout_splits = torch.utils.data.random_split(range(data_size), (train_size, holdout_size)) train_inputs, train_targets = inputs[train_splits.indices], targets[train_splits.indices] @@ -208,7 +208,7 @@ def learn(self, inputs: np.ndarray, targets: np.ndarray, batch_size: int = 256) var_loss = logvar.mean(dim=(1, 2)) loss = mse_loss_inv.sum() + var_loss.sum() loss = loss + self.model.get_decay_loss() - loss = loss + 0.01 * self.model.max_logvar.sum() - 0.01 * self.model.min_logvar.sum() + loss = loss + 0.001 * self.model.max_logvar.sum() - 0.001 * self.model.min_logvar.sum() self.optim.zero_grad() loss.backward() diff --git a/offlinerlkit/policy/model_based/mopo.py b/offlinerlkit/policy/model_based/mopo.py index cd302e7..38d065c 100755 --- a/offlinerlkit/policy/model_based/mopo.py +++ b/offlinerlkit/policy/model_based/mopo.py @@ -57,7 +57,6 @@ def rollout( for _ in range(rollout_length): actions = self.select_action(observations) next_observations, rewards, terminals, info = self.dynamics.step(observations, actions) - rollout_transitions["obss"].append(observations) rollout_transitions["next_obss"].append(next_observations) rollout_transitions["actions"].append(actions) diff --git a/offlinerlkit/policy/model_based/rambo.py b/offlinerlkit/policy/model_based/rambo.py index e323038..f815604 100644 --- a/offlinerlkit/policy/model_based/rambo.py +++ b/offlinerlkit/policy/model_based/rambo.py @@ -31,9 +31,10 @@ def __init__( gamma: float = 0.99, alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, adv_weight: float=0, + adv_train_steps: int=1000, adv_rollout_batch_size: int=256, adv_rollout_length: int=5, - include_ent_in_adv: bool=False, # CHECK 这里是不是False + include_ent_in_adv: bool=False, # 这里是不是False device="cpu" ) -> None: super().__init__( @@ -51,13 +52,14 @@ def __init__( self._dynmics_adv_optim = dynamics_adv_optim self._adv_weight = adv_weight + self._adv_train_steps = adv_train_steps self._adv_rollout_batch_size = adv_rollout_batch_size self._adv_rollout_length = adv_rollout_length self._include_ent_in_adv = include_ent_in_adv self.device = device def load(self, path): - self.load_state_dict(torch.load(path, map_location="cpu")) + self.load_state_dict(torch.load(os.path.join(path, "rambo.pt"), map_location="cpu")) def pretrain(self, data: Dict, n_epoch, batch_size, lr, logger) -> None: self._bc_optim = torch.optim.Adam(self.actor.parameters(), lr=lr) @@ -85,9 +87,11 @@ def pretrain(self, data: Dict, n_epoch, batch_size, lr, logger) -> None: self._bc_optim.step() sum_loss += bc_loss.cpu().item() print(f"Epoch {i_epoch}, mean bc loss {sum_loss/i_batch}") + # logger.logkv("loss/pretrain_bc", sum_loss/i_batch) + # logger.set_timestep(i_epoch) + # logger.dumpkvs(exclude) torch.save(self.state_dict(), os.path.join(logger.model_dir, "rambo_pretrain.pt")) - def update_dynamics( self, real_buffer, @@ -97,9 +101,10 @@ def update_dynamics( "sl_loss": 0, "adv_loss": 0 } + self.dynamics.train() steps = 0 - while steps < 1000: - init_obss = real_buffer.sample(self._adv_batch_size)["observations"].cpu().numpy() + while steps < self._adv_train_steps: + init_obss = real_buffer.sample(self._adv_rollout_batch_size)["observations"].cpu().numpy() observations = init_obss for t in range(self._adv_rollout_length): actions = self.select_action(observations) @@ -110,13 +115,15 @@ def update_dynamics( all_loss_info["adv_loss"] += loss_info["adv_loss"] all_loss_info["sl_loss"] += loss_info["sl_loss"] - nonterm_mask = (~terminals).flatten() + # nonterm_mask = (~terminals).flatten() steps += 1 - observations = next_observations[nonterm_mask] - if nonterm_mask.sum() == 0: - break + # observations = next_observations[nonterm_mask] + observations = next_observations + # if nonterm_mask.sum() == 0: + # break if steps == 1000: break + self.dynamics.eval() return {_key: _value/steps for _key, _value in all_loss_info.items()} @@ -131,28 +138,25 @@ def dynamics_step_and_forward( ): obs_act = np.concatenate([observations, actions], axis=-1) obs_act = self.dynamics.scaler.transform(obs_act) - with torch.no_grad(): - mean, logvar = self.dynamics.model(obs_act) - # mean = mean.cpu().numpy() - # logvar = logvar.cpu().numpy() + mean, logvar = self.dynamics.model(obs_act) observations = torch.from_numpy(observations).to(mean.device) mean[..., :-1] += observations std = torch.sqrt(torch.exp(logvar)) - _noise_generator = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(mean)) - noise = _noise_generator.sample() - + + dist = torch.distributions.Normal(mean, std) + ensemble_sample = dist.sample() + ensemble_size, batch_size, _ = ensemble_size.shape + # select the next observations - sample_size = mean.shape[1] - selected_indexes = np.random.randint(0, noise.shape[0], size=sample_size) - noise = noise[selected_indexes, np.arange(sample_size)] - sample = mean + noise * std - next_observations = sample[..., :-1][selected_indexes, np.arange(sample_size)] - rewards = sample[..., -1][selected_indexes, np.arange(sample_size)] + selected_indexes = np.random.randint(0, ensemble_size, size=batch_size) # CHECK 这里有可能应该使用所有模型 + sample = ensemble_sample[selected_indexes, np.arange(batch_size)] + next_observations = sample[..., :-1] + rewards = sample[..., -1:] terminals = np.squeeze(self.dynamics.terminal_fn(observations.detach().cpu().numpy(), actions, next_observations.detach().cpu().numpy())) - # terminals = torch.from_numpy(terminals).to(mean.device) - # evaluate the noises - log_prob = _noise_generator.log_prob(noise) - log_prob = log_prob.exp().sum(dim=0).log().sum(-1) + + # compute logprob + log_prob = dist.log_prob(sample) + log_prob = log_prob.exp().mean(dim=0).log().sum(-1) # compute the advantage with torch.no_grad(): @@ -163,26 +167,26 @@ def dynamics_step_and_forward( ) if self._include_ent_in_adv: next_q = next_q - self._alpha * next_policy_log_prob - value = rewards.unsqueeze(1) + (1-torch.from_numpy(terminals).to(mean.device).float().unsqueeze(1)) * self._gamma * next_q + value = rewards + (1-torch.from_numpy(terminals).to(mean.device).float().unsqueeze(1)) * self._gamma * next_q - q = torch.minimum( + value_baseline = torch.minimum( self.critic1(observations, actions), self.critic2(observations, actions) ) - advantage = q - value + advantage = value - value_baseline adv_loss = (log_prob * advantage).mean() # compute the supervised loss sl_input = torch.cat([sl_observations, sl_actions], dim=-1).cpu().numpy() sl_target = torch.cat([sl_next_observations-sl_observations, sl_rewards], dim=-1) - sl_input = self.dynamics.transform(sl_input) + sl_input = self.dynamics.scaler.transform(sl_input) sl_mean, sl_logvar = self.dynamics.model(sl_input) sl_inv_var = torch.exp(-sl_logvar) sl_mse_loss_inv = (torch.pow(sl_mean - sl_target, 2) * sl_inv_var).mean(dim=(1, 2)) sl_var_loss = sl_logvar.mean(dim=(1, 2)) sl_loss = sl_mse_loss_inv.sum() + sl_var_loss.sum() sl_loss = sl_loss + self.dynamics.model.get_decay_loss() - sl_loss = sl_loss + 0.01 * self.dynamics.model.max_logvar.sum() - 0.01 * self.dynamics.model.min_logvar.sum() + sl_loss = sl_loss + 0.001 * self.dynamics.model.max_logvar.sum() - 0.001 * self.dynamics.model.min_logvar.sum() all_loss = self._adv_weight * adv_loss + sl_loss self._dynmics_adv_optim.zero_grad() @@ -190,7 +194,7 @@ def dynamics_step_and_forward( self._dynmics_adv_optim.step() return next_observations.cpu().numpy(), terminals, { - "all_loss": all_loss.cpu().item(), - "sl_loss": sl_loss.cpu().item(), - "adv_loss": adv_loss.cpu().item() + "adv_dynamics_update/all_loss": all_loss.cpu().item(), + "adv_dynamics_update/sl_loss": sl_loss.cpu().item(), + "adv_dynamics_update/adv_loss": adv_loss.cpu().item() } \ No newline at end of file diff --git a/offlinerlkit/policy_trainer/mb_policy_trainer.py b/offlinerlkit/policy_trainer/mb_policy_trainer.py index 6d622b2..b440f1c 100755 --- a/offlinerlkit/policy_trainer/mb_policy_trainer.py +++ b/offlinerlkit/policy_trainer/mb_policy_trainer.py @@ -28,10 +28,12 @@ def __init__( batch_size: int = 256, real_ratio: float = 0.05, eval_episodes: int = 10, - normalize_obs: bool = False, + # normalize_obs: bool = False, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, oracle_dynamics=None, - dynamics_update_freq: int = 0 + dynamics_update_freq: int = 0, + obs_mean: Optional[np.ndarray]=None, + obs_std: Optional[np.ndarray]=None ) -> None: self.policy = policy self.eval_env = eval_env @@ -48,9 +50,10 @@ def __init__( self._batch_size = batch_size self._real_ratio = real_ratio self._eval_episodes = eval_episodes - self._normalize_obs = normalize_obs - if normalize_obs: - self._obs_mean, self._obs_std = self.real_buffer.normalize_obs() + self._obs_mean = obs_mean + self._obs_std = obs_std + if self._obs_mean is None or self._obs_std is None: + self._obs_mean, self._obs_std = 0, 1 self.lr_scheduler = lr_scheduler self.oracle_dynamics = oracle_dynamics @@ -130,8 +133,8 @@ def _evaluate(self) -> Dict[str, List[float]]: episode_reward, episode_length = 0, 0 while num_episodes < self._eval_episodes: - if self._normalize_obs: - obs = (np.array(obs).reshape(1,-1) - self._obs_mean) / self._obs_std + # if self._normalize_obs: + obs = (np.array(obs).reshape(1,-1) - self._obs_mean) / self._obs_std action = self.policy.select_action(obs, deterministic=True) next_obs, reward, terminal, _ = self.eval_env.step(action.flatten()) episode_reward += reward diff --git a/offlinerlkit/utils/termination_fns.py b/offlinerlkit/utils/termination_fns.py index 0b27e5c..af110a8 100755 --- a/offlinerlkit/utils/termination_fns.py +++ b/offlinerlkit/utils/termination_fns.py @@ -1,5 +1,11 @@ import numpy as np +def obs_unnormalization(termination_fn, obs_mean, obs_std): + def thunk(obs, act, next_obs): + obs = obs*obs_std + obs_mean + next_obs = next_obs*obs_std + obs_mean + return termination_fn(obs, act, next_obs) + return thunk def termination_fn_halfcheetah(obs, act, next_obs): assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 diff --git a/run_example/run_rambo.py b/run_example/run_rambo.py index f8ae07f..68b0838 100644 --- a/run_example/run_rambo.py +++ b/run_example/run_rambo.py @@ -15,7 +15,7 @@ from offlinerlkit.dynamics import EnsembleDynamics from offlinerlkit.dynamics import MujocoOracleDynamics from offlinerlkit.utils.scaler import StandardScaler -from offlinerlkit.utils.termination_fns import get_termination_fn +from offlinerlkit.utils.termination_fns import get_termination_fn, obs_unnormalization from offlinerlkit.utils.load_dataset import qlearning_dataset from offlinerlkit.buffer import ReplayBuffer from offlinerlkit.utils.logger import Logger, make_log_dirs @@ -106,7 +106,8 @@ def train(args=get_args()): critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) # CHECK: do anealing? - lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(actor_optim, args.epoch) + # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(actor_optim, args.epoch) + lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(actor_optim, 1, -1) if args.auto_alpha: target_entropy = args.target_entropy if args.target_entropy \ @@ -120,8 +121,28 @@ def train(args=get_args()): else: alpha = args.alpha + # create buffer + real_buffer = ReplayBuffer( + buffer_size=len(dataset["observations"]), + obs_shape=args.obs_shape, + obs_dtype=np.float32, + action_dim=args.action_dim, + action_dtype=np.float32, + device=args.device + ) + real_buffer.load_dataset(dataset) + obs_mean, obs_std = real_buffer.normalize_obs() + fake_buffer_size = args.step_per_epoch // args.rollout_freq * args.model_retain_epochs * args.rollout_batch_size * args.rollout_length + fake_buffer = ReplayBuffer( + buffer_size=fake_buffer_size, + obs_shape=args.obs_shape, + obs_dtype=np.float32, + action_dim=args.action_dim, + action_dtype=np.float32, + device=args.device + ) + # create dynamics - load_dynamics_model = True if args.load_dynamics_path else False dynamics_model = EnsembleDynamicsModel( obs_dim=np.prod(args.obs_shape), action_dim=args.action_dim, @@ -139,8 +160,8 @@ def train(args=get_args()): dynamics_model.parameters(), lr=args.dynamics_adv_lr ) - scaler = StandardScaler() - termination_fn = get_termination_fn(task=args.task) + scaler = StandardScaler() # CHECK 这里换成dummy scaler + termination_fn = obs_unnormalization(get_termination_fn(task=args.task), obs_mean, obs_std) dynamics = EnsembleDynamics( dynamics_model, dynamics_optim, @@ -148,8 +169,6 @@ def train(args=get_args()): termination_fn, ) - if args.load_dynamics_path: - dynamics.load(args.load_dynamics_path) oracle_dynamics = MujocoOracleDynamics(env) @@ -172,26 +191,6 @@ def train(args=get_args()): device=args.device ).to(args.device) - # create buffer - real_buffer = ReplayBuffer( - buffer_size=len(dataset["observations"]), - obs_shape=args.obs_shape, - obs_dtype=np.float32, - action_dim=args.action_dim, - action_dtype=np.float32, - device=args.device - ) - real_buffer.load_dataset(dataset) - fake_buffer_size = args.step_per_epoch // args.rollout_freq * args.model_retain_epochs * args.rollout_batch_size * args.rollout_length - fake_buffer = ReplayBuffer( - buffer_size=fake_buffer_size, - obs_shape=args.obs_shape, - obs_dtype=np.float32, - action_dim=args.action_dim, - action_dtype=np.float32, - device=args.device - ) - # log log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args)) # key: output file name, value: output handler type @@ -220,17 +219,20 @@ def train(args=get_args()): eval_episodes=args.eval_episodes, lr_scheduler=lr_scheduler, oracle_dynamics=oracle_dynamics, - normalize_obs=True + obs_mean=obs_mean, + obs_std=obs_std ) # train - if args.do_bc: - if args.load_bc is not None: - policy.load(args.load_bc) - policy.to(args.device) - else: - policy.pretrain(real_buffer.sample_all(), args.bc_epoch, args.bc_batch_size, args.bc_lr, logger) - if not load_dynamics_model: + if args.load_bc_path is not None: + policy.load(args.load_bc) + policy.to(args.device) + else: + policy.pretrain(real_buffer.sample_all(), args.bc_epoch, args.bc_batch_size, args.bc_lr, logger) + if args.load_dynamics_path: + dynamics.load(args.load_dynamics_path) + dynamics.to(args.device) + else: dynamics.train(real_buffer.sample_all(), logger) policy_trainer.train() From d24ed311142b49578bd8335ae3ab80e5842c3627 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Wed, 11 Jan 2023 22:02:42 +0800 Subject: [PATCH 3/5] fix: logs --- offlinerlkit/policy/model_based/rambo.py | 34 +++++++++--------- .../policy_trainer/mb_policy_trainer.py | 35 ++++++++++--------- run_example/run_rambo.py | 10 +++--- 3 files changed, 41 insertions(+), 38 deletions(-) diff --git a/offlinerlkit/policy/model_based/rambo.py b/offlinerlkit/policy/model_based/rambo.py index f815604..42b8312 100644 --- a/offlinerlkit/policy/model_based/rambo.py +++ b/offlinerlkit/policy/model_based/rambo.py @@ -59,7 +59,7 @@ def __init__( self.device = device def load(self, path): - self.load_state_dict(torch.load(os.path.join(path, "rambo.pt"), map_location="cpu")) + self.load_state_dict(torch.load(os.path.join(path, "rambo_pretrain.pt"), map_location="cpu")) def pretrain(self, data: Dict, n_epoch, batch_size, lr, logger) -> None: self._bc_optim = torch.optim.Adam(self.actor.parameters(), lr=lr) @@ -97,11 +97,11 @@ def update_dynamics( real_buffer, ) -> Tuple[Dict[str, np.ndarray], Dict]: all_loss_info = { - "all_loss": 0, - "sl_loss": 0, - "adv_loss": 0 + "adv_dynamics_update/all_loss": 0, + "adv_dynamics_update/sl_loss": 0, + "adv_dynamics_update/adv_loss": 0 } - self.dynamics.train() + # self.dynamics.train() steps = 0 while steps < self._adv_train_steps: init_obss = real_buffer.sample(self._adv_rollout_batch_size)["observations"].cpu().numpy() @@ -109,12 +109,10 @@ def update_dynamics( for t in range(self._adv_rollout_length): actions = self.select_action(observations) sl_observations, sl_actions, sl_next_observations, sl_rewards = \ - itemgetter("observations", "actions", "next_observations", "rewards")(real_buffer.sample(self._adv_batch_size)) + itemgetter("observations", "actions", "next_observations", "rewards")(real_buffer.sample(self._adv_rollout_batch_size)) next_observations, terminals, loss_info = self.dynamics_step_and_forward(observations, actions, sl_observations, sl_actions, sl_next_observations, sl_rewards) - all_loss_info["all_loss"] += loss_info["all_loss"] - all_loss_info["adv_loss"] += loss_info["adv_loss"] - all_loss_info["sl_loss"] += loss_info["sl_loss"] - + for _key in loss_info: + all_loss_info[_key] += loss_info[_key] # nonterm_mask = (~terminals).flatten() steps += 1 # observations = next_observations[nonterm_mask] @@ -123,7 +121,7 @@ def update_dynamics( # break if steps == 1000: break - self.dynamics.eval() + # self.dynamics.eval() return {_key: _value/steps for _key, _value in all_loss_info.items()} @@ -138,17 +136,21 @@ def dynamics_step_and_forward( ): obs_act = np.concatenate([observations, actions], axis=-1) obs_act = self.dynamics.scaler.transform(obs_act) - mean, logvar = self.dynamics.model(obs_act) - observations = torch.from_numpy(observations).to(mean.device) - mean[..., :-1] += observations + diff_mean, logvar = self.dynamics.model(obs_act) + observations = torch.from_numpy(observations).to(diff_mean.device) + diff_obs, diff_reward = torch.split(diff_mean, [diff_mean.shape[-1]-1, 1], dim=-1) + mean = torch.cat([diff_obs + observations, diff_reward], dim=-1) + # mean[..., :-1] = mean[..., :-1] + observations + # mean[..., :-1] += observations std = torch.sqrt(torch.exp(logvar)) dist = torch.distributions.Normal(mean, std) ensemble_sample = dist.sample() - ensemble_size, batch_size, _ = ensemble_size.shape + ensemble_size, batch_size, _ = ensemble_sample.shape # select the next observations - selected_indexes = np.random.randint(0, ensemble_size, size=batch_size) # CHECK 这里有可能应该使用所有模型 + selected_indexes = self.dynamics.model.random_elite_idxs(batch_size) + # selected_indexes = np.random.randint(0, ensemble_size, size=batch_size) # CHECK 这里有可能应该使用所有模型 sample = ensemble_sample[selected_indexes, np.arange(batch_size)] next_observations = sample[..., :-1] rewards = sample[..., -1:] diff --git a/offlinerlkit/policy_trainer/mb_policy_trainer.py b/offlinerlkit/policy_trainer/mb_policy_trainer.py index b440f1c..affe247 100755 --- a/offlinerlkit/policy_trainer/mb_policy_trainer.py +++ b/offlinerlkit/policy_trainer/mb_policy_trainer.py @@ -101,23 +101,24 @@ def train(self) -> Dict[str, float]: if self.lr_scheduler is not None: self.lr_scheduler.step() - # evaluate current policy - eval_info = self._evaluate() - ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"]) - ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"]) - norm_ep_rew_mean = self.eval_env.get_normalized_score(ep_reward_mean) * 100 - norm_ep_rew_std = self.eval_env.get_normalized_score(ep_reward_std) * 100 - last_10_performance.append(norm_ep_rew_mean) - - self.logger.logkv("eval/normalized_episode_reward", norm_ep_rew_mean) - self.logger.logkv("eval/normalized_episode_reward_std", norm_ep_rew_std) - self.logger.logkv("eval/episode_length", ep_length_mean) - self.logger.logkv("eval/episode_length_std", ep_length_std) - self.logger.set_timestep(num_timesteps) - self.logger.dumpkvs(exclude=["dynamics_training_progress"]) - - # save checkpoint - torch.save(self.policy.state_dict(), os.path.join(self.logger.checkpoint_dir, "policy.pth")) + if e % 10 == 0: + # evaluate current policy + eval_info = self._evaluate() + ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"]) + ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"]) + norm_ep_rew_mean = self.eval_env.get_normalized_score(ep_reward_mean) * 100 + norm_ep_rew_std = self.eval_env.get_normalized_score(ep_reward_std) * 100 + last_10_performance.append(norm_ep_rew_mean) + + self.logger.logkv("eval/normalized_episode_reward", norm_ep_rew_mean) + self.logger.logkv("eval/normalized_episode_reward_std", norm_ep_rew_std) + self.logger.logkv("eval/episode_length", ep_length_mean) + self.logger.logkv("eval/episode_length_std", ep_length_std) + self.logger.set_timestep(num_timesteps) + self.logger.dumpkvs(exclude=["dynamics_training_progress"]) + + # save checkpoint + torch.save(self.policy.state_dict(), os.path.join(self.logger.checkpoint_dir, "policy.pth")) self.logger.log("total time: {:.2f}s".format(time.time() - start_time)) torch.save(self.policy.state_dict(), os.path.join(self.logger.model_dir, "policy.pth")) diff --git a/run_example/run_rambo.py b/run_example/run_rambo.py index 68b0838..a71b6ff 100644 --- a/run_example/run_rambo.py +++ b/run_example/run_rambo.py @@ -26,7 +26,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--algo-name", type=str, default="rambo") - parser.add_argument("--task", type=str, default="hopper-medium-replay-v2") + parser.add_argument("--task", type=str, default="hopper-medium-v2") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--actor-lr", type=float, default=1e-4) parser.add_argument("--critic-lr", type=float, default=3e-4) @@ -62,7 +62,7 @@ def get_args(): parser.add_argument("--include-ent-in-adv", type=int, default=0) parser.add_argument("--do-bc", type=int, default=1) - parser.add_argument("--load-bc", type=str, default=None) + parser.add_argument("--load-bc-path", type=str, default=None) parser.add_argument("--bc-lr", type=float, default=1e-4) parser.add_argument("--bc-epoch", type=int, default=50) parser.add_argument("--bc-batch-size", type=int, default=256) @@ -224,14 +224,14 @@ def train(args=get_args()): ) # train - if args.load_bc_path is not None: - policy.load(args.load_bc) + if args.load_bc_path: + policy.load(args.load_bc_path) policy.to(args.device) else: policy.pretrain(real_buffer.sample_all(), args.bc_epoch, args.bc_batch_size, args.bc_lr, logger) if args.load_dynamics_path: dynamics.load(args.load_dynamics_path) - dynamics.to(args.device) + # dynamics.to(args.device) else: dynamics.train(real_buffer.sample_all(), logger) From 9a50204f7da2e9397c8e45cdb80bcce09a60cb22 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Thu, 12 Jan 2023 22:15:36 +0800 Subject: [PATCH 4/5] update: comform with rambo initial implementation --- .gitignore | 3 ++- offlinerlkit/policy/model_based/rambo.py | 13 +++++++++---- offlinerlkit/policy_trainer/mb_policy_trainer.py | 3 +++ run_example/run_rambo.py | 2 +- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index c3aa7b4..a0b716a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ **/*.egg-info **/*.png **/*.txt -**/.vscode \ No newline at end of file +**/.vscode +**/_log \ No newline at end of file diff --git a/offlinerlkit/policy/model_based/rambo.py b/offlinerlkit/policy/model_based/rambo.py index 42b8312..018897a 100644 --- a/offlinerlkit/policy/model_based/rambo.py +++ b/offlinerlkit/policy/model_based/rambo.py @@ -99,7 +99,9 @@ def update_dynamics( all_loss_info = { "adv_dynamics_update/all_loss": 0, "adv_dynamics_update/sl_loss": 0, - "adv_dynamics_update/adv_loss": 0 + "adv_dynamics_update/adv_loss": 0, + "adv_dynamics_update/adv_advantage": 0, + "adv_dynamics_update/adv_log_prob": 0, } # self.dynamics.train() steps = 0 @@ -150,7 +152,6 @@ def dynamics_step_and_forward( # select the next observations selected_indexes = self.dynamics.model.random_elite_idxs(batch_size) - # selected_indexes = np.random.randint(0, ensemble_size, size=batch_size) # CHECK 这里有可能应该使用所有模型 sample = ensemble_sample[selected_indexes, np.arange(batch_size)] next_observations = sample[..., :-1] rewards = sample[..., -1:] @@ -158,6 +159,7 @@ def dynamics_step_and_forward( # compute logprob log_prob = dist.log_prob(sample) + log_prob = log_prob[self.dynamics.model.elites.data, ...] log_prob = log_prob.exp().mean(dim=0).log().sum(-1) # compute the advantage @@ -176,7 +178,8 @@ def dynamics_step_and_forward( self.critic2(observations, actions) ) advantage = value - value_baseline - adv_loss = (log_prob * advantage).mean() + advantage = (advantage - advantage.mean()) / (advantage.std()+1e-6) + adv_loss = (log_prob * advantage).sum() # compute the supervised loss sl_input = torch.cat([sl_observations, sl_actions], dim=-1).cpu().numpy() @@ -198,5 +201,7 @@ def dynamics_step_and_forward( return next_observations.cpu().numpy(), terminals, { "adv_dynamics_update/all_loss": all_loss.cpu().item(), "adv_dynamics_update/sl_loss": sl_loss.cpu().item(), - "adv_dynamics_update/adv_loss": adv_loss.cpu().item() + "adv_dynamics_update/adv_loss": adv_loss.cpu().item(), + "adv_dynamics_update/adv_advantage": advantage.mean().cpu().item(), + "adv_dynamics_update/adv_log_prob": log_prob.mean().cpu().item(), } \ No newline at end of file diff --git a/offlinerlkit/policy_trainer/mb_policy_trainer.py b/offlinerlkit/policy_trainer/mb_policy_trainer.py index affe247..dfab5cb 100755 --- a/offlinerlkit/policy_trainer/mb_policy_trainer.py +++ b/offlinerlkit/policy_trainer/mb_policy_trainer.py @@ -78,6 +78,8 @@ def train(self) -> Dict[str, float]: "num rollout transitions: {}, reward mean: {:.4f}".\ format(rollout_info["num_transitions"], rollout_info["reward_mean"]) ) + for _key, _value in rollout_transitions.items(): + self.logger.logkv_mean("rollout_info/"+_key, _value) real_sample_size = int(self._batch_size * self._real_ratio) fake_sample_size = self._batch_size - real_sample_size @@ -122,6 +124,7 @@ def train(self) -> Dict[str, float]: self.logger.log("total time: {:.2f}s".format(time.time() - start_time)) torch.save(self.policy.state_dict(), os.path.join(self.logger.model_dir, "policy.pth")) + self.policy.dynamics.save(self.logger.model_dir) self.logger.close() return {"last_10_performance": np.mean(last_10_performance)} diff --git a/run_example/run_rambo.py b/run_example/run_rambo.py index a71b6ff..233b85d 100644 --- a/run_example/run_rambo.py +++ b/run_example/run_rambo.py @@ -54,7 +54,7 @@ def get_args(): parser.add_argument("--real-ratio", type=float, default=0.5) parser.add_argument("--load-dynamics-path", type=str, default=None) - parser.add_argument("--epoch", type=int, default=1000) + parser.add_argument("--epoch", type=int, default=2000) parser.add_argument("--step-per-epoch", type=int, default=1000) parser.add_argument("--eval_episodes", type=int, default=10) parser.add_argument("--batch-size", type=int, default=256) From faf1a8cfee60e282ecc681c629c74bb1e50b7be0 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Sat, 14 Jan 2023 23:24:35 +0800 Subject: [PATCH 5/5] fix bugs --- offlinerlkit/policy_trainer/mb_policy_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/offlinerlkit/policy_trainer/mb_policy_trainer.py b/offlinerlkit/policy_trainer/mb_policy_trainer.py index dfab5cb..17acdfd 100755 --- a/offlinerlkit/policy_trainer/mb_policy_trainer.py +++ b/offlinerlkit/policy_trainer/mb_policy_trainer.py @@ -78,7 +78,7 @@ def train(self) -> Dict[str, float]: "num rollout transitions: {}, reward mean: {:.4f}".\ format(rollout_info["num_transitions"], rollout_info["reward_mean"]) ) - for _key, _value in rollout_transitions.items(): + for _key, _value in rollout_info.items(): self.logger.logkv_mean("rollout_info/"+_key, _value) real_sample_size = int(self._batch_size * self._real_ratio)