Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RAMBO #1

Merged
merged 6 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@
**/dist
**/*.egg-info
**/*.png
**/*.txt
**/*.txt
**/.vscode
**/_log
4 changes: 2 additions & 2 deletions offlinerlkit/dynamics/ensemble_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def train(
) -> None:
inputs, targets = self.format_samples_for_training(data)
data_size = inputs.shape[0]
holdout_size = min(int(data_size * 0.2), 1000)
holdout_size = min(int(data_size * 0.15), 1000)
train_size = data_size - holdout_size
train_splits, holdout_splits = torch.utils.data.random_split(range(data_size), (train_size, holdout_size))
train_inputs, train_targets = inputs[train_splits.indices], targets[train_splits.indices]
Expand Down Expand Up @@ -208,7 +208,7 @@ def learn(self, inputs: np.ndarray, targets: np.ndarray, batch_size: int = 256)
var_loss = logvar.mean(dim=(1, 2))
loss = mse_loss_inv.sum() + var_loss.sum()
loss = loss + self.model.get_decay_loss()
loss = loss + 0.01 * self.model.max_logvar.sum() - 0.01 * self.model.min_logvar.sum()
loss = loss + 0.001 * self.model.max_logvar.sum() - 0.001 * self.model.min_logvar.sum()

self.optim.zero_grad()
loss.backward()
Expand Down
2 changes: 2 additions & 0 deletions offlinerlkit/policy/__init__.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# model based
from offlinerlkit.policy.model_based.mopo import MOPOPolicy
from offlinerlkit.policy.model_based.mobile import MOBILEPolicy
from offlinerlkit.policy.model_based.rambo import RAMBOPolicy


__all__ = [
Expand All @@ -23,4 +24,5 @@
"TD3BCPolicy",
"MOPOPolicy",
"MOBILEPolicy",
"RAMBOPolicy"
]
1 change: 0 additions & 1 deletion offlinerlkit/policy/model_based/mopo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def rollout(
for _ in range(rollout_length):
actions = self.select_action(observations)
next_observations, rewards, terminals, info = self.dynamics.step(observations, actions)

rollout_transitions["obss"].append(observations)
rollout_transitions["next_obss"].append(next_observations)
rollout_transitions["actions"].append(actions)
Expand Down
207 changes: 207 additions & 0 deletions offlinerlkit/policy/model_based/rambo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import numpy as np
import torch
import torch.nn as nn
import gym
import os

from torch.nn import functional as F
from typing import Dict, Union, Tuple
from collections import defaultdict
from operator import itemgetter
from offlinerlkit.policy import MOPOPolicy
from offlinerlkit.dynamics import BaseDynamics


class RAMBOPolicy(MOPOPolicy):
"""
RAMBO-RL: Robust Adversarial Model-Based Offline Reinforcement Learning <Ref: https://arxiv.org/abs/2204.12581>
"""

def __init__(
self,
dynamics: BaseDynamics,
actor: nn.Module,
critic1: nn.Module,
critic2: nn.Module,
actor_optim: torch.optim.Optimizer,
critic1_optim: torch.optim.Optimizer,
critic2_optim: torch.optim.Optimizer,
dynamics_adv_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
adv_weight: float=0,
adv_train_steps: int=1000,
adv_rollout_batch_size: int=256,
adv_rollout_length: int=5,
include_ent_in_adv: bool=False, # 这里是不是False
device="cpu"
) -> None:
super().__init__(
dynamics,
actor,
critic1,
critic2,
actor_optim,
critic1_optim,
critic2_optim,
tau=tau,
gamma=gamma,
alpha=alpha
)

self._dynmics_adv_optim = dynamics_adv_optim
self._adv_weight = adv_weight
self._adv_train_steps = adv_train_steps
self._adv_rollout_batch_size = adv_rollout_batch_size
self._adv_rollout_length = adv_rollout_length
self._include_ent_in_adv = include_ent_in_adv
self.device = device

def load(self, path):
self.load_state_dict(torch.load(os.path.join(path, "rambo_pretrain.pt"), map_location="cpu"))

def pretrain(self, data: Dict, n_epoch, batch_size, lr, logger) -> None:
self._bc_optim = torch.optim.Adam(self.actor.parameters(), lr=lr)
observations = data["observations"]
actions = data["actions"]
sample_num = observations.shape[0]
idxs = np.arange(sample_num)

logger.log("Pretraining policy")
self.actor.train()
for i_epoch in range(n_epoch):
np.random.shuffle(idxs)
sum_loss = 0
for i_batch in range(sample_num // batch_size):
batch_obs = observations[i_batch * batch_size: (i_batch + 1) * batch_size]
batch_act = actions[i_batch * batch_size: (i_batch + 1) * batch_size]
batch_obs = torch.from_numpy(batch_obs).to(self.device)
batch_act = torch.from_numpy(batch_act).to(self.device)
dist = self.actor(batch_obs)
log_prob = dist.log_prob(batch_act)
bc_loss = - log_prob.mean()

self._bc_optim.zero_grad()
bc_loss.backward()
self._bc_optim.step()
sum_loss += bc_loss.cpu().item()
print(f"Epoch {i_epoch}, mean bc loss {sum_loss/i_batch}")
# logger.logkv("loss/pretrain_bc", sum_loss/i_batch)
# logger.set_timestep(i_epoch)
# logger.dumpkvs(exclude)
torch.save(self.state_dict(), os.path.join(logger.model_dir, "rambo_pretrain.pt"))

def update_dynamics(
self,
real_buffer,
) -> Tuple[Dict[str, np.ndarray], Dict]:
all_loss_info = {
"adv_dynamics_update/all_loss": 0,
"adv_dynamics_update/sl_loss": 0,
"adv_dynamics_update/adv_loss": 0,
"adv_dynamics_update/adv_advantage": 0,
"adv_dynamics_update/adv_log_prob": 0,
}
# self.dynamics.train()
steps = 0
while steps < self._adv_train_steps:
init_obss = real_buffer.sample(self._adv_rollout_batch_size)["observations"].cpu().numpy()
observations = init_obss
for t in range(self._adv_rollout_length):
actions = self.select_action(observations)
sl_observations, sl_actions, sl_next_observations, sl_rewards = \
itemgetter("observations", "actions", "next_observations", "rewards")(real_buffer.sample(self._adv_rollout_batch_size))
next_observations, terminals, loss_info = self.dynamics_step_and_forward(observations, actions, sl_observations, sl_actions, sl_next_observations, sl_rewards)
for _key in loss_info:
all_loss_info[_key] += loss_info[_key]
# nonterm_mask = (~terminals).flatten()
steps += 1
# observations = next_observations[nonterm_mask]
observations = next_observations
# if nonterm_mask.sum() == 0:
# break
if steps == 1000:
break
# self.dynamics.eval()
return {_key: _value/steps for _key, _value in all_loss_info.items()}


def dynamics_step_and_forward(
self,
observations,
actions,
sl_observations,
sl_actions,
sl_next_observations,
sl_rewards,
):
obs_act = np.concatenate([observations, actions], axis=-1)
obs_act = self.dynamics.scaler.transform(obs_act)
diff_mean, logvar = self.dynamics.model(obs_act)
observations = torch.from_numpy(observations).to(diff_mean.device)
diff_obs, diff_reward = torch.split(diff_mean, [diff_mean.shape[-1]-1, 1], dim=-1)
mean = torch.cat([diff_obs + observations, diff_reward], dim=-1)
# mean[..., :-1] = mean[..., :-1] + observations
# mean[..., :-1] += observations
std = torch.sqrt(torch.exp(logvar))

dist = torch.distributions.Normal(mean, std)
ensemble_sample = dist.sample()
ensemble_size, batch_size, _ = ensemble_sample.shape

# select the next observations
selected_indexes = self.dynamics.model.random_elite_idxs(batch_size)
sample = ensemble_sample[selected_indexes, np.arange(batch_size)]
next_observations = sample[..., :-1]
rewards = sample[..., -1:]
terminals = np.squeeze(self.dynamics.terminal_fn(observations.detach().cpu().numpy(), actions, next_observations.detach().cpu().numpy()))

# compute logprob
log_prob = dist.log_prob(sample)
log_prob = log_prob[self.dynamics.model.elites.data, ...]
log_prob = log_prob.exp().mean(dim=0).log().sum(-1)

# compute the advantage
with torch.no_grad():
next_actions, next_policy_log_prob = self.actforward(next_observations, deterministic=True)
next_q = torch.minimum(
self.critic1(next_observations, next_actions),
self.critic2(next_observations, next_actions)
)
if self._include_ent_in_adv:
next_q = next_q - self._alpha * next_policy_log_prob
value = rewards + (1-torch.from_numpy(terminals).to(mean.device).float().unsqueeze(1)) * self._gamma * next_q

value_baseline = torch.minimum(
self.critic1(observations, actions),
self.critic2(observations, actions)
)
advantage = value - value_baseline
advantage = (advantage - advantage.mean()) / (advantage.std()+1e-6)
adv_loss = (log_prob * advantage).sum()

# compute the supervised loss
sl_input = torch.cat([sl_observations, sl_actions], dim=-1).cpu().numpy()
sl_target = torch.cat([sl_next_observations-sl_observations, sl_rewards], dim=-1)
sl_input = self.dynamics.scaler.transform(sl_input)
sl_mean, sl_logvar = self.dynamics.model(sl_input)
sl_inv_var = torch.exp(-sl_logvar)
sl_mse_loss_inv = (torch.pow(sl_mean - sl_target, 2) * sl_inv_var).mean(dim=(1, 2))
sl_var_loss = sl_logvar.mean(dim=(1, 2))
sl_loss = sl_mse_loss_inv.sum() + sl_var_loss.sum()
sl_loss = sl_loss + self.dynamics.model.get_decay_loss()
sl_loss = sl_loss + 0.001 * self.dynamics.model.max_logvar.sum() - 0.001 * self.dynamics.model.min_logvar.sum()

all_loss = self._adv_weight * adv_loss + sl_loss
self._dynmics_adv_optim.zero_grad()
all_loss.backward()
self._dynmics_adv_optim.step()

return next_observations.cpu().numpy(), terminals, {
"adv_dynamics_update/all_loss": all_loss.cpu().item(),
"adv_dynamics_update/sl_loss": sl_loss.cpu().item(),
"adv_dynamics_update/adv_loss": adv_loss.cpu().item(),
"adv_dynamics_update/adv_advantage": advantage.mean().cpu().item(),
"adv_dynamics_update/adv_log_prob": log_prob.mean().cpu().item(),
}
63 changes: 39 additions & 24 deletions offlinerlkit/policy_trainer/mb_policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ def __init__(
batch_size: int = 256,
real_ratio: float = 0.05,
eval_episodes: int = 10,
normalize_obs: bool = False,
# normalize_obs: bool = False,
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
oracle_dynamics=None
oracle_dynamics=None,
dynamics_update_freq: int = 0,
obs_mean: Optional[np.ndarray]=None,
obs_std: Optional[np.ndarray]=None
) -> None:
self.policy = policy
self.eval_env = eval_env
Expand All @@ -40,15 +43,17 @@ def __init__(

self._rollout_freq, self._rollout_batch_size, \
self._rollout_length = rollout_setting
self._dynamics_update_freq = dynamics_update_freq

self._epoch = epoch
self._step_per_epoch = step_per_epoch
self._batch_size = batch_size
self._real_ratio = real_ratio
self._eval_episodes = eval_episodes
self._normalize_obs = normalize_obs
if normalize_obs:
self._obs_mean, self._obs_std = self.real_buffer.normalize_obs()
self._obs_mean = obs_mean
self._obs_std = obs_std
if self._obs_mean is None or self._obs_std is None:
self._obs_mean, self._obs_std = 0, 1
self.lr_scheduler = lr_scheduler

self.oracle_dynamics = oracle_dynamics
Expand All @@ -73,6 +78,8 @@ def train(self) -> Dict[str, float]:
"num rollout transitions: {}, reward mean: {:.4f}".\
format(rollout_info["num_transitions"], rollout_info["reward_mean"])
)
for _key, _value in rollout_info.items():
self.logger.logkv_mean("rollout_info/"+_key, _value)

real_sample_size = int(self._batch_size * self._real_ratio)
fake_sample_size = self._batch_size - real_sample_size
Expand All @@ -85,31 +92,39 @@ def train(self) -> Dict[str, float]:
for k, v in loss.items():
self.logger.logkv_mean(k, v)

# update the dynamics if necessary
if 0 < self._dynamics_update_freq and (num_timesteps+1)%self._dynamics_update_freq == 0:
dynamics_update_info = self.policy.update_dynamics(self.real_buffer)
for k, v in dynamics_update_info.items():
self.logger.logkv_mean(k, v)

num_timesteps += 1

if self.lr_scheduler is not None:
self.lr_scheduler.step()

# evaluate current policy
eval_info = self._evaluate()
ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"])
ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"])
norm_ep_rew_mean = self.eval_env.get_normalized_score(ep_reward_mean) * 100
norm_ep_rew_std = self.eval_env.get_normalized_score(ep_reward_std) * 100
last_10_performance.append(norm_ep_rew_mean)

self.logger.logkv("eval/normalized_episode_reward", norm_ep_rew_mean)
self.logger.logkv("eval/normalized_episode_reward_std", norm_ep_rew_std)
self.logger.logkv("eval/episode_length", ep_length_mean)
self.logger.logkv("eval/episode_length_std", ep_length_std)
self.logger.set_timestep(num_timesteps)
self.logger.dumpkvs(exclude=["dynamics_training_progress"])

# save checkpoint
torch.save(self.policy.state_dict(), os.path.join(self.logger.checkpoint_dir, "policy.pth"))
if e % 10 == 0:
# evaluate current policy
eval_info = self._evaluate()
ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"])
ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"])
norm_ep_rew_mean = self.eval_env.get_normalized_score(ep_reward_mean) * 100
norm_ep_rew_std = self.eval_env.get_normalized_score(ep_reward_std) * 100
last_10_performance.append(norm_ep_rew_mean)

self.logger.logkv("eval/normalized_episode_reward", norm_ep_rew_mean)
self.logger.logkv("eval/normalized_episode_reward_std", norm_ep_rew_std)
self.logger.logkv("eval/episode_length", ep_length_mean)
self.logger.logkv("eval/episode_length_std", ep_length_std)
self.logger.set_timestep(num_timesteps)
self.logger.dumpkvs(exclude=["dynamics_training_progress"])

# save checkpoint
torch.save(self.policy.state_dict(), os.path.join(self.logger.checkpoint_dir, "policy.pth"))

self.logger.log("total time: {:.2f}s".format(time.time() - start_time))
torch.save(self.policy.state_dict(), os.path.join(self.logger.model_dir, "policy.pth"))
self.policy.dynamics.save(self.logger.model_dir)
self.logger.close()

return {"last_10_performance": np.mean(last_10_performance)}
Expand All @@ -122,8 +137,8 @@ def _evaluate(self) -> Dict[str, List[float]]:
episode_reward, episode_length = 0, 0

while num_episodes < self._eval_episodes:
if self._normalize_obs:
obs = (np.array(obs).reshape(1,-1) - self._obs_mean) / self._obs_std
# if self._normalize_obs:
obs = (np.array(obs).reshape(1,-1) - self._obs_mean) / self._obs_std
action = self.policy.select_action(obs, deterministic=True)
next_obs, reward, terminal, _ = self.eval_env.step(action.flatten())
episode_reward += reward
Expand Down
6 changes: 6 additions & 0 deletions offlinerlkit/utils/termination_fns.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import numpy as np

def obs_unnormalization(termination_fn, obs_mean, obs_std):
def thunk(obs, act, next_obs):
obs = obs*obs_std + obs_mean
next_obs = next_obs*obs_std + obs_mean
return termination_fn(obs, act, next_obs)
return thunk

def termination_fn_halfcheetah(obs, act, next_obs):
assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
Expand Down
Loading