Skip to content

Commit

Permalink
added temporal ensembling to the diffusion model
Browse files Browse the repository at this point in the history
  • Loading branch information
mlkakram committed Aug 26, 2024
1 parent 3b7d0ba commit f5dc85e
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ class ForceDiffusionConfig:
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
temporal_ensemble_coeff: float | None = None


# Unet / FILM
down_dims: tuple[int, ...] = (512, 1024, 2048)
Expand Down
125 changes: 121 additions & 4 deletions lerobot/common/policies/diffusion/modeling_force_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,15 @@ def __init__(
assert set(config.input_shapes).issuperset({*self.input_keys})
assert set(config.output_shapes).issuperset({*self.output_keys})

#TODO(malek): clean this
self.output_sizes = [config.output_shapes[action][0] for action in self.output_keys]

# self.max_blending_value = self.config.horizon - self.config.n_action_steps - self.config.n_obs_steps + 1
self.max_horizon_prediction = self.config.horizon - self.config.n_obs_steps + 1

if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = DiffusionTemporalEnsembler(config.temporal_ensemble_coeff, self.max_horizon_prediction, self.config.n_action_steps)

self.reset()

def get_optimizer_parameters(self):
Expand All @@ -133,6 +140,8 @@ def reset(self):
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
if self.config.temporal_ensemble_coeff is not None:
self.temporal_ensembler.reset()

@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
Expand All @@ -159,6 +168,12 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.state"] = torch.cat([batch[k] for k in self.input_keys], dim=-1)

# batch_size, n_obs_steps = batch["observation.state"].shape[:2]
n_obs_steps = self.config.n_obs_steps
start = n_obs_steps - 1
end = start + self.config.n_action_steps

# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)

Expand All @@ -167,6 +182,13 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)

if self.config.temporal_ensemble_coeff is not None:
actions = actions[:, start: ]
actions = self.temporal_ensembler.update(actions)

else:
actions = actions[:, start:end]

# TODO(rcadene): make above methods return output dictionary?
# seperate the action outputs into seperate entites
action_list = torch.split(actions, split_size_or_sections=self.output_sizes, dim=-1)
Expand All @@ -176,12 +198,13 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:

actions = torch.cat([actions[k] for k in self.output_keys], dim=-1)

self._queues["action"].extend(actions.transpose(0, 1))
self._queues["action"].extend(actions.transpose(0, 1))

action = self._queues["action"].popleft()
# TODO(Malek): seperate the actions again here
action = dict(zip(self.output_keys, torch.split(action, self.output_sizes, dim=-1)))
return action


def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
Expand Down Expand Up @@ -211,6 +234,100 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
raise ValueError(f"Unsupported noise scheduler type {name}")


class DiffusionTemporalEnsembler:
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int, n_action_steps) -> None:
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
coefficient works:
- Setting it to 0 uniformly weighs all actions.
- Setting it positive gives more weight to older actions.
- Setting it negative gives more weight to newer actions.
NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
results in older actions being weighed more highly than newer actions (the experiments documented in
https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
detrimental: doing so aggressively may diminish the benefits of action chunking).
Here we use an online method for computing the average rather than caching a history of actions in
order to compute the average offline. For a simple 1D sequence it looks something like:
```
import torch
seq = torch.linspace(8, 8.5, 100)
print(seq)
m = 0.01
exp_weights = torch.exp(-m * torch.arange(len(seq)))
print(exp_weights)
# Calculate offline
avg = (exp_weights * seq).sum() / exp_weights.sum()
print("offline", avg)
# Calculate online
for i, item in enumerate(seq):
if i == 0:
avg = item
continue
avg *= exp_weights[:i].sum()
avg += item * exp_weights[i]
avg /= exp_weights[:i+1].sum()
print("online", avg)
```
"""
self.chunk_size = chunk_size
self.n_action_steps = n_action_steps
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()

def reset(self):
"""Resets the online computation variables."""
self.ensembled_actions = None
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
self.ensembled_actions_count = None

def update(self, actions: Tensor) -> Tensor:
"""
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self.ensembled_actions = actions.clone()
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-self.n_action_steps] * self.ensemble_weights[self.ensembled_actions_count] # edited
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -self.n_action_steps:]], dim=1) # edited
self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones((self.n_action_steps, 1), dtype=torch.long, device=self.ensembled_actions.device)] # edited
)

# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
self.ensembled_actions[:, 0:self.n_action_steps],
self.ensembled_actions[:, self.n_action_steps:],
self.ensembled_actions_count[self.n_action_steps:],
)

return action


class DiffusionModel(nn.Module):
def __init__(self, config: ForceDiffusionConfig):
super().__init__()
Expand Down Expand Up @@ -341,9 +458,9 @@ def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
print(f"the time for computing actions {timeit.default_timer() - start_time}")

# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
# start = n_obs_steps - 1
# end = start + self.config.n_action_steps
# actions = actions[:, start:end]

return actions

Expand Down

0 comments on commit f5dc85e

Please sign in to comment.