Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 23, 2024
1 parent 582dd2c commit 550b423
Show file tree
Hide file tree
Showing 24 changed files with 234 additions and 152 deletions.
2 changes: 2 additions & 0 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def main(cfg: "DictConfig"): # noqa: F821
weight_decay=cfg.optim.weight_decay,
eps=cfg.optim.eps,
)
if cfg.loss.compile:
loss_module = torch.compile(loss_module)

# Create logger
logger = None
Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def main(cfg: "DictConfig"): # noqa: F821
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
)
if cfg.loss.compile:
loss_module = torch.compile(loss_module)

# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr)
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ loss:
critic_coef: 0.25
entropy_coef: 0.01
loss_critic_type: l2
compile: True
1 change: 1 addition & 0 deletions sota-implementations/a2c/config_mujoco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ loss:
critic_coef: 0.25
entropy_coef: 0.0
loss_critic_type: l2
compile: True
82 changes: 43 additions & 39 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import numpy as np
import torch
import tqdm
from tensordict import TensorDict

from torchrl._utils import logger as torchrl_logger
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -81,66 +83,68 @@ def main(cfg: "DictConfig"): # noqa: F821
alpha_prime_optim,
) = make_continuous_cql_optimizer(cfg, loss_module)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)

gradient_steps = cfg.optim.gradient_steps
policy_eval_start = cfg.optim.policy_eval_start
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
for i in range(gradient_steps):
pbar.update(1)
# sample data
data = replay_buffer.sample()
# compute loss
loss_vals = loss_module(data.clone().to(device))
def update(data, i):
critic_optim.zero_grad()
q_loss, metadata = loss_module.q_loss(data)
cql_loss, cql_metadata = loss_module.cql_loss(data)
q_loss = q_loss + cql_loss
q_loss.backward()
critic_optim.step()
metadata.update(cql_metadata)

# official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
policy_optim.zero_grad()
if i >= policy_eval_start:
actor_loss = loss_vals["loss_actor"]
actor_loss, actor_metadata = loss_module.actor_loss(data)
else:
actor_loss = loss_vals["loss_actor_bc"]
q_loss = loss_vals["loss_qvalue"]
cql_loss = loss_vals["loss_cql"]

q_loss = q_loss + cql_loss

# update model
alpha_loss = loss_vals["loss_alpha"]
alpha_prime_loss = loss_vals["loss_alpha_prime"]
actor_loss, actor_metadata = loss_module.actor_bc_loss(data)
actor_loss.backward()
policy_optim.step()
metadata.update(actor_metadata)

alpha_optim.zero_grad()
alpha_loss, alpha_metadata = loss_module.alpha_loss(actor_metadata)
alpha_loss.backward()
alpha_optim.step()

policy_optim.zero_grad()
actor_loss.backward()
policy_optim.step()
metadata.update(alpha_metadata)

if alpha_prime_optim is not None:
alpha_prime_optim.zero_grad()
alpha_prime_loss.backward(retain_graph=True)
alpha_prime_loss, alpha_prime_metadata = loss_module.alpha_prime_loss(data)
alpha_prime_loss.backward()
alpha_prime_optim.step()
metadata.update(alpha_prime_metadata)

critic_optim.zero_grad()
# TODO: we have the option to compute losses independently retain is not needed?
q_loss.backward(retain_graph=False)
critic_optim.step()
loss_vals = TensorDict(metadata)
loss_vals["loss_qvalue"] = q_loss
loss_vals["loss_cql"] = cql_loss
loss_vals["loss_alpha"] = alpha_loss
loss = actor_loss + q_loss + alpha_loss
if alpha_prime_optim is not None:
loss_vals["loss_alpha_prime"] = alpha_prime_loss
loss = loss + alpha_prime_loss
loss_vals["loss"] = loss

return loss_vals.detach()

if cfg.loss.compile:
update = torch.compile(update, mode=cfg.loss.compile_mode)

loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
# Training loop
start_time = time.time()
pbar = tqdm.tqdm(range(gradient_steps))
for i in pbar:
# sample data
data = replay_buffer.sample().to(device)
loss_vals = update(data, i)

# log metrics
to_log = {
"loss": loss.item(),
"loss_actor_bc": loss_vals["loss_actor_bc"].item(),
"loss_actor": loss_vals["loss_actor"].item(),
"loss_qvalue": q_loss.item(),
"loss_cql": cql_loss.item(),
"loss_alpha": alpha_loss.item(),
"loss_alpha_prime": alpha_prime_loss.item(),
}
to_log = loss_vals.mean().to_dict()

# update qnet_target params
target_net_updater.step()
Expand Down
123 changes: 82 additions & 41 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import torch
import tqdm
from tensordict import TensorDict
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule

from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger

Expand Down Expand Up @@ -111,17 +113,77 @@ def main(cfg: "DictConfig"): # noqa: F821
evaluation_interval = cfg.logger.log_interval
eval_rollout_steps = cfg.logger.eval_steps

def update(sampled_tensordict):

critic_optim.zero_grad()
q_loss, metadata = loss_module.q_loss(sampled_tensordict)
cql_loss, metadata_cql = loss_module.cql_loss(sampled_tensordict)
metadata.update(metadata)
q_loss = q_loss + cql_loss
q_loss.backward()
critic_optim.step()

if loss_module.with_lagrange:
alpha_prime_optim.zero_grad()
alpha_prime_loss, metadata_aprime = loss_module.alpha_prime_loss(
sampled_tensordict
)
metadata.update(metadata_aprime)
alpha_prime_loss.backward()
alpha_prime_optim.step()

policy_optim.zero_grad()
# loss_actor_bc, _ = loss_module.actor_bc_loss(sampled_tensordict)
actor_loss, actor_metadata = loss_module.actor_loss(sampled_tensordict)
metadata.update(actor_metadata)
actor_loss.backward()
policy_optim.step()

alpha_optim.zero_grad()
alpha_loss, metadata_actor = loss_module.alpha_loss(actor_metadata)
metadata.update(metadata_actor)
alpha_loss.backward()
alpha_optim.step()
loss_td = TensorDict(metadata)

loss_td["loss_actor"] = actor_loss
loss_td["loss_qvalue"] = q_loss
loss_td["loss_cql"] = cql_loss
loss_td["loss_alpha"] = alpha_loss
if alpha_prime_optim:
alpha_prime_loss = loss_td["loss_alpha_prime"]

loss = actor_loss + alpha_loss + q_loss
if alpha_prime_optim is not None:
loss = loss + alpha_prime_loss

loss_td["loss"] = loss
return loss_td.detach()

if cfg.loss.compile:
update = torch.compile(update, mode=cfg.loss.compile_mode)

if cfg.loss.cudagraphs:
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)

sampling_start = time.time()
for i, tensordict in enumerate(collector):
collector_iter = iter(collector)
for i in range(cfg.collector.total_frames):
timeit.print()
timeit.erase()
with timeit("collection"):
tensordict = next(collector_iter)
sampling_time = time.time() - sampling_start
pbar.update(tensordict.numel())
# update weights of the inference policy
collector.update_policy_weights_()
with timeit("update policies"):
# update weights of the inference policy
collector.update_policy_weights_()

tensordict = tensordict.view(-1)
tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# add to replay buffer
replay_buffer.extend(tensordict.cpu())
with timeit("extend"):
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# optimization steps
Expand All @@ -130,44 +192,22 @@ def main(cfg: "DictConfig"): # noqa: F821
log_loss_td = TensorDict({}, [num_updates])
for j in range(num_updates):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample()
with timeit("sample"):
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

loss_td = loss_module(sampled_tensordict)

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
cql_loss = loss_td["loss_cql"]
q_loss = q_loss + cql_loss
alpha_loss = loss_td["loss_alpha"]
alpha_prime_loss = loss_td["loss_alpha_prime"]

alpha_optim.zero_grad()
alpha_loss.backward()
alpha_optim.step()

policy_optim.zero_grad()
actor_loss.backward()
policy_optim.step()

if alpha_prime_optim is not None:
alpha_prime_optim.zero_grad()
alpha_prime_loss.backward(retain_graph=True)
alpha_prime_optim.step()

critic_optim.zero_grad()
q_loss.backward(retain_graph=False)
critic_optim.step()

log_loss_td[j] = loss_td.detach()
with timeit("update"):
loss_td = update(sampled_tensordict)
log_loss_td[j] = loss_td

# update qnet_target params
target_net_updater.step()
with timeit("target net"):
# update qnet_target params
target_net_updater.step()

# update priority
if prb:
Expand All @@ -191,10 +231,11 @@ def main(cfg: "DictConfig"): # noqa: F821
metrics_to_log["train/loss_actor"] = log_loss_td.get("loss_actor").mean()
metrics_to_log["train/loss_qvalue"] = log_loss_td.get("loss_qvalue").mean()
metrics_to_log["train/loss_alpha"] = log_loss_td.get("loss_alpha").mean()
metrics_to_log["train/loss_alpha_prime"] = log_loss_td.get(
"loss_alpha_prime"
).mean()
metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
if alpha_prime_optim is not None:
metrics_to_log["train/loss_alpha_prime"] = log_loss_td.get(
"loss_alpha_prime"
).mean()
# metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

Expand All @@ -204,7 +245,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cur_test_frame = (i * frames_per_batch) // evaluation_interval
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(), timeit("eval"):
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/cql/discrete_cql_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ loss:
loss_function: l2
gamma: 0.99
tau: 0.005
compile: True
4 changes: 3 additions & 1 deletion sota-implementations/cql/offline_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,7 @@ loss:
max_q_backup: False
deterministic_backup: False
num_random: 10
with_lagrange: True
with_lagrange: False
lagrange_thresh: 5.0 # tau
compile: False
compile_mode: reduce-overhead
5 changes: 4 additions & 1 deletion sota-implementations/cql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,8 @@ loss:
max_q_backup: False
deterministic_backup: False
num_random: 10
with_lagrange: True
with_lagrange: False
lagrange_thresh: 10.0
compile: False
compile_mode: reduce-overhead
cudagraphs: False
14 changes: 10 additions & 4 deletions sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,21 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
# We use a ProbabilisticActor to make sure that we map the
# network output to the right space using a TanhDelta
# distribution.
high = action_spec.space.high
low = action_spec.space.low
if train_env.batch_size:
high = high[(0,) * len(train_env.batch_size)]
low = low[(0,) * len(train_env.batch_size)]
actor = ProbabilisticActor(
module=actor_module,
in_keys=["loc", "scale"],
spec=action_spec,
distribution_class=TanhNormal,
distribution_kwargs={
"low": action_spec.space.low[len(train_env.batch_size) :],
"high": action_spec.space.high[
len(train_env.batch_size) :
], # remove batch-size
"low": low.to(device),
"high": high.to(device),
"tanh_loc": False,
"safe_tanh": not cfg.loss.compile,
},
default_interaction_type=ExplorationType.RANDOM,
)
Expand Down Expand Up @@ -334,6 +338,8 @@ def make_discrete_loss(loss_cfg, model):
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)
if loss_cfg.compile:
loss_module = torch.compile(loss_module)

return loss_module, target_net_updater

Expand Down
Loading

0 comments on commit 550b423

Please sign in to comment.