diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index a984b37faa9..f91c050dde0 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -151,6 +151,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/di replay_buffer.size=120 \ env.name=CartPole-v1 \ logger.backend= +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/crossq/crossq.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=2 \ + collector.device= \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ + replay_buffer.size=120 \ + env.name=Pendulum-v1 \ + network.device= \ + logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ collector.total_frames=200 \ collector.init_random_frames=10 \ diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index ccd6cb23ed0..b46d789ed15 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -317,6 +317,7 @@ Regular modules Conv3dNet SqueezeLayer Squeeze2dLayer + BatchRenorm Algorithm-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index ef9bc1ee907..96a887196aa 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -121,6 +121,15 @@ REDQ REDQLoss +CrossQ +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + CrossQ + IQL ---- diff --git a/sota-check/run_crossq.sh b/sota-check/run_crossq.sh new file mode 100644 index 00000000000..2ae4ea51c49 --- /dev/null +++ b/sota-check/run_crossq.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=crossq +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/crossq_%j.txt +#SBATCH --error=slurm_errors/crossq_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="crossq" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/sota-implementations/crossq/crossq.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-implementations/crossq/config.yaml b/sota-implementations/crossq/config.yaml new file mode 100644 index 00000000000..1dcbd3db92d --- /dev/null +++ b/sota-implementations/crossq/config.yaml @@ -0,0 +1,58 @@ +# environment and task +env: + name: HalfCheetah-v4 + task: "" + library: gym + max_episode_steps: 1000 + seed: 42 + +# collector +collector: + total_frames: 1_000_000 + init_random_frames: 25000 + frames_per_batch: 1000 + init_env_steps: 1000 + device: cpu + env_per_collector: 1 + reset_at_each_iter: False + +# replay buffer +replay_buffer: + size: 1000000 + prb: 0 # use prioritized experience replay + scratch_dir: null + +# optim +optim: + utd_ratio: 1.0 + policy_update_delay: 3 + gamma: 0.99 + loss_function: l2 + lr: 1.0e-3 + weight_decay: 0.0 + batch_size: 256 + alpha_init: 1.0 + adam_eps: 1.0e-8 + beta1: 0.5 + beta2: 0.999 + +# network +network: + batch_norm_momentum: 0.01 + warmup_steps: 100000 + critic_hidden_sizes: [2048, 2048] + actor_hidden_sizes: [256, 256] + critic_activation: relu + actor_activation: relu + default_policy_scale: 1.0 + scale_lb: 0.1 + device: "cuda:0" + +# logging +logger: + backend: wandb + project_name: torchrl_example_crossQ + group_name: null + exp_name: ${env.name}_CrossQ + mode: online + eval_iter: 25000 diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py new file mode 100644 index 00000000000..df34d4ae68d --- /dev/null +++ b/sota-implementations/crossq/crossq.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""CrossQ Example. + +This is a simple self-contained example of a CrossQ training script. + +It supports state environments like MuJoCo. + +The helper functions are coded in the utils.py associated with this script. +""" +import time + +import hydra + +import numpy as np +import torch +import torch.cuda +import tqdm +from torchrl._utils import logger as torchrl_logger +from torchrl.envs.utils import ExplorationType, set_exploration_type + +from torchrl.record.loggers import generate_exp_name, get_logger +from utils import ( + log_metrics, + make_collector, + make_crossQ_agent, + make_crossQ_optimizer, + make_environment, + make_loss_module, + make_replay_buffer, +) + + +@hydra.main(version_base="1.1", config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + device = torch.device(device) + + # Create logger + exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="crossq_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create environments + train_env, eval_env = make_environment(cfg) + + # Create agent + model, exploration_policy = make_crossQ_agent(cfg, train_env, device) + + # Create CrossQ loss + loss_module = make_loss_module(cfg, model) + + # Create off-policy collector + collector = make_collector(cfg, train_env, exploration_policy.eval(), device=device) + + # Create replay buffer + replay_buffer = make_replay_buffer( + batch_size=cfg.optim.batch_size, + prb=cfg.replay_buffer.prb, + buffer_size=cfg.replay_buffer.size, + scratch_dir=cfg.replay_buffer.scratch_dir, + device="cpu", + ) + + # Create optimizers + ( + optimizer_actor, + optimizer_critic, + optimizer_alpha, + ) = make_crossQ_optimizer(cfg, loss_module) + + # Main loop + start_time = time.time() + collected_frames = 0 + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + init_random_frames = cfg.collector.init_random_frames + num_updates = int( + cfg.collector.env_per_collector + * cfg.collector.frames_per_batch + * cfg.optim.utd_ratio + ) + prb = cfg.replay_buffer.prb + eval_iter = cfg.logger.eval_iter + frames_per_batch = cfg.collector.frames_per_batch + eval_rollout_steps = cfg.env.max_episode_steps + + sampling_start = time.time() + update_counter = 0 + delayed_updates = cfg.optim.policy_update_delay + for _, tensordict in enumerate(collector): + sampling_time = time.time() - sampling_start + + # Update weights of the inference policy + collector.update_policy_weights_() + + pbar.update(tensordict.numel()) + + tensordict = tensordict.reshape(-1) + current_frames = tensordict.numel() + # Add to replay buffer + replay_buffer.extend(tensordict.cpu()) + collected_frames += current_frames + + # Optimization steps + training_start = time.time() + if collected_frames >= init_random_frames: + ( + actor_losses, + alpha_losses, + q_losses, + ) = ([], [], []) + for _ in range(num_updates): + + # Update actor every delayed_updates + update_counter += 1 + update_actor = update_counter % delayed_updates == 0 + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + if sampled_tensordict.device != device: + sampled_tensordict = sampled_tensordict.to(device) + else: + sampled_tensordict = sampled_tensordict.clone() + + # Compute loss + q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict) + q_loss = q_loss.mean() + # Update critic + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() + q_losses.append(q_loss.detach().item()) + + if update_actor: + actor_loss, metadata_actor = loss_module.actor_loss( + sampled_tensordict + ) + actor_loss = actor_loss.mean() + alpha_loss = loss_module.alpha_loss( + log_prob=metadata_actor["log_prob"] + ).mean() + + # Update actor + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() + + # Update alpha + optimizer_alpha.zero_grad() + alpha_loss.backward() + optimizer_alpha.step() + + actor_losses.append(actor_loss.detach().item()) + alpha_losses.append(alpha_loss.detach().item()) + + # Update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) + + training_time = time.time() - training_start + episode_end = ( + tensordict["next", "done"] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] + ) + episode_rewards = tensordict["next", "episode_reward"][episode_end] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = np.mean(q_losses).item() + metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item() + metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item() + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model[0], + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_time = time.time() - eval_start + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py new file mode 100644 index 00000000000..9883bc50b17 --- /dev/null +++ b/sota-implementations/crossq/utils.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from tensordict.nn import InteractionType, TensorDictModule +from tensordict.nn.distributions import NormalParamExtractor +from torch import nn, optim +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.envs import ( + CatTensors, + Compose, + DMControlEnv, + DoubleToFloat, + EnvCreator, + ParallelEnv, + TransformedEnv, +) +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import MLP, ProbabilisticActor, ValueOperator +from torchrl.modules.distributions import TanhNormal + +from torchrl.modules.models.batchrenorm import BatchRenorm1d +from torchrl.objectives import CrossQLoss + +# ==================================================================== +# Environment utils +# ----------------- + + +def env_maker(cfg, device="cpu"): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") + + +def apply_env_transforms(env, max_episode_steps=1000): + transformed_env = TransformedEnv( + env, + Compose( + InitTracker(), + StepCounter(max_episode_steps), + DoubleToFloat(), + RewardSum(), + ), + ) + return transformed_env + + +def make_environment(cfg): + """Make environments for training and evaluation.""" + parallel_env = ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, + ) + parallel_env.set_seed(cfg.env.seed) + + train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) + + eval_env = TransformedEnv( + ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, + ), + train_env.transform.clone(), + ) + return train_env, eval_env + + +# ==================================================================== +# Collector and replay buffer +# --------------------------- + + +def make_collector(cfg, train_env, actor_model_explore, device): + """Make collector.""" + collector = SyncDataCollector( + train_env, + actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + ) + collector.set_seed(cfg.env.seed) + return collector + + +def make_replay_buffer( + batch_size, + prb=False, + buffer_size=1000000, + scratch_dir=None, + device="cpu", + prefetch=3, +): + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + ), + batch_size=batch_size, + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + ), + batch_size=batch_size, + ) + replay_buffer.append_transform(lambda x: x.to(device, non_blocking=True)) + return replay_buffer + + +# ==================================================================== +# Model +# ----- + + +def make_crossQ_agent(cfg, train_env, device): + """Make CrossQ agent.""" + # Define Actor Network + in_keys = ["observation"] + action_spec = train_env.action_spec + if train_env.batch_size: + action_spec = action_spec[(0,) * len(train_env.batch_size)] + actor_net_kwargs = { + "num_cells": cfg.network.actor_hidden_sizes, + "out_features": 2 * action_spec.shape[-1], + "activation_class": get_activation(cfg.network.actor_activation), + "norm_class": BatchRenorm1d, + "norm_kwargs": { + "momentum": cfg.network.batch_norm_momentum, + "num_features": cfg.network.actor_hidden_sizes[-1], + "warmup_steps": cfg.network.warmup_steps, + }, + } + + actor_net = MLP(**actor_net_kwargs) + + dist_class = TanhNormal + dist_kwargs = { + "low": action_spec.space.low, + "high": action_spec.space.high, + "tanh_loc": False, + } + + actor_extractor = NormalParamExtractor( + scale_mapping=f"biased_softplus_{cfg.network.default_policy_scale}", + scale_lb=cfg.network.scale_lb, + ) + actor_net = nn.Sequential(actor_net, actor_extractor) + + in_keys_actor = in_keys + actor_module = TensorDictModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "loc", + "scale", + ], + ) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_type=InteractionType.RANDOM, + return_log_prob=False, + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": cfg.network.critic_hidden_sizes, + "out_features": 1, + "activation_class": get_activation(cfg.network.critic_activation), + "norm_class": BatchRenorm1d, + "norm_kwargs": { + "momentum": cfg.network.batch_norm_momentum, + "num_features": cfg.network.critic_hidden_sizes[-1], + "warmup_steps": cfg.network.warmup_steps, + }, + } + + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = train_env.fake_tensordict() + td = td.to(device) + for net in model: + net.eval() + net(td) + net.train() + del td + + return model, model[0] + + +# ==================================================================== +# CrossQ Loss +# --------- + + +def make_loss_module(cfg, model): + """Make loss module and target network updater.""" + # Create CrossQ loss + loss_module = CrossQLoss( + actor_network=model[0], + qvalue_network=model[1], + num_qvalue_nets=2, + loss_function=cfg.optim.loss_function, + alpha_init=cfg.optim.alpha_init, + ) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) + + return loss_module + + +def split_critic_params(critic_params): + critic1_params = [] + critic2_params = [] + + for param in critic_params: + data1, data2 = param.data.chunk(2, dim=0) + critic1_params.append(nn.Parameter(data1)) + critic2_params.append(nn.Parameter(data2)) + return critic1_params, critic2_params + + +def make_crossQ_optimizer(cfg, loss_module): + critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) + actor_params = list(loss_module.actor_network_params.flatten_keys().values()) + + optimizer_actor = optim.Adam( + actor_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + betas=(cfg.optim.beta1, cfg.optim.beta2), + ) + optimizer_critic = optim.Adam( + critic_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + betas=(cfg.optim.beta1, cfg.optim.beta2), + ) + optimizer_alpha = optim.Adam( + [loss_module.log_alpha], + lr=cfg.optim.lr, + ) + return optimizer_actor, optimizer_critic, optimizer_alpha + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(activation: str): + if activation == "relu": + return nn.ReLU + elif activation == "tanh": + return nn.Tanh + elif activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError diff --git a/test/test_cost.py b/test/test_cost.py index 2f187c8e3ba..7921a32e02e 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -98,6 +98,7 @@ A2CLoss, ClipPPOLoss, CQLLoss, + CrossQLoss, DDPGLoss, DiscreteCQLLoss, DiscreteIQLLoss, @@ -4940,6 +4941,698 @@ def test_discrete_sac_reduction(self, reduction): assert loss[key].shape == torch.Size([]) +class TestCrossQ(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + action_key="action", + ): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = TensorDictModule( + net, in_keys=[observation_key], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=module, + in_keys=["loc", "scale"], + spec=action_spec, + distribution_class=TanhNormal, + out_keys=[action_key], + ) + return actor.to(device) + + def _create_mock_qvalue( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + action_key="action", + out_keys=None, + ): + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim + action_dim, 1) + + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) + + module = ValueClass() + qvalue = ValueOperator( + module=module, + in_keys=[observation_key, action_key], + out_keys=out_keys, + ) + return qvalue.to(device) + + def _create_mock_common_layer_setup( + self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2 + ): + common = MLP( + num_cells=ncells, + in_features=n_obs, + depth=3, + out_features=n_hidden, + ) + actor_net = MLP( + num_cells=ncells, + in_features=n_hidden, + depth=1, + out_features=2 * n_act, + ) + qvalue = MLP( + in_features=n_hidden + n_act, + num_cells=ncells, + depth=1, + out_features=1, + ) + batch = [batch] + td = TensorDict( + { + "obs": torch.randn(*batch, n_obs), + "action": torch.randn(*batch, n_act), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(*batch, n_obs), + "reward": torch.randn(*batch, 1), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + }, + }, + batch, + ) + common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) + actor = ProbSeq( + common, + Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + ProbMod( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + ), + ) + qvalue_head = Mod( + qvalue, in_keys=["hidden", "action"], out_keys=["state_action_value"] + ) + qvalue = Seq(common, qvalue_head) + return actor, qvalue, common, td + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_data_crossq( + self, + batch=16, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + observation_key="observation", + action_key="action", + done_key="done", + terminated_key="terminated", + reward_key="reward", + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + observation_key: obs, + "next": { + observation_key: next_obs, + done_key: done, + terminated_key: terminated, + reward_key: reward, + }, + action_key: action, + }, + device=device, + ) + return td + + def _create_seq_mock_data_crossq( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = torch.ones(batch, T, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "done": done, + "terminated": terminated, + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + "collector": {"mask": mask}, + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + names=[None, "time"], + device=device, + ) + return td + + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + def test_crossq( + self, + num_qvalue, + device, + td_est, + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_crossq(device=device) + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + ) + + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + with pytest.raises(NotImplementedError): + loss_fn.make_value_estimator(td_est) + return + if td_est is not None: + loss_fn.make_value_estimator(td_est) + + with _check_td_steady(td): + loss = loss_fn(td) + + assert loss_fn.tensor_keys.priority in td.keys() + + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_crossq_state_dict( + self, + num_qvalue, + device, + ): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + ) + sd = loss_fn.state_dict() + loss_fn2 = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + ) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("separate_losses", [False, True]) + def test_crossq_separate_losses( + self, + separate_losses, + device, + ): + n_act = 4 + torch.manual_seed(self.seed) + actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act) + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)), + num_qvalue_nets=1, + separate_losses=separate_losses, + ) + loss = loss_fn(td) + + assert loss_fn.tensor_keys.priority in td.keys() + + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_qvalue": + common_layers_no = len(list(common.parameters())) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + if separate_losses: + common_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() for p in common_layers + ) + qvalue_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + None, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() for p in qvalue_layers + ) + else: + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + @pytest.mark.parametrize("n", range(1, 4)) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_crossq_batcher( + self, + n, + num_qvalue, + device, + ): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_crossq(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + ) + + ms = MultiStep(gamma=0.9, n_steps=n).to(device) + + td_clone = td.clone() + ms_td = ms(td_clone) + + torch.manual_seed(0) + np.random.seed(0) + + with _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + if n == 1: + assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + # Check param update effect on targets + target_actor = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_actor2 = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + + assert not any((p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2)) + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_crossq_tensordict_keys(self, td_est): + + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + value = None + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=2, + loss_function="l2", + ) + + default_keys = { + "priority": "td_error", + "state_action_value": "state_action_value", + "action": "action", + "log_prob": "_log_prob", + "reward": "reward", + "done": "done", + "terminated": "terminated", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + td_est=td_est, + ) + + qvalue = self._create_mock_qvalue() + loss_fn = CrossQLoss( + actor, + qvalue, + loss_function="l2", + ) + + key_mapping = { + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("action_key", ["action", "action2"]) + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_crossq_notensordict( + self, action_key, observation_key, reward_key, done_key, terminated_key + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_crossq( + action_key=action_key, + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + + actor = self._create_mock_actor( + observation_key=observation_key, action_key=action_key + ) + qvalue = self._create_mock_qvalue( + observation_key=observation_key, + action_key=action_key, + out_keys=["state_action_value"], + ) + + loss = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + ) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) + + kwargs = { + action_key: td.get(action_key), + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), + f"next_{observation_key}": td.get(("next", observation_key)), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + + # setting the seed for each loss so that drawing the random samples from value network + # leads to same numbers for both runs + torch.manual_seed(self.seed) + loss_val = loss(**kwargs) + + torch.manual_seed(self.seed) + + loss_val_td = loss(td) + assert len(loss_val) == 5 + + torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0]) + torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1]) + torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2]) + torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3]) + torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4]) + + # test select + torch.manual_seed(self.seed) + loss.select_out_keys("loss_actor", "loss_alpha") + if torch.__version__ >= "2.0.0": + loss_actor, loss_alpha = loss(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor, loss_alpha = loss(**kwargs) + return + assert loss_actor == loss_val_td["loss_actor"] + assert loss_alpha == loss_val_td["loss_alpha"] + + def test_state_dict( + self, + ): + + model = torch.nn.Linear(3, 4) + actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"]) + policy = ProbabilisticActor( + module=actor_module, + in_keys=["logits"], + out_keys=["action"], + distribution_class=TanhDelta, + ) + value = ValueOperator(module=model, in_keys=["obs"], out_keys="value") + + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + state = loss.state_dict() + + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.load_state_dict(state) + + # with an access in between + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.target_entropy + state = loss.state_dict() + + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.load_state_dict(state) + + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_crossq_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_mock_data_crossq(device=device) + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + loss_function="l2", + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + if not key.startswith("loss"): + continue + assert loss[key].shape == torch.Size([]) + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) diff --git a/test/test_modules.py b/test/test_modules.py index 59adbea653d..592464f0a96 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -34,7 +34,14 @@ VDNMixer, ) from torchrl.modules.distributions.utils import safeatanh, safetanh -from torchrl.modules.models import Conv3dNet, ConvNet, MLP, NoisyLazyLinear, NoisyLinear +from torchrl.modules.models import ( + BatchRenorm1d, + Conv3dNet, + ConvNet, + MLP, + NoisyLazyLinear, + NoisyLinear, +) from torchrl.modules.models.decision_transformer import ( _has_transformers, DecisionTransformer, @@ -1438,6 +1445,40 @@ def test_python_gru(device, bias, dropout, batch_first, num_layers): torch.testing.assert_close(h1, h2) +class TestBatchRenorm: + @pytest.mark.parametrize("num_steps", [0, 5]) + @pytest.mark.parametrize("smooth", [False, True]) + def test_batchrenorm(self, num_steps, smooth): + torch.manual_seed(0) + bn = torch.nn.BatchNorm1d(5, momentum=0.1, eps=1e-5) + brn = BatchRenorm1d( + 5, + momentum=0.1, + eps=1e-5, + warmup_steps=num_steps, + max_d=10000, + max_r=10000, + smooth=smooth, + ) + bn.train() + brn.train() + data_train = torch.randn(100, 5).split(25) + data_test = torch.randn(100, 5) + for i, d in enumerate(data_train): + b = bn(d) + a = brn(d) + if num_steps > 0 and ( + (i < num_steps and not smooth) or (i == 0 and smooth) + ): + torch.testing.assert_close(a, b) + else: + assert not torch.isclose(a, b).all(), i + + bn.eval() + brn.eval() + torch.testing.assert_close(bn(data_test), brn(data_test)) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 7f462782757..4241f6613a0 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -26,8 +26,8 @@ LazyStackedTensorDict, TensorDict, TensorDictBase, + unravel_key, ) -from tensordict._tensordict import unravel_key from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index fb0cc0135b8..62ccf53c30a 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -6,6 +6,8 @@ from torchrl.modules.tensordict_module.common import DistributionalDQNnet +from .batchrenorm import BatchRenorm1d + from .decision_transformer import DecisionTransformer from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise from .model_based import ( diff --git a/torchrl/modules/models/batchrenorm.py b/torchrl/modules/models/batchrenorm.py new file mode 100644 index 00000000000..26a2f9d50d2 --- /dev/null +++ b/torchrl/modules/models/batchrenorm.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +import torch.nn as nn + + +class BatchRenorm1d(nn.Module): + """BatchRenorm Module (https://arxiv.org/abs/1702.03275). + + The code is adapted from https://github.com/google-research/corenet + + BatchRenorm is an enhanced version of the standard BatchNorm. Unlike BatchNorm, + it utilizes running statistics to normalize batches after an initial warmup phase. + This approach reduces the impact of "outlier" batches that may occur during + extended training periods, making BatchRenorm more robust for long training runs. + + During the warmup phase, BatchRenorm functions identically to a BatchNorm layer. + + Args: + num_features (int): Number of features in the input tensor. + + Keyword Args: + momentum (float, optional): Momentum factor for computing the running mean and variance. + Defaults to ``0.01``. + eps (float, optional): Small value added to the variance to avoid division by zero. + Defaults to ``1e-5``. + max_r (float, optional): Maximum value for the scaling factor r. + Defaults to ``3.0``. + max_d (float, optional): Maximum value for the bias factor d. + Defaults to ``5.0``. + warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. + Defaults to ``10000``. + smooth (bool, optional): if ``True``, the behaviour smoothly transitions from regular + batch-norm (when ``iter=0``) to batch-renorm (when ``iter=warmup_steps``). + Otherwise, the behaviour will transition from batch-norm to batch-renorm when + ``iter=warmup_steps``. Defaults to ``False``. + """ + + def __init__( + self, + num_features: int, + *, + momentum: float = 0.01, + eps: float = 1e-5, + max_r: float = 3.0, + max_d: float = 5.0, + warmup_steps: int = 10000, + smooth: bool = False, + ): + super().__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.max_r = max_r + self.max_d = max_d + self.warmup_steps = warmup_steps + self.smooth = smooth + + self.register_buffer( + "running_mean", torch.zeros(num_features, dtype=torch.float32) + ) + self.register_buffer( + "running_var", torch.ones(num_features, dtype=torch.float32) + ) + self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.int64)) + self.weight = nn.Parameter(torch.ones(num_features, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(num_features, dtype=torch.float32)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not x.dim() >= 2: + raise ValueError( + f"The {type(self).__name__} expects a 2D (or more) tensor, got {x.dim()}." + ) + + view_dims = [1, x.shape[1]] + [1] * (x.dim() - 2) + + def _v(v): + return v.view(view_dims) + + running_std = (self.running_var + self.eps).sqrt_() + + if self.training: + reduce_dims = [i for i in range(x.dim()) if i != 1] + b_mean = x.mean(reduce_dims) + b_var = x.var(reduce_dims, unbiased=False) + b_std = (b_var + self.eps).sqrt_() + + r = torch.clamp((b_std.detach() / running_std), 1 / self.max_r, self.max_r) + d = torch.clamp( + (b_mean.detach() - self.running_mean) / running_std, + -self.max_d, + self.max_d, + ) + + # Compute warmup factor (0 during warmup, 1 after warmup) + if self.warmup_steps > 0: + if self.smooth: + warmup_factor = self.num_batches_tracked / self.warmup_steps + else: + warmup_factor = self.num_batches_tracked // self.warmup_steps + r = 1.0 + (r - 1.0) * warmup_factor + d = d * warmup_factor + + x = (x - _v(b_mean)) / _v(b_std) * _v(r) + _v(d) + + unbiased_var = b_var.detach() * x.shape[0] / (x.shape[0] - 1) + self.running_var += self.momentum * (unbiased_var - self.running_var) + self.running_mean += self.momentum * (b_mean.detach() - self.running_mean) + self.num_batches_tracked += 1 + self.num_batches_tracked.clamp_max(self.warmup_steps) + else: + x = (x - _v(self.running_mean)) / _v(running_std) + + x = _v(self.weight) * x + _v(self.bias) + return x diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 674c06123ad..aa13a88c7e9 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -6,6 +6,7 @@ from .a2c import A2CLoss from .common import LossModule from .cql import CQLLoss, DiscreteCQLLoss +from .crossq import CrossQLoss from .ddpg import DDPGLoss from .decision_transformer import DTLoss, OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py new file mode 100644 index 00000000000..22d35bd5799 --- /dev/null +++ b/torchrl/objectives/crossq.py @@ -0,0 +1,662 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import math +from dataclasses import dataclass +from functools import wraps +from typing import Dict, Tuple, Union + +import torch +from tensordict import TensorDict, TensorDictBase, TensorDictParams + +from tensordict.nn import dispatch, TensorDictModule +from tensordict.utils import NestedKey +from torch import Tensor +from torchrl.data.tensor_specs import CompositeSpec +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import ProbabilisticActor +from torchrl.objectives.common import LossModule + +from torchrl.objectives.utils import ( + _cache_values, + _reduce, + _vmap_func, + default_value_kwargs, + distance_loss, + ValueEstimators, +) +from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator + + +def _delezify(func): + @wraps(func) + def new_func(self, *args, **kwargs): + self.target_entropy + return func(self, *args, **kwargs) + + return new_func + + +class CrossQLoss(LossModule): + """TorchRL implementation of the CrossQ loss. + + Presented in "CROSSQ: BATCH NORMALIZATION IN DEEP REINFORCEMENT LEARNING + FOR GREATER SAMPLE EFFICIENCY AND SIMPLICITY" https://openreview.net/pdf?id=PczQtTsTIX + + This class has three loss functions that will be called sequentially by the `forward` method: + :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`. Alternatively, they can + be called by the user that order. + + Args: + actor_network (ProbabilisticActor): stochastic actor + qvalue_network (TensorDictModule): Q(s, a) parametric model. + This module typically outputs a ``"state_action_value"`` entry. + + Keyword Args: + num_qvalue_nets (integer, optional): number of Q-Value networks used. + Defaults to ``2``. + loss_function (str, optional): loss function to be used with + the value function loss. Default is `"smooth_l1"`. + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + min_alpha (float, optional): min value of alpha. + Default is None (no minimum value). + max_alpha (float, optional): max value of alpha. + Default is None (no maximum value). + action_spec (TensorSpec, optional): the action tensor spec. If not provided + and the target entropy is ``"auto"``, it will be retrieved from + the actor. + fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its + initial value. Otherwise, alpha will be optimized to + match the 'target_entropy' value. + Default is ``False``. + target_entropy (float or str, optional): Target entropy for the + stochastic policy. Default is "auto", where target entropy is + computed as :obj:`-prod(n_actions)`. + priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] + Tensordict key where to write the + priority (for prioritized replay buffer usage). Defaults to ``"td_error"``. + separate_losses (bool, optional): if ``True``, shared parameters between + policy and critic will only be trained on the policy loss. + Defaults to ``False``, ie. gradients are propagated to shared + parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.crossq import CrossQLoss + >>> from tensordict import TensorDict + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> actor = ProbabilisticActor( + ... module=module, + ... in_keys=["loc", "scale"], + ... spec=spec, + ... distribution_class=TanhNormal) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = CrossQLoss(actor, qvalue) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> data = TensorDict({ + ... "observation": torch.randn(*batch, n_obs), + ... "action": action, + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1), + ... ("next", "observation"): torch.randn(*batch, n_obs), + ... }, batch) + >>> loss(data) + TensorDict( + fields={ + alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + This class is compatible with non-tensordict based modules too and can be + used without recurring to any tensordict-related primitive. In this case, + the expected keyword arguments are: + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network. + The return value is a tuple of tensors in the following order: + ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]`` + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives import CrossQLoss + >>> _ = torch.manual_seed(42) + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> actor = ProbabilisticActor( + ... module=module, + ... in_keys=["loc", "scale"], + ... spec=spec, + ... distribution_class=TanhNormal) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = CrossQLoss(actor, qvalue) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> loss_actor, loss_qvalue, _, _, _ = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_observation=torch.zeros(*batch, n_obs), + ... next_reward=torch.randn(*batch, 1)) + >>> loss_actor.backward() + + The output keys can also be filtered using the :meth:`CrossQLoss.select_out_keys` + method. + + Examples: + >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue') + >>> loss_actor, loss_qvalue = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_observation=torch.zeros(*batch, n_obs), + ... next_reward=torch.randn(*batch, 1)) + >>> loss_actor.backward() + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"advantage"``. + state_action_value (NestedKey): The input tensordict key where the + state action value is expected. Defaults to ``"state_action_value"``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. + log_prob (NestedKey): The input tensordict key where the log probability is expected. + Defaults to ``"_log_prob"``. + """ + + action: NestedKey = "action" + state_action_value: NestedKey = "state_action_value" + priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + log_prob: NestedKey = "_log_prob" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + + actor_network: ProbabilisticActor + actor_network_params: TensorDictParams + qvalue_network: TensorDictModule + qvalue_network_params: TensorDictParams + target_actor_network_params: TensorDictParams + target_qvalue_network_params: TensorDictParams + + def __init__( + self, + actor_network: ProbabilisticActor, + qvalue_network: TensorDictModule, + *, + num_qvalue_nets: int = 2, + loss_function: str = "smooth_l1", + alpha_init: float = 1.0, + min_alpha: float = None, + max_alpha: float = None, + action_spec=None, + fixed_alpha: bool = False, + target_entropy: Union[str, float] = "auto", + priority_key: str = None, + separate_losses: bool = False, + reduction: str = None, + ) -> None: + self._in_keys = None + self._out_keys = None + if reduction is None: + reduction = "mean" + super().__init__() + self._set_deprecated_ctor_keys(priority_key=priority_key) + + # Actor + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=False, + ) + if separate_losses: + # we want to make sure there are no duplicates in the params: the + # params of critic must be refs to actor if they're shared + policy_params = list(actor_network.parameters()) + else: + policy_params = None + q_value_policy_params = None + + # Q value + self.num_qvalue_nets = num_qvalue_nets + + q_value_policy_params = policy_params + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=False, + compare_against=q_value_policy_params, + ) + + self.loss_function = loss_function + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + if bool(min_alpha) ^ bool(max_alpha): + min_alpha = min_alpha if min_alpha else 0.0 + if max_alpha == 0: + raise ValueError("max_alpha must be either None or greater than 0.") + max_alpha = max_alpha if max_alpha else 1e9 + if min_alpha: + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + else: + self.min_log_alpha = None + if max_alpha: + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + else: + self.max_log_alpha = None + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + self._target_entropy = target_entropy + self._action_spec = action_spec + self._vmap_qnetworkN0 = _vmap_func( + self.qvalue_network, (None, 0), randomness=self.vmap_randomness + ) + self.reduction = reduction + + @property + def target_entropy_buffer(self): + """The target entropy. + + This value can be controlled via the `target_entropy` kwarg in the constructor. + """ + return self.target_entropy + + @property + def target_entropy(self): + target_entropy = self._buffers.get("_target_entropy", None) + if target_entropy is not None: + return target_entropy + target_entropy = self._target_entropy + action_spec = self._action_spec + actor_network = self.actor_network + device = next(self.parameters()).device + if target_entropy == "auto": + action_spec = ( + action_spec + if action_spec is not None + else getattr(actor_network, "spec", None) + ) + if action_spec is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." + ) + if not isinstance(action_spec, CompositeSpec): + action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape + else: + action_container_shape = action_spec.shape + target_entropy = -float( + action_spec[self.tensor_keys.action] + .shape[len(action_container_shape) :] + .numel() + ) + delattr(self, "_target_entropy") + self.register_buffer( + "_target_entropy", torch.tensor(target_entropy, device=device) + ) + return self._target_entropy + + state_dict = _delezify(LossModule.state_dict) + load_state_dict = _delezify(LossModule.load_state_dict) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + if self._value_estimator is not None: + self._value_estimator.set_keys( + value=self.tensor_keys.value, + reward=self.tensor_keys.reward, + done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, + ) + self._set_in_keys() + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + if value_type is None: + value_type = self.default_value_estimator + self.value_type = value_type + + value_net = None + hp = dict(default_value_kwargs(value_type)) + hp.update(hyperparams) + if value_type is ValueEstimators.TD1: + self._value_estimator = TD1Estimator( + **hp, + value_network=value_net, + ) + elif value_type is ValueEstimators.TD0: + self._value_estimator = TD0Estimator( + **hp, + value_network=value_net, + ) + elif value_type is ValueEstimators.GAE: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) + elif value_type is ValueEstimators.TDLambda: + self._value_estimator = TDLambdaEstimator( + **hp, + value_network=value_net, + ) + else: + raise NotImplementedError(f"Unknown value type {value_type}") + + tensor_keys = { + "reward": self.tensor_keys.reward, + "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, + } + self._value_estimator.set_keys(**tensor_keys) + + @property + def device(self) -> torch.device: + for p in self.parameters(): + return p.device + raise RuntimeError( + "At least one of the networks of SACLoss must have trainable " "parameters." + ) + + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], + *self.qvalue_network.in_keys, + ] + self._in_keys = list(set(keys)) + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if self._out_keys is None: + keys = ["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"] + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """The forward method. + + Computes successively the :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`, and returns + a tensordict with these values along with the `"alpha"` value and the `"entropy"` value (detached). + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ + shape = None + if tensordict.ndimension() > 1: + shape = tensordict.shape + tensordict_reshape = tensordict.reshape(-1) + else: + tensordict_reshape = tensordict + + loss_qvalue, value_metadata = self.qvalue_loss(tensordict_reshape) + loss_actor, metadata_actor = self.actor_loss(tensordict_reshape) + loss_alpha = self.alpha_loss(log_prob=metadata_actor["log_prob"]) + tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"]) + if loss_actor.shape != loss_qvalue.shape: + raise RuntimeError( + f"Losses shape mismatch: {loss_actor.shape} and {loss_qvalue.shape}" + ) + if shape: + tensordict.update(tensordict_reshape.view(shape)) + entropy = -metadata_actor["log_prob"] + out = { + "loss_actor": loss_actor, + "loss_qvalue": loss_qvalue, + "loss_alpha": loss_alpha, + "alpha": self._alpha, + "entropy": entropy.detach().mean(), + **metadata_actor, + **value_metadata, + } + td_out = TensorDict(out, []) + # td_out = td_out.named_apply( + # lambda name, value: ( + # _reduce(value, reduction=self.reduction) + # if name.startswith("loss_") + # else value + # ), + # batch_size=[], + # ) + return td_out + + @property + @_cache_values + def _cached_detached_qvalue_params(self): + return self.qvalue_network_params.detach() + + def actor_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Compute the actor loss. + + The actor loss should be computed after the :meth:`~.qvalue_loss` and before the `~.alpha_loss` which + requires the `log_prob` field of the `metadata` returned by this method. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + + Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action. + """ + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): + dist = self.actor_network.get_dist(tensordict) + a_reparm = dist.rsample() + log_prob = dist.log_prob(a_reparm) + + td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) + self.qvalue_network.eval() + td_q.set(self.tensor_keys.action, a_reparm) + td_q = self._vmap_qnetworkN0( + td_q, + self._cached_detached_qvalue_params, + ) + + min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) + self.qvalue_network.train() + + if log_prob.shape != min_q.shape: + raise RuntimeError( + f"Losses shape mismatch: {log_prob.shape} and {min_q.shape}" + ) + actor_loss = self._alpha * log_prob - min_q + return _reduce(actor_loss, reduction=self.reduction), { + "log_prob": log_prob.detach() + } + + def qvalue_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Compute the q-value loss. + + The q-value loss should be computed before the :meth:`~.actor_loss`. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + + Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing + the detached `"td_error"` to be used for prioritized sampling. + """ + # # compute next action + with torch.no_grad(): + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): + next_tensordict = tensordict.get("next").clone(False) + next_dist = self.actor_network.get_dist(next_tensordict) + next_action = next_dist.sample() + next_tensordict.set(self.tensor_keys.action, next_action) + next_sample_log_prob = next_dist.log_prob(next_action) + + combined = torch.cat( + [ + tensordict.select(*self.qvalue_network.in_keys, strict=False), + next_tensordict.select(*self.qvalue_network.in_keys, strict=False), + ] + ) + pred_qs = self._vmap_qnetworkN0(combined, self.qvalue_network_params).get( + self.tensor_keys.state_action_value + ) + (current_state_action_value, next_state_action_value) = pred_qs.split( + tensordict.batch_size[0], dim=1 + ) + + # compute target value + if ( + next_state_action_value.shape[-len(next_sample_log_prob.shape) :] + != next_sample_log_prob.shape + ): + next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) + next_state_action_value = next_state_action_value.min(0)[0] + next_state_action_value = ( + next_state_action_value - self._alpha * next_sample_log_prob + ).detach() + + target_value = self.value_estimator.value_estimate( + tensordict, next_value=next_state_action_value + ).squeeze(-1) + + # get current q-values + pred_val = current_state_action_value.squeeze(-1) + + # compute loss + td_error = abs(pred_val - target_value) + loss_qval = distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ).sum(0) + metadata = {"td_error": td_error.detach().max(0)[0]} + return _reduce(loss_qval, reduction=self.reduction), metadata + + def alpha_loss(self, log_prob: Tensor) -> Tensor: + """Compute the entropy loss. + + The entropy loss should be computed last. + + Args: + log_prob (torch.Tensor): a log-probability as computed by the :meth:`~.actor_loss` and returned in the `metadata`. + + Returns: a differentiable tensor with the entropy loss. + """ + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_prob) + return _reduce(alpha_loss, reduction=self.reduction) + + @property + def _alpha(self): + if self.min_log_alpha is not None: + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha