diff --git a/.gitignore b/.gitignore index d53318df..c2d62341 100644 --- a/.gitignore +++ b/.gitignore @@ -123,3 +123,5 @@ venv.bak/ # private macros macros_private.py +*.pyc +act/detr/models/__pycache__ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6f64af97..79837098 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ imageio-ffmpeg matplotlib egl_probe>=1.0.1 torch -torchvision +torchvision \ No newline at end of file diff --git a/robomimic/algo/algo.py b/robomimic/algo/algo.py index 321db01d..330b065e 100644 --- a/robomimic/algo/algo.py +++ b/robomimic/algo/algo.py @@ -118,6 +118,7 @@ def __init__( self.global_config = global_config self.ac_dim = ac_dim + self.ac_key = global_config.train.ac_key self.device = device self.obs_key_shapes = obs_key_shapes @@ -201,7 +202,7 @@ def process_batch_for_training(self, batch): """ return batch - def postprocess_batch_for_training(self, batch, obs_normalization_stats): + def postprocess_batch_for_training(self, batch, normalization_stats, normalize_actions=True): """ Does some operations (like channel swap, uint8 to float conversion, normalization) after @process_batch_for_training is called, in order to ensure these operations @@ -222,7 +223,11 @@ def postprocess_batch_for_training(self, batch, obs_normalization_stats): """ # ensure obs_normalization_stats are torch Tensors on proper device - obs_normalization_stats = TensorUtils.to_float(TensorUtils.to_device(TensorUtils.to_tensor(obs_normalization_stats), self.device)) + normalization_stats = TensorUtils.to_float( + TensorUtils.to_device( + TensorUtils.to_tensor(normalization_stats), self.device + ) + ) # we will search the nested batch dictionary for the following special batch dict keys # and apply the processing function to their values (which correspond to observations) @@ -236,14 +241,16 @@ def recurse_helper(d): if k in obs_keys: # found key - stop search and process observation if d[k] is not None: - d[k] = ObsUtils.process_obs_dict(d[k]) - if obs_normalization_stats is not None: - d[k] = ObsUtils.normalize_obs(d[k], obs_normalization_stats=obs_normalization_stats) + d[k] = ObsUtils.process_obs_dict(d[k], imagenet_normalize=self.global_config.train.imagenet_normalize_images) elif isinstance(d[k], dict): # search down into dictionary recurse_helper(d[k]) recurse_helper(batch) + if normalization_stats is not None: + batch = ObsUtils.normalize_batch( + batch, normalization_stats=normalization_stats, normalize_actions=normalize_actions + ) return batch def train_on_batch(self, batch, epoch, validate=False): @@ -502,8 +509,10 @@ def _prepare_observation(self, ob): # ensure obs_normalization_stats are torch Tensors on proper device obs_normalization_stats = TensorUtils.to_float(TensorUtils.to_device(TensorUtils.to_tensor(self.obs_normalization_stats), self.policy.device)) # limit normalization to obs keys being used, in case environment includes extra keys - ob = { k : ob[k] for k in self.policy.global_config.all_obs_keys } - ob = ObsUtils.normalize_obs(ob, obs_normalization_stats=obs_normalization_stats) + ob = {k: ob[k] for k in self.policy.global_config.all_obs_keys} + ob = ObsUtils.normalize_batch( + ob, obs_normalization_stats=obs_normalization_stats + ) return ob def __repr__(self): diff --git a/robomimic/algo/bc.py b/robomimic/algo/bc.py index 091be78e..02c7dfe8 100644 --- a/robomimic/algo/bc.py +++ b/robomimic/algo/bc.py @@ -107,7 +107,8 @@ def process_batch_for_training(self, batch): will be used for training """ input_batch = dict() - input_batch["obs"] = {k: batch["obs"][k][:, 0, :] for k in batch["obs"]} + #input_batch["obs"] = {k: batch["obs"][k][:, 0, :] for k in batch["obs"]} + input_batch["obs"] = {k: v[:, 0, :] if v.ndim != 1 else v for k, v in batch['obs'].items()} input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present input_batch["actions"] = batch["actions"][:, 0, :] # we move to device first before float conversion because image observation modalities will be uint8 - diff --git a/robomimic/algo/diffusion_policy.py b/robomimic/algo/diffusion_policy.py new file mode 100644 index 00000000..f1ad2610 --- /dev/null +++ b/robomimic/algo/diffusion_policy.py @@ -0,0 +1,693 @@ +""" +Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi +""" +from typing import Callable, Union +import math +from collections import OrderedDict, deque +from packaging.version import parse as parse_version +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +# requires diffusers==0.11.1 +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from diffusers.training_utils import EMAModel + +import robomimic.models.obs_nets as ObsNets +import robomimic.utils.tensor_utils as TensorUtils +import robomimic.utils.torch_utils as TorchUtils +import robomimic.utils.obs_utils as ObsUtils + +from robomimic.algo import register_algo_factory_func, PolicyAlgo + +@register_algo_factory_func("diffusion_policy") +def algo_config_to_class(algo_config): + """ + Maps algo config to the BC algo class to instantiate, along with additional algo kwargs. + + Args: + algo_config (Config instance): algo config + + Returns: + algo_class: subclass of Algo + algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm + """ + + if algo_config.unet.enabled: + return DiffusionPolicyUNet, {} + elif algo_config.transformer.enabled: + raise NotImplementedError() + else: + raise RuntimeError() + +class DiffusionPolicyUNet(PolicyAlgo): + def _create_networks(self): + """ + Creates networks and places them into @self.nets. + """ + # set up different observation groups for @MIMO_MLP + observation_group_shapes = OrderedDict() + observation_group_shapes["obs"] = OrderedDict(self.obs_shapes) + encoder_kwargs = ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder) + + obs_encoder = ObsNets.ObservationGroupEncoder( + observation_group_shapes=observation_group_shapes, + encoder_kwargs=encoder_kwargs, + ) + # IMPORTANT! + # replace all BatchNorm with GroupNorm to work with EMA + # performance will tank if you forget to do this! + obs_encoder = replace_bn_with_gn(obs_encoder) + + obs_dim = obs_encoder.output_shape()[0] + + # create network object + noise_pred_net = ConditionalUnet1D( + input_dim=self.ac_dim, + global_cond_dim=obs_dim*self.algo_config.horizon.observation_horizon + ) + + # the final arch has 2 parts + nets = nn.ModuleDict({ + 'policy': nn.ModuleDict({ + 'obs_encoder': obs_encoder, + 'noise_pred_net': noise_pred_net + }) + }) + + nets = nets.float().to(self.device) + + # setup noise scheduler + noise_scheduler = None + if self.algo_config.ddpm.enabled: + noise_scheduler = DDPMScheduler( + num_train_timesteps=self.algo_config.ddpm.num_train_timesteps, + beta_schedule=self.algo_config.ddpm.beta_schedule, + clip_sample=self.algo_config.ddpm.clip_sample, + prediction_type=self.algo_config.ddpm.prediction_type + ) + elif self.algo_config.ddim.enabled: + noise_scheduler = DDIMScheduler( + num_train_timesteps=self.algo_config.ddim.num_train_timesteps, + beta_schedule=self.algo_config.ddim.beta_schedule, + clip_sample=self.algo_config.ddim.clip_sample, + set_alpha_to_one=self.algo_config.ddim.set_alpha_to_one, + steps_offset=self.algo_config.ddim.steps_offset, + prediction_type=self.algo_config.ddim.prediction_type + ) + else: + raise RuntimeError() + + # setup EMA + ema = None + if self.algo_config.ema.enabled: + ema = EMAModel(parameters=nets.parameters(), power=self.algo_config.ema.power) + + # set attrs + self.nets = nets + self._shadow_nets = copy.deepcopy(self.nets).eval() + self._shadow_nets.requires_grad_(False) + self.noise_scheduler = noise_scheduler + self.ema = ema + self.action_check_done = False + self.obs_queue = None + self.action_queue = None + + def process_batch_for_training(self, batch): + """ + Processes input batch from a data loader to filter out + relevant information and prepare the batch for training. + + Args: + batch (dict): dictionary with torch.Tensors sampled + from a data loader + + Returns: + input_batch (dict): processed and filtered batch that + will be used for training + """ + To = self.algo_config.horizon.observation_horizon + Ta = self.algo_config.horizon.action_horizon + Tp = self.algo_config.horizon.prediction_horizon + + input_batch = dict() + input_batch["obs"] = {k: batch["obs"][k][:, :To, :] for k in batch["obs"]} + input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present + input_batch["actions"] = batch["actions"][:, :Tp, :] + + # check if actions are normalized to [-1,1] + if not self.action_check_done: + actions = input_batch["actions"] + in_range = (-1 <= actions) & (actions <= 1) + all_in_range = torch.all(in_range).item() + if not all_in_range: + raise ValueError('"actions" must be in range [-1,1] for Diffusion Policy! Check if hdf5_normalize_action is enabled.') + self.action_check_done = True + + return TensorUtils.to_device(TensorUtils.to_float(input_batch), self.device) + + def train_on_batch(self, batch, epoch, validate=False): + """ + Training on a single batch of data. + + Args: + batch (dict): dictionary with torch.Tensors sampled + from a data loader and filtered by @process_batch_for_training + + epoch (int): epoch number - required by some Algos that need + to perform staged training and early stopping + + validate (bool): if True, don't perform any learning updates. + + Returns: + info (dict): dictionary of relevant inputs, outputs, and losses + that might be relevant for logging + """ + To = self.algo_config.horizon.observation_horizon + Ta = self.algo_config.horizon.action_horizon + Tp = self.algo_config.horizon.prediction_horizon + action_dim = self.ac_dim + B = batch['actions'].shape[0] + + + with TorchUtils.maybe_no_grad(no_grad=validate): + info = super(DiffusionPolicyUNet, self).train_on_batch(batch, epoch, validate=validate) + actions = batch['actions'] + + # encode obs + inputs = { + 'obs': batch["obs"], + 'goal': batch["goal_obs"] + } + for k in self.obs_shapes: + # first two dimensions should be [B, T] for inputs + assert inputs['obs'][k].ndim - 2 == len(self.obs_shapes[k]) + + obs_features = TensorUtils.time_distributed(inputs, self.nets['policy']['obs_encoder'], inputs_as_kwargs=True) + assert obs_features.ndim == 3 # [B, T, D] + + obs_cond = obs_features.flatten(start_dim=1) + + # sample noise to add to actions + noise = torch.randn(actions.shape, device=self.device) + + # sample a diffusion iteration for each data point + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, + (B,), device=self.device + ).long() + + # add noise to the clean actions according to the noise magnitude at each diffusion iteration + # (this is the forward diffusion process) + noisy_actions = self.noise_scheduler.add_noise( + actions, noise, timesteps) + + # predict the noise residual + noise_pred = self.nets['policy']['noise_pred_net']( + noisy_actions, timesteps, global_cond=obs_cond) + + # L2 loss + loss = F.mse_loss(noise_pred, noise) + + # logging + losses = { + 'l2_loss': loss + } + info["losses"] = TensorUtils.detach(losses) + + if not validate: + # gradient step + policy_grad_norms = TorchUtils.backprop_for_loss( + net=self.nets, + optim=self.optimizers["policy"], + loss=loss, + ) + + # update Exponential Moving Average of the model weights + if self.ema is not None: + self.ema.step(self.nets.parameters()) + + step_info = { + 'policy_grad_norms': policy_grad_norms + } + info.update(step_info) + + return info + + def log_info(self, info): + """ + Process info dictionary from @train_on_batch to summarize + information to pass to tensorboard for logging. + + Args: + info (dict): dictionary of info + + Returns: + loss_log (dict): name -> summary statistic + """ + log = super(DiffusionPolicyUNet, self).log_info(info) + log["Loss"] = info["losses"]["l2_loss"].item() + if "policy_grad_norms" in info: + log["Policy_Grad_Norms"] = info["policy_grad_norms"] + return log + + def reset(self): + """ + Reset algo state to prepare for environment rollouts. + """ + # setup inference queues + To = self.algo_config.horizon.observation_horizon + Ta = self.algo_config.horizon.action_horizon + obs_queue = deque(maxlen=To) + action_queue = deque(maxlen=Ta) + self.obs_queue = obs_queue + self.action_queue = action_queue + + def get_action(self, obs_dict, goal_dict=None): + """ + Get policy action outputs. + + Args: + obs_dict (dict): current observation [1, Do] + goal_dict (dict): (optional) goal + + Returns: + action (torch.Tensor): action tensor [1, Da] + """ + # obs_dict: key: [1,D] + To = self.algo_config.horizon.observation_horizon + Ta = self.algo_config.horizon.action_horizon + + # TODO: obs_queue already handled by frame_stack + # make sure we have at least To observations in obs_queue + # if not enough, repeat + # if already full, append one to the obs_queue + # n_repeats = max(To - len(self.obs_queue), 1) + # self.obs_queue.extend([obs_dict] * n_repeats) + + if len(self.action_queue) == 0: + # no actions left, run inference + # turn obs_queue into dict of tensors (concat at T dim) + # import pdb; pdb.set_trace() + # obs_dict_list = TensorUtils.list_of_flat_dict_to_dict_of_list(list(self.obs_queue)) + # obs_dict_tensor = dict((k, torch.cat(v, dim=0).unsqueeze(0)) for k,v in obs_dict_list.items()) + + # run inference + # [1,T,Da] + action_sequence = self._get_action_trajectory(obs_dict=obs_dict) + + # put actions into the queue + self.action_queue.extend(action_sequence[0]) + + # has action, execute from left to right + # [Da] + action = self.action_queue.popleft() + + # [1,Da] + action = action.unsqueeze(0) + return action + + def _get_action_trajectory(self, obs_dict, goal_dict=None): + assert not self.nets.training + To = self.algo_config.horizon.observation_horizon + Ta = self.algo_config.horizon.action_horizon + Tp = self.algo_config.horizon.prediction_horizon + action_dim = self.ac_dim + if self.algo_config.ddpm.enabled is True: + num_inference_timesteps = self.algo_config.ddpm.num_inference_timesteps + elif self.algo_config.ddim.enabled is True: + num_inference_timesteps = self.algo_config.ddim.num_inference_timesteps + else: + raise ValueError + + # select network + nets = self.nets + if self.ema is not None: + self.ema.copy_to(parameters=self._shadow_nets.parameters()) + nets = self._shadow_nets + + # encode obs + inputs = { + 'obs': obs_dict, + 'goal': goal_dict + } + for k in self.obs_shapes: + # first two dimensions should be [B, T] for inputs + assert inputs['obs'][k].ndim - 2 == len(self.obs_shapes[k]) + obs_features = TensorUtils.time_distributed(inputs, self.nets['policy']['obs_encoder'], inputs_as_kwargs=True) + assert obs_features.ndim == 3 # [B, T, D] + B = obs_features.shape[0] + + # reshape observation to (B,obs_horizon*obs_dim) + obs_cond = obs_features.flatten(start_dim=1) + + # initialize action from Guassian noise + noisy_action = torch.randn( + (B, Tp, action_dim), device=self.device) + naction = noisy_action + + # init scheduler + self.noise_scheduler.set_timesteps(num_inference_timesteps) + + for k in self.noise_scheduler.timesteps: + # predict noise + noise_pred = nets['policy']['noise_pred_net']( + sample=naction, + timestep=k, + global_cond=obs_cond + ) + + # inverse diffusion step (remove noise) + naction = self.noise_scheduler.step( + model_output=noise_pred, + timestep=k, + sample=naction + ).prev_sample + + # process action using Ta + start = To - 1 + end = start + Ta + action = naction[:,start:end] + return action + + def serialize(self): + """ + Get dictionary of current model parameters. + """ + return { + "nets": self.nets.state_dict(), + "ema": self.ema.state_dict() if self.ema is not None else None, + } + + def deserialize(self, model_dict): + """ + Load model from a checkpoint. + + Args: + model_dict (dict): a dictionary saved by self.serialize() that contains + the same keys as @self.network_classes + """ + self.nets.load_state_dict(model_dict["nets"]) + if model_dict.get("ema", None) is not None: + self.ema.load_state_dict(model_dict["ema"]) + + +# =================== Vision Encoder Utils ===================== +def replace_submodules( + root_module: nn.Module, + predicate: Callable[[nn.Module], bool], + func: Callable[[nn.Module], nn.Module]) -> nn.Module: + """ + Replace all submodules selected by the predicate with + the output of func. + + predicate: Return true if the module is to be replaced. + func: Return new module to use. + """ + if predicate(root_module): + return func(root_module) + + if parse_version(torch.__version__) < parse_version('1.9.0'): + raise ImportError('This function requires pytorch >= 1.9.0') + + bn_list = [k.split('.') for k, m + in root_module.named_modules(remove_duplicate=True) + if predicate(m)] + for *parent, k in bn_list: + parent_module = root_module + if len(parent) > 0: + parent_module = root_module.get_submodule('.'.join(parent)) + if isinstance(parent_module, nn.Sequential): + src_module = parent_module[int(k)] + else: + src_module = getattr(parent_module, k) + tgt_module = func(src_module) + if isinstance(parent_module, nn.Sequential): + parent_module[int(k)] = tgt_module + else: + setattr(parent_module, k, tgt_module) + # verify that all modules are replaced + bn_list = [k.split('.') for k, m + in root_module.named_modules(remove_duplicate=True) + if predicate(m)] + assert len(bn_list) == 0 + return root_module + +def replace_bn_with_gn( + root_module: nn.Module, + features_per_group: int=16) -> nn.Module: + """ + Relace all BatchNorm layers with GroupNorm. + """ + replace_submodules( + root_module=root_module, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm( + num_groups=x.num_features//features_per_group, + num_channels=x.num_features) + ) + return root_module + +# =================== UNet for Diffusion ============== + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Downsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + +class Upsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Conv1dBlock(nn.Module): + ''' + Conv1d --> GroupNorm --> Mish + ''' + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.GroupNorm(n_groups, out_channels), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class ConditionalResidualBlock1D(nn.Module): + def __init__(self, + in_channels, + out_channels, + cond_dim, + kernel_size=3, + n_groups=8): + super().__init__() + + self.blocks = nn.ModuleList([ + Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), + Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), + ]) + + # FiLM modulation https://arxiv.org/abs/1709.07871 + # predicts per-channel scale and bias + cond_channels = out_channels * 2 + self.out_channels = out_channels + self.cond_encoder = nn.Sequential( + nn.Mish(), + nn.Linear(cond_dim, cond_channels), + nn.Unflatten(-1, (-1, 1)) + ) + + # make sure dimensions compatible + self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ + if in_channels != out_channels else nn.Identity() + + def forward(self, x, cond): + ''' + x : [ batch_size x in_channels x horizon ] + cond : [ batch_size x cond_dim] + + returns: + out : [ batch_size x out_channels x horizon ] + ''' + out = self.blocks[0](x) + embed = self.cond_encoder(cond) + + embed = embed.reshape( + embed.shape[0], 2, self.out_channels, 1) + scale = embed[:,0,...] + bias = embed[:,1,...] + out = scale * out + bias + + out = self.blocks[1](out) + out = out + self.residual_conv(x) + return out + + +class ConditionalUnet1D(nn.Module): + def __init__(self, + input_dim, + global_cond_dim, + diffusion_step_embed_dim=256, + down_dims=[256,512,1024], + kernel_size=5, + n_groups=8 + ): + """ + input_dim: Dim of actions. + global_cond_dim: Dim of global conditioning applied with FiLM + in addition to diffusion step embedding. This is usually obs_horizon * obs_dim + diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k + down_dims: Channel size for each UNet level. + The length of this array determines numebr of levels. + kernel_size: Conv kernel size + n_groups: Number of groups for GroupNorm + """ + + super().__init__() + all_dims = [input_dim] + list(down_dims) + start_dim = down_dims[0] + + dsed = diffusion_step_embed_dim + diffusion_step_encoder = nn.Sequential( + SinusoidalPosEmb(dsed), + nn.Linear(dsed, dsed * 4), + nn.Mish(), + nn.Linear(dsed * 4, dsed), + ) + cond_dim = dsed + global_cond_dim + + in_out = list(zip(all_dims[:-1], all_dims[1:])) + mid_dim = all_dims[-1] + self.mid_modules = nn.ModuleList([ + ConditionalResidualBlock1D( + mid_dim, mid_dim, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups + ), + ConditionalResidualBlock1D( + mid_dim, mid_dim, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups + ), + ]) + + down_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + down_modules.append(nn.ModuleList([ + ConditionalResidualBlock1D( + dim_in, dim_out, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups), + ConditionalResidualBlock1D( + dim_out, dim_out, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups), + Downsample1d(dim_out) if not is_last else nn.Identity() + ])) + + up_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + is_last = ind >= (len(in_out) - 1) + up_modules.append(nn.ModuleList([ + ConditionalResidualBlock1D( + dim_out*2, dim_in, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups), + ConditionalResidualBlock1D( + dim_in, dim_in, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups), + Upsample1d(dim_in) if not is_last else nn.Identity() + ])) + + final_conv = nn.Sequential( + Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), + nn.Conv1d(start_dim, input_dim, 1), + ) + + self.diffusion_step_encoder = diffusion_step_encoder + self.up_modules = up_modules + self.down_modules = down_modules + self.final_conv = final_conv + + print("number of parameters: {:e}".format( + sum(p.numel() for p in self.parameters())) + ) + + def forward(self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + global_cond=None): + """ + x: (B,T,input_dim) + timestep: (B,) or int, diffusion step + global_cond: (B,global_cond_dim) + output: (B,T,input_dim) + """ + # (B,T,C) + sample = sample.moveaxis(-1,-2) + # (B,C,T) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + global_feature = self.diffusion_step_encoder(timesteps) + # breakpoint() + if global_cond is not None: + global_feature = torch.cat([ + global_feature, global_cond + ], axis=-1) + + x = sample + h = [] + for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + h.append(x) + x = downsample(x) + + for mid_module in self.mid_modules: + x = mid_module(x, global_feature) + + for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): + x = torch.cat((x, h.pop()), dim=1) + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + x = upsample(x) + + x = self.final_conv(x) + + # (B,C,T) + x = x.moveaxis(-1,-2) + # (B,T,C) + return x \ No newline at end of file diff --git a/robomimic/envs/env_robosuite.py b/robomimic/envs/env_robosuite.py index 7f983cd3..7d398ff2 100644 --- a/robomimic/envs/env_robosuite.py +++ b/robomimic/envs/env_robosuite.py @@ -511,4 +511,4 @@ def __repr__(self): """ Pretty-print env description. """ - return self.name + "\n" + json.dumps(self._init_kwargs, sort_keys=True, indent=4) + return self.name + "\n" + json.dumps(self._init_kwargs, sort_keys=True, indent=4) \ No newline at end of file diff --git a/robomimic/exps/templates/act.json b/robomimic/exps/templates/act.json new file mode 100644 index 00000000..4512ecdf --- /dev/null +++ b/robomimic/exps/templates/act.json @@ -0,0 +1,160 @@ +{ + "algo_name": "act", + "experiment": { + "name": "test", + "validate": false, + "logging": { + "terminal_output_to_txt": true, + "log_tb": true, + "log_wandb": false, + "wandb_proj_name": "debug" + }, + "mse":{}, + "save": { + "enabled": true, + "every_n_seconds": null, + "every_n_epochs": 40, + "epochs": [], + "on_best_validation": false, + "on_best_rollout_return": false, + "on_best_rollout_success_rate": true + }, + "epoch_every_n_steps": 500, + "validation_epoch_every_n_steps": 10, + "env": null, + "additional_envs": null, + "render": false, + "render_video": true, + "keep_all_videos": false, + "video_skip": 5, + "rollout": { + "enabled": true, + "n": 50, + "horizon": 400, + "rate": 40, + "warmstart": 0, + "terminate_on_success": true + } + }, + "train": { + "data": null, + "output_dir":"../act_trained_models", + "num_data_workers": 4, + "hdf5_cache_mode": "low_dim", + "hdf5_use_swmr": true, + "hdf5_load_next_obs": false, + "hdf5_normalize_obs": false, + "hdf5_filter_key": null, + "seq_length": 10, + "pad_seq_length": true, + "frame_stack": 1, + "pad_frame_stack": true, + "dataset_keys": [ + "actions" + ], + "goal_mode": null, + "cuda": true, + "batch_size": 128, + "num_epochs": 10000, + "seed": 1 + }, + "algo": { + "optim_params": { + "policy": { + "optimizer_type": "adamw", + "learning_rate": { + "initial": 0.00005, + "decay_factor": 1, + "epoch_schedule": [ + 100 + ], + "scheduler_type": "linear" + }, + "regularization": { + "L2": 0.0001 + } + } + }, + "loss": { + "l2_weight": 0.0, + "l1_weight": 1.0, + "cos_weight": 0.0 + }, + "act": { + "hidden_dim": 512, + "dim_feedforward": 3200, + "backbone": "resnet18", + "enc_layers": 4, + "dec_layers": 7, + "nheads": 8, + "latent_dim": 32, + "kl_weight": 20 + } + }, + "observation": { + "modalities": { + "obs": { + "low_dim": [ + "robot0_eef_pos", + "robot0_eef_quat", + "robot0_gripper_qpos", + "object" + ], + "rgb": [], + "depth": [], + "scan": [] + }, + "goal": { + "low_dim": [], + "rgb": [], + "depth": [], + "scan": [] + } + }, + "encoder": { + "low_dim": { + "core_class": null, + "core_kwargs": {}, + "obs_randomizer_class": null, + "obs_randomizer_kwargs": {} + }, + "rgb": { + "core_class": "VisualCore", + "core_kwargs": { + "feature_dimension": 64, + "backbone_class": "ResNet18Conv", + "backbone_kwargs": { + "pretrained": false, + "input_coord_conv": false + }, + "pool_class": "SpatialSoftmax", + "pool_kwargs": { + "num_kp": 32, + "learnable_temperature": false, + "temperature": 1.0, + "noise_std": 0.0 + } + }, + "obs_randomizer_class": "CropRandomizer", + "obs_randomizer_kwargs": { + "crop_height": 76, + "crop_width": 76, + "num_crops": 1, + "pos_enc": false + } + }, + "depth": { + "core_class": "VisualCore", + "core_kwargs": {}, + "obs_randomizer_class": null, + "obs_randomizer_kwargs": {} + }, + "scan": { + "core_class": "ScanCore", + "core_kwargs": {}, + "obs_randomizer_class": null, + "obs_randomizer_kwargs": {} + } + } + } +} \ No newline at end of file diff --git a/robomimic/models/base_nets.py b/robomimic/models/base_nets.py index 93e76e6b..611a0dd7 100644 --- a/robomimic/models/base_nets.py +++ b/robomimic/models/base_nets.py @@ -16,7 +16,10 @@ from torchvision import models as vision_models import robomimic.utils.tensor_utils as TensorUtils +from robomimic.models.vit_rein import Reins, LoRAReins, MLPhead +from robomimic.utils.log_utils import bcolors +from peft import LoraConfig, get_peft_model CONV_ACTIVATIONS = { "relu": nn.ReLU, @@ -486,7 +489,6 @@ def forward(self, inputs): ) return x - class ResNet18Conv(ConvBase): """ A ResNet18 block that can be used to process input images. @@ -540,7 +542,305 @@ def __repr__(self): """Pretty print network.""" header = '{}'.format(str(self.__class__.__name__)) return header + '(input_channel={}, input_coord_conv={})'.format(self._input_channel, self._input_coord_conv) + +class ViT_Rein(ConvBase): + """ + ViT LoRA using Rein method + """ + def __init__( + self, + input_channel=3, + vit_model_class="vit_b", + lora_dim=16, + patch_size=16, + freeze=True, + return_key="x_norm_patchtokens" + ): + """ + Using pretrained observation encoder network proposed in Vision Transformers + git clone https://github.com/facebookresearch/dinov2 + pip install -r requirements.txt + Args: + input_channel (int): number of input channels for input images to the network. + If not equal to 3, modifies first conv layer to handle the number + of input channels. + vit_model_class (str): select one of the vit pretrained model "vit_b", "vit_l", "vit_s" or "vit_g" + freeze (bool): if True, use a frozen ViT pretrained model. + """ + super(ViT_Rein, self).__init__() + print(f"{bcolors.WARNING}BACKBONE FREEZE: {freeze}{bcolors.ENDC}") + assert input_channel == 3 + assert vit_model_class in [ + "vit_b", + "vit_l", + "vit_g", + "vit_s", + ] # make sure the selected vit model do exist + + # cut the last fc layer + self._input_channel = input_channel + self._vit_model_class = vit_model_class + self._freeze = freeze + self._input_coord_conv = False + self._pretrained = False + self._lora_dim = lora_dim + self._patch_size = patch_size + self._out_indices = ([7, 11, 15, 23],) + self.return_key = return_key + if self.return_key not in ["x_norm_patchtokens", "x_norm_clstoken"]: + raise ValueError(f"return_key {self.return_key} not supported") + + self.preprocess = nn.Sequential( + transforms.Resize((294,294)), + # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ) + + try: + if self._vit_model_class == "vit_s": + self.nets = dinov2_vits14 = torch.hub.load( + "facebookresearch/dinov2", "dinov2_vits14" + ) + if self._vit_model_class == "vit_l": + self.nets = dinov2_vits14 = torch.hub.load( + "facebookresearch/dinov2", "dinov2_vitl14" + ) + if self._vit_model_class == "vit_g": + self.nets = dinov2_vits14 = torch.hub.load( + "facebookresearch/dinov2", "dinov2_vitg14" + ) + if self._vit_model_class == "vit_b": + self.nets = dinov2_vits14 = torch.hub.load( + "facebookresearch/dinov2", "dinov2_vitb14" + ) + except ImportError: + print("WARNING: could not load Vit") + + try: + self._rein_layers = LoRAReins( + lora_dim=self._lora_dim, + num_layers=len(self.nets.blocks), + embed_dims=self.nets.patch_embed.proj.out_channels, + patch_size=self._patch_size, + ) + self._mlp_lora_head = MLPhead( + in_dim=3 * self.nets.patch_embed.proj.out_channels, + out_dim=5 * self.nets.patch_embed.proj.out_channels, + ) + except ImportError: + print("WARNING: could not load rein layer") + + + if self._freeze: + for param in self.nets.parameters(): + param.requires_grad = False + self.nets.eval() + + def forward(self, inputs): + x = self.preprocess(inputs) + x = self.nets.patch_embed(x) + for idx, blk in enumerate(self.nets.blocks): + x = blk(x) + x = self._rein_layers.forward( + x, + idx, + batch_first=True, + has_cls_token=True, + ) + if self.return_key == "x_norm_patchtokens": + return x + q_avg = x.mean(dim=1).unsqueeze(1) + q_max = torch.max(x,1)[0].unsqueeze(1) + q_N = x[:,x.shape[1]-1,:].unsqueeze(1) + + _q = torch.cat((q_avg, q_max, q_N), dim=1) + + x = self.nets.norm(_q) + x = x.flatten(-2, -1) + x = self._mlp_lora_head(x) + if self.return_key == "x_norm_clstoken": + return x + + + def output_shape(self, input_shape): + """ + Function to compute output shape from inputs to this module. + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + assert(len(input_shape) == 3) + + C, H, W = input_shape + out_dim = self._mlp_lora_head._out_dim + + if self.return_key == "x_norm_patchtokens": + return [441, out_dim] + elif self.return_key == "x_norm_clstoken": + return [out_dim] + else: + raise NotImplementedError + + def __repr__(self): + """Pretty print network.""" + # print( + # "**Number of learnable params:", + # sum(p.numel() for p in self.nets.parameters() if p.requires_grad), + # " Freeze:", + # self._freeze, + # ) + # print("**Number of params:", sum(p.numel() for p in self.nets.parameters())) + + header = "{}".format(str(self.__class__.__name__)) + return ( + header + + "(input_channel={}, input_coord_conv={}, pretrained={}, freeze={})".format( + self._input_channel, + self._input_coord_conv, + self._pretrained, + self._freeze, + ) + ) + header = '{}'.format(str(self.__class__.__name__)) + return header + '(input_channel={}, input_coord_conv={}, pretrained={}, freeze={})'.format(self._input_channel, self._input_coord_conv, self._pretrained, self._freeze) + + +class Vit(ConvBase): + """ + Vision transformer with optional peft lora + """ + + def __init__(self, input_channel=3, vit_model_class="vit_b", freeze=True, return_key="x_norm_patchtokens", use_lora=False, **kwargs): + """ + Using pretrained observation encoder network proposed in Vision Transformers + git clone https://github.com/facebookresearch/dinov2 + pip install -r requirements.txt + Args: + input_channel (int): number of input channels for input images to the network. + If not equal to 3, modifies first conv layer to handle the number + of input channels. + vit_model_class (str): select one of the vit pretrained model "vit_b", "vit_l", "vit_s", "vit_g" or "radio" + freeze (bool): if True, use a frozen ViT pretrained model. + """ + super(Vit, self).__init__() + + assert input_channel == 3 + assert vit_model_class in ["vit_b", "vit_l" ,"vit_g", "vit_s", "radio"] # make sure the selected vit model do exist + + # cut the last fc layer + self._input_channel = input_channel + self._vit_model_class = vit_model_class + + self._model_version = kwargs.get("model_version", None) + + self._freeze = freeze + self._input_coord_conv = False + self._pretrained = False + self.return_key = return_key + if self.return_key not in ["x_norm_patchtokens", "x_norm_clstoken"]: + raise ValueError(f"return_key {self.return_key} not supported") + + self.use_lora = use_lora + + self.preprocess = nn.Sequential( + transforms.Resize((224, 224)), + # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ) + + + try: + if self._vit_model_class == "vit_s": + self.nets = dinov2_vits14 = torch.hub.load( + "facebookresearch/dinov2", "dinov2_vits14" + ) + self.patch_size = self.nets.patch_embed.patch_size + if self._vit_model_class == "vit_l": + self.nets = dinov2_vits14 = torch.hub.load( + "facebookresearch/dinov2", "dinov2_vitl14" + ) + self.patch_size = self.nets.patch_embed.patch_size + if self._vit_model_class == "vit_g": + self.nets = dinov2_vits14 = torch.hub.load( + "facebookresearch/dinov2", "dinov2_vitg14" + ) + self.patch_size = self.nets.patch_embed.patch_size + if self._vit_model_class == "vit_b": + self.nets = dinov2_vits14 = torch.hub.load( + "facebookresearch/dinov2", "dinov2_vitb14" + ) + self.patch_size = self.nets.patch_embed.patch_size + if self._vit_model_class == "radio": + radio_model_version = self._model_version if self._model_version is not None else "radio_v2.5-l" + self.nets = torch.hub.load( + 'NVlabs/RADIO', 'radio_model', version=radio_model_version, progress=True, skip_validation=True + ) + self.preprocess = nn.Sequential( + transforms.Resize((224, 224)), + # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ) + self.patch_size = self.nets.patch_size + + + except ImportError: + print("WARNING: could not load Vit") + + if self.use_lora: + lora_config = LoraConfig( + r=8, + lora_alpha=32, + target_modules=["qkv", "query", "key", "value"], + lora_dropout=0.1, + bias="none", + task_type="SEQ_2_SEQ_LM" + ) + self.nets = get_peft_model(self.nets, lora_config) + + if self._freeze and not self.use_lora: + for param in self.nets.parameters(): + param.requires_grad = False + self.nets.eval() + + def forward(self, inputs): + + x = self.preprocess(inputs) + # x = self.nets(x) + if "vit" in self._vit_model_class: + x = self.nets.forward_features(x)[self.return_key] + else: + summary, x = self.nets(x) + + return x + + def output_shape(self, input_shape): + """ + Function to compute output shape from inputs to this module. + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + assert(len(input_shape) == 3) + + C, H, W = input_shape + out_dim = self.nets.patch_embed.proj.out_channels + + if self.return_key == "x_norm_patchtokens": + return [(H / self.patch_size) * (W / self.patch_size), out_dim] + elif self.return_key == "x_norm_clstoken": + return [out_dim] + + def __repr__(self): + """Pretty print network.""" + print("**Number of learnable params:",sum(p.numel() for p in self.nets.parameters() if p.requires_grad)," Freeze:",self._freeze) + print("**Number of params:",sum(p.numel() for p in self.nets.parameters())) + + header = '{}'.format(str(self.__class__.__name__)) + return header + '(input_channel={}, input_coord_conv={}, pretrained={}, freeze={})'.format(self._input_channel, self._input_coord_conv, self._pretrained, self._freeze, self.use_lora) class R3MConv(ConvBase): """ @@ -583,7 +883,7 @@ def __init__( preprocess = nn.Sequential( transforms.Resize(256), transforms.CenterCrop(224), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ) self.nets = Sequential(*([preprocess] + list(net.module.convnet.children())), has_output_shape = False) if freeze: diff --git a/robomimic/models/obs_core.py b/robomimic/models/obs_core.py index 41830438..81f12b66 100644 --- a/robomimic/models/obs_core.py +++ b/robomimic/models/obs_core.py @@ -11,8 +11,9 @@ import torch import torch.nn as nn -from torchvision.transforms import Lambda, Compose +from torchvision.transforms import Lambda, Compose, RandomResizedCrop import torchvision.transforms.functional as TVF +import torchvision.transforms as TT import robomimic.models.base_nets as BaseNets import robomimic.utils.tensor_utils as TensorUtils @@ -23,6 +24,9 @@ from robomimic.models.base_nets import * from robomimic.utils.vis_utils import visualize_image_randomizer from robomimic.macros import VISUALIZE_RANDOMIZER +import datetime +import matplotlib.pyplot as plt + """ @@ -578,6 +582,164 @@ def __repr__(self): return msg +class CropResizeRandomizer(Randomizer): + """ + Randomly sample crop, then resize to specified size + """ + def __init__( + self, + input_shape, + size, + scale, + ratio, + num_crops=1, + pos_enc=False, + ): + """ + Args: + input_shape (tuple, list): shape of input (not including batch dimension) + crop_height (int): crop height + crop_width (int): crop width + resize_height (int): resize height + resize_width (int): resize width + num_crops (int): number of random crops to take + pos_enc (bool): if True, add 2 channels to the output to encode the spatial + location of the cropped pixels in the source image + """ + super(CropResizeRandomizer, self).__init__() + + assert len(input_shape) == 3 # (C, H, W) + # assert crop_height < input_shape[1] + # assert crop_width < input_shape[2] + + self.input_shape = input_shape + self.size = size + self.scale = scale + self.ratio = ratio + self.num_crops = num_crops + self.pos_enc = pos_enc + + self.resize_crop = RandomResizedCrop(size=size, scale=scale, ratio=ratio, interpolation=TVF.InterpolationMode.BILINEAR) + + def output_shape_in(self, input_shape=None): + shape = [self.input_shape[0], self.size[0], self.size[1]] + return shape + + def output_shape_out(self, input_shape=None): + return list(input_shape) + + def _visualize(self, pre_random_input, randomized_input, num_samples_to_visualize=2): + """ + pre_random_input: (B, C, H, W) + randomized_input: (B, C, H, W) + num_samples_to_visualize: + Use plt.imsave to save a plot with the original input and the randomized input side by side. Save it to debug/augIms/ with a unique name. + """ + fig, axes = plt.subplots(num_samples_to_visualize, 2, figsize=(10, 5*num_samples_to_visualize)) + for i in range(num_samples_to_visualize): + axes[i, 0].imshow(pre_random_input[i].permute(1, 2, 0).cpu().numpy()) + axes[i, 0].set_title("Original Input") + axes[i, 1].imshow(randomized_input[i].permute(1, 2, 0).cpu().numpy()) + axes[i, 1].set_title("Randomized Input") + plt.tight_layout() + plt.savefig(f"debug/augIms/sample_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.png") + plt.close(fig) + # plt.close(fig) + # fig, axes = plt.subplots(1, 2) + # axes[0].imshow(pre_random_input[i].permute(1, 2, 0).cpu().numpy()) + # axes[0].set_title("Original Input") + # axes[1].imshow(randomized_input[i].permute(1, 2, 0).cpu().numpy()) + # axes[1].set_title("Randomized Input") + # plt.savefig(f"debug/augIms/sample_{i}.png") + # plt.close(fig) + + def _forward_in(self, inputs): + """ + Samples single random crop for each input + """ + # assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions + # out, _ = ObsUtils.sample_random_image_crops( + # images=inputs, + # crop_height=self.crop_height, + # crop_width=self.crop_width, + # num_crops=self.num_crops, + # pos_enc=self.pos_enc, + # ) + # # [B, N, ...] -> [B * N, ...] + # out = TensorUtils.join_dimensions(out, 0, 1) + out = self.resize_crop(inputs) + # self._visualize(inputs, out) + + return out + + def _forward_in_eval(self, inputs): + """ + Do center crops during eval + """ + # assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions + # inputs = inputs.permute(*range(inputs.dim()-3), inputs.dim()-2, inputs.dim()-1, inputs.dim()-3) + # out = ObsUtils.center_crop(inputs, self.crop_height, self.crop_width) + # out = out.permute(*range(out.dim()-3), out.dim()-1, out.dim()-3, out.dim()-2) + # return out + + # just resize + return TVF.resize(inputs, size=self.size, interpolation=TVF.InterpolationMode.BILINEAR) + + + def _forward_out(self, inputs): + """ + Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N + to result in shape [B, ...] to make sure the network output is consistent with + what would have happened if there were no randomization. + + In this class I assume N = 1 so I just return input + """ + + return inputs + +class CropResizeColorRandomizer(CropResizeRandomizer): + """ + Does the same thing as CropResizeRandomizer, but additionally performs color jitter + """ + def __init__( + self, + input_shape, + size, + scale, + ratio, + num_crops=1, + pos_enc=False, + brightness_min=1.0, + brightness_max=1.0, + contrast_min=1.0, + contrast_max=1.0, + saturation_min=1.0, + saturation_max=1.0, + hue_min=0.0, + hue_max=0.0 + ): + super(CropResizeColorRandomizer, self).__init__( + input_shape=input_shape, + size=size, + scale=scale, + ratio=ratio, + num_crops=num_crops, + pos_enc=pos_enc, + ) + self.color_jitter = TT.ColorJitter(brightness=(brightness_min, brightness_max), contrast=(contrast_min, contrast_max), saturation=(saturation_min, saturation_max), hue=(hue_min, hue_max)) + + def _forward_in(self, inputs): + out = super(CropResizeColorRandomizer, self)._forward_in(inputs) + out = self.color_jitter(out) + # self._visualize(inputs, out) + return out + + def _forward_in_eval(self, inputs): + out = super(CropResizeColorRandomizer, self)._forward_in_eval(inputs) + return out + + + class ColorRandomizer(Randomizer): """ Randomly sample color jitter at input, and then average across color jtters at output. diff --git a/robomimic/models/obs_nets.py b/robomimic/models/obs_nets.py index b3284185..4a0b9483 100644 --- a/robomimic/models/obs_nets.py +++ b/robomimic/models/obs_nets.py @@ -25,6 +25,7 @@ FeatureAggregator from robomimic.models.obs_core import VisualCore, Randomizer from robomimic.models.transformers import PositionalEncoding, GPT_Backbone +from robomimic.models.base_nets import Vit def obs_encoder_factory( @@ -101,13 +102,13 @@ class ObservationEncoder(Module): Module that processes inputs by observation key and then concatenates the processed observation keys together. Each key is processed with an encoder head network. Call @register_obs_key to register observation keys with the encoder and then - finally call @make to create the encoder networks. + finally call @make to create the encoder networks. """ def __init__(self, feature_activation=nn.ReLU): """ Args: feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass - None to apply no activation. + None to apply no activation. """ super(ObservationEncoder, self).__init__() self.obs_shapes = OrderedDict() @@ -120,12 +121,12 @@ def __init__(self, feature_activation=nn.ReLU): self._locked = False def register_obs_key( - self, + self, name, - shape, - net_class=None, - net_kwargs=None, - net=None, + shape, + net_class=None, + net_kwargs=None, + net=None, randomizer=None, share_net_from=None, ): @@ -143,7 +144,7 @@ def register_obs_key( instead of creating a different net randomizer (Randomizer instance): if provided, use this Module to augment observation keys coming in to the encoder, and possibly augment the processed output as well - share_net_from (str): if provided, use the same instance of @net_class + share_net_from (str): if provided, use the same instance of @net_class as another observation key. This observation key must already exist in this encoder. Warning: Note that this does not share the observation key randomizer """ @@ -362,7 +363,7 @@ class ObservationGroupEncoder(Module): The class takes a dictionary of dictionaries, @observation_group_shapes. Each key corresponds to a observation group (e.g. 'obs', 'subgoal', 'goal') - and each OrderedDict should be a map between modalities and + and each OrderedDict should be a map between modalities and expected input shapes (e.g. { 'image' : (3, 120, 160) }). """ def __init__( @@ -403,7 +404,7 @@ def __init__( # type checking assert isinstance(observation_group_shapes, OrderedDict) assert np.all([isinstance(observation_group_shapes[k], OrderedDict) for k in observation_group_shapes]) - + self.observation_group_shapes = observation_group_shapes # create an observation encoder per observation group @@ -421,7 +422,7 @@ def forward(self, **inputs): Args: inputs (dict): dictionary that maps observation groups to observation - dictionaries of torch.Tensor batches that agree with + dictionaries of torch.Tensor batches that agree with @self.observation_group_shapes. All observation groups in @self.observation_group_shapes must be present, but additional observation groups can also be present. Note that these are specified @@ -567,7 +568,7 @@ def output_shape(self, input_shape=None): """ return { k : list(self.output_shapes[k]) for k in self.output_shapes } - def forward(self, **inputs): + def forward(self, return_latent=False, **inputs): """ Process each set of inputs in its own observation group. @@ -583,6 +584,8 @@ def forward(self, **inputs): """ enc_outputs = self.nets["encoder"](**inputs) mlp_out = self.nets["mlp"](enc_outputs) + if return_latent: + return self.nets["decoder"](mlp_out), enc_outputs.detach(), mlp_out.detach() return self.nets["decoder"](mlp_out) def _to_string(self): @@ -604,7 +607,6 @@ def __repr__(self): msg = header + '(' + msg + '\n)' return msg - class RNN_MIMO_MLP(Module): """ A wrapper class for a multi-step RNN and a per-step MLP and a decoder. diff --git a/robomimic/models/vit_rein.py b/robomimic/models/vit_rein.py new file mode 100644 index 00000000..e73f9327 --- /dev/null +++ b/robomimic/models/vit_rein.py @@ -0,0 +1,162 @@ +""" +Contains torch Modules for implementation of rein method +for domain adaptation of DINOv2 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from functools import reduce +from operator import mul +from torch import Tensor + +class MLPhead(nn.Module): + def __init__(self, + in_dim: int, + out_dim: int, + **kwargs) -> None: + super().__init__(**kwargs) + self._in_dim = in_dim + self._out_dim = out_dim + + self._mlp = nn.Linear(self._in_dim, self._out_dim) + + def forward(self, x: Tensor) -> Tensor: + x = self._mlp.forward(x) + return x + +class Reins(nn.Module): + def __init__( + self, + num_layers: int, + embed_dims: int, + patch_size: int, + query_dims: int = 256, + token_length: int = 100, + use_softmax: bool = True, + link_token_to_query: bool = True, + scale_init: float = 0.001, + zero_mlp_delta_f: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + self.embed_dims = embed_dims + self.patch_size = patch_size + self.query_dims = query_dims + self.token_length = token_length + self.link_token_to_query = link_token_to_query + self.scale_init = scale_init + self.use_softmax = use_softmax + self.zero_mlp_delta_f = zero_mlp_delta_f + self.create_model() + + def create_model(self): + self.learnable_tokens = nn.Parameter( + torch.empty([self.num_layers, self.token_length, self.embed_dims]) + ) + self.scale = nn.Parameter(torch.tensor(self.scale_init)) + self.mlp_token2feat = nn.Linear(self.embed_dims, self.embed_dims) + self.mlp_delta_f = nn.Linear(self.embed_dims, self.embed_dims) + val = math.sqrt( + 6.0 + / float( + 3 * reduce(mul, (self.patch_size, self.patch_size), 1) + self.embed_dims + ) + ) + nn.init.uniform_(self.learnable_tokens.data, -val, val) + nn.init.kaiming_uniform_(self.mlp_delta_f.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.mlp_token2feat.weight, a=math.sqrt(5)) + self.transform = nn.Linear(self.embed_dims, self.query_dims) + self.merge = nn.Linear(self.query_dims * 3, self.query_dims) + if self.zero_mlp_delta_f: + del self.scale + self.scale = 1.0 + nn.init.zeros_(self.mlp_delta_f.weight) + nn.init.zeros_(self.mlp_delta_f.bias) + + def return_auto(self, feats): + if self.link_token_to_query: + tokens = self.transform(self.get_tokens(-1)).permute(1, 2, 0) + tokens = torch.cat( + [ + F.max_pool1d(tokens, kernel_size=self.num_layers), + F.avg_pool1d(tokens, kernel_size=self.num_layers), + tokens[:, :, -1].unsqueeze(-1), + ], + dim=-1, + ) + querys = self.merge(tokens.flatten(-2, -1)) + return feats, querys + else: + return feats + + def get_tokens(self, layer: int) -> Tensor: + if layer == -1: + # return all + return self.learnable_tokens + else: + return self.learnable_tokens[layer] + + def forward( + self, feats: Tensor, layer: int, batch_first=False, has_cls_token=True + ) -> Tensor: + if batch_first: + feats = feats.permute(1, 0, 2) + if has_cls_token: + cls_token, feats = torch.tensor_split(feats, [1], dim=0) + tokens = self.get_tokens(layer) + delta_feat = self.forward_delta_feat( + feats, + tokens, + layer, + ) + delta_feat = delta_feat * self.scale + feats = feats + delta_feat + if has_cls_token: + feats = torch.cat([cls_token, feats], dim=0) + if batch_first: + feats = feats.permute(1, 0, 2) + return feats + + def forward_delta_feat(self, feats: Tensor, tokens: Tensor, layers: int) -> Tensor: + attn = torch.einsum("nbc,mc->nbm", feats, tokens) + if self.use_softmax: + attn = attn * (self.embed_dims**-0.5) + attn = F.softmax(attn, dim=-1) + delta_f = torch.einsum( + "nbm,mc->nbc", + attn[:, :, 1:], + self.mlp_token2feat(tokens[1:, :]), + ) + delta_f = self.mlp_delta_f(delta_f + feats) + return delta_f + +class LoRAReins(Reins): + def __init__(self, lora_dim=16, **kwargs): + self.lora_dim = lora_dim + super().__init__(**kwargs) + + def create_model(self): + super().create_model() + del self.learnable_tokens + self.learnable_tokens_a = nn.Parameter( + torch.empty([self.num_layers, self.token_length, self.lora_dim]) + ) + self.learnable_tokens_b = nn.Parameter( + torch.empty([self.num_layers, self.lora_dim, self.embed_dims]) + ) + val = math.sqrt( + 6.0 + / float( + 3 * reduce(mul, (self.patch_size, self.patch_size), 1) + + (self.embed_dims * self.lora_dim) ** 0.5 + ) + ) + nn.init.uniform_(self.learnable_tokens_a.data, -val, val) + nn.init.uniform_(self.learnable_tokens_b.data, -val, val) + + def get_tokens(self, layer): + if layer == -1: + return self.learnable_tokens_a @ self.learnable_tokens_b + else: + return self.learnable_tokens_a[layer] @ self.learnable_tokens_b[layer] \ No newline at end of file diff --git a/robomimic/utils/dataset.py b/robomimic/utils/dataset.py index 075c4d59..065c442c 100644 --- a/robomimic/utils/dataset.py +++ b/robomimic/utils/dataset.py @@ -13,7 +13,9 @@ import robomimic.utils.tensor_utils as TensorUtils import robomimic.utils.obs_utils as ObsUtils import robomimic.utils.log_utils as LogUtils - +import time +import scipy +import matplotlib.pyplot as plt class SequenceDataset(torch.utils.data.Dataset): def __init__( @@ -21,6 +23,7 @@ def __init__( hdf5_path, obs_keys, dataset_keys, + ac_key, frame_stack=1, seq_length=1, pad_frame_stack=True, @@ -32,6 +35,8 @@ def __init__( hdf5_normalize_obs=False, filter_by_attribute=None, load_next_obs=True, + prestacked_actions=False, + hdf5_normalize_actions=False, ): """ Dataset class for fetching sequences of experience. @@ -79,13 +84,19 @@ def __init__( demonstrations to load load_next_obs (bool): whether to load next_obs from the dataset + + imagenet_normalize_images (bool): if True, normalize images using ImageNet mean and std """ super(SequenceDataset, self).__init__() + self.prestacked_actions = prestacked_actions + self.hdf5_path = os.path.expanduser(hdf5_path) self.hdf5_use_swmr = hdf5_use_swmr self.hdf5_normalize_obs = hdf5_normalize_obs + self.hdf5_normalize_actions = hdf5_normalize_actions self._hdf5_file = None + self.ac_key = ac_key assert hdf5_cache_mode in ["all", "low_dim", None] self.hdf5_cache_mode = hdf5_cache_mode @@ -212,7 +223,8 @@ def hdf5_file(self): This property allows for a lazy hdf5 file open. """ if self._hdf5_file is None: - self._hdf5_file = h5py.File(self.hdf5_path, 'r', swmr=self.hdf5_use_swmr, libver='latest') + print("opening hdf5") + self._hdf5_file = h5py.File(self.hdf5_path, 'r', swmr=self.hdf5_use_swmr, libver='latest', rdcc_nbytes=1e10) return self._hdf5_file def close_and_delete_hdf5_handle(self): @@ -304,52 +316,31 @@ def normalize_obs(self): Computes a dataset-wide mean and standard deviation for the observations (per dimension and per obs key) and returns it. """ - def _compute_traj_stats(traj_obs_dict): - """ - Helper function to compute statistics over a single trajectory of observations. - """ - traj_stats = { k : {} for k in traj_obs_dict } - for k in traj_obs_dict: - traj_stats[k]["n"] = traj_obs_dict[k].shape[0] - traj_stats[k]["mean"] = traj_obs_dict[k].mean(axis=0, keepdims=True) # [1, ...] - traj_stats[k]["sqdiff"] = ((traj_obs_dict[k] - traj_stats[k]["mean"]) ** 2).sum(axis=0, keepdims=True) # [1, ...] - return traj_stats - - def _aggregate_traj_stats(traj_stats_a, traj_stats_b): - """ - Helper function to aggregate trajectory statistics. - See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm - for more information. - """ - merged_stats = {} - for k in traj_stats_a: - n_a, avg_a, M2_a = traj_stats_a[k]["n"], traj_stats_a[k]["mean"], traj_stats_a[k]["sqdiff"] - n_b, avg_b, M2_b = traj_stats_b[k]["n"], traj_stats_b[k]["mean"], traj_stats_b[k]["sqdiff"] - n = n_a + n_b - mean = (n_a * avg_a + n_b * avg_b) / n - delta = (avg_b - avg_a) - M2 = M2_a + M2_b + (delta ** 2) * (n_a * n_b) / n - merged_stats[k] = dict(n=n, mean=mean, sqdiff=M2) - return merged_stats - - # Run through all trajectories. For each one, compute minimal observation statistics, and then aggregate - # with the previous statistics. - ep = self.demos[0] - obs_traj = {k: self.hdf5_file["data/{}/obs/{}".format(ep, k)][()].astype('float32') for k in self.obs_keys} - obs_traj = ObsUtils.process_obs_dict(obs_traj) - merged_stats = _compute_traj_stats(obs_traj) - print("SequenceDataset: normalizing observations...") - for ep in LogUtils.custom_tqdm(self.demos[1:]): - obs_traj = {k: self.hdf5_file["data/{}/obs/{}".format(ep, k)][()].astype('float32') for k in self.obs_keys} - obs_traj = ObsUtils.process_obs_dict(obs_traj) - traj_stats = _compute_traj_stats(obs_traj) - merged_stats = _aggregate_traj_stats(merged_stats, traj_stats) - - obs_normalization_stats = { k : {} for k in merged_stats } - for k in merged_stats: - # note we add a small tolerance of 1e-3 for std - obs_normalization_stats[k]["mean"] = merged_stats[k]["mean"].astype(np.float32) - obs_normalization_stats[k]["std"] = (np.sqrt(merged_stats[k]["sqdiff"] / merged_stats[k]["n"]) + 1e-3).astype(np.float32) + def _calc_helper(hdf5_key): + obs = [] + demo_keys = [k for k in self.hdf5_file["data"].keys() if "demo" in k] + for ep in demo_keys: + obs_traj = self.hdf5_file[f"data/{ep}/{hdf5_key}"][()].astype('float32') + obs.append(obs_traj) + if len(obs) == 0: + breakpoint() + obs = np.concatenate(obs, axis=0) + mean = obs.mean(axis=0, keepdims=True) + std = obs.std(axis=0, keepdims=True) + 1e-3 + return dict(mean=mean, std=std) + + + obs_normalization_stats = {} + # keys_to_norm = [f"obs/{k}" for k in self.obs_keys if ObsUtils.key_is_obs_modality(k, "low_dim")] + ["actions"] + for key in self.obs_keys: + # hardcoded language key not normalized for now + if ObsUtils.key_is_obs_modality(key, "low_dim") and "lang" not in key: + obs_normalization_stats[key] = _calc_helper(f"obs/{key}") + + for key in self.dataset_keys: + if "actions" in key: + obs_normalization_stats[key] = _calc_helper(key) + return obs_normalization_stats def get_obs_normalization_stats(self): @@ -363,7 +354,10 @@ def get_obs_normalization_stats(self): with a "mean" and "std" of shape (1, ...) where ... is the default shape for the observation. """ - assert self.hdf5_normalize_obs, "not using observation normalization!" + # assert self.hdf5_normalize_obs, "not using observation normalization!" + if not self.hdf5_normalize_obs: + print("Warning: not using observation normalization!") + return None return deepcopy(self.obs_normalization_stats) def get_dataset_for_ep(self, ep, key): @@ -371,6 +365,7 @@ def get_dataset_for_ep(self, ep, key): Helper utility to get a dataset for a specific demonstration. Takes into account whether the dataset has been loaded into memory. """ + # check if this key should be in memory key_should_be_in_memory = (self.hdf5_cache_mode in ["all", "low_dim"]) @@ -381,7 +376,7 @@ def get_dataset_for_ep(self, ep, key): assert(key1 in ['obs', 'next_obs']) if key2 not in self.obs_keys_in_memory: key_should_be_in_memory = False - + if key_should_be_in_memory: # read cache if '/' in key: @@ -433,7 +428,6 @@ def get_item(self, index): goal_index = None if self.goal_mode == "last": goal_index = end_index_in_demo - 1 - meta["obs"] = self.get_obs_sequence_from_demo( demo_id, index_in_demo=index_in_demo, @@ -466,7 +460,15 @@ def get_item(self, index): return meta - def get_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1): + def get_sequence_from_demo( + self, + demo_id, + index_in_demo, + keys, + num_frames_to_stack=0, + seq_length=1, + dont_load_fut=None, + ): """ Extract a (sub)sequence of data items from a demo given the @keys of the items. @@ -476,10 +478,14 @@ def get_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_sta keys (tuple): list of keys to extract num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range + dont_load_fut (list): list of keys to not load future items for Returns: a dictionary of extracted items. """ + + if dont_load_fut is None: + dont_load_fut = [] assert num_frames_to_stack >= 0 assert seq_length >= 1 @@ -503,16 +509,20 @@ def get_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_sta # fetch observation from the dataset file seq = dict() for k in keys: + t = time.time() data = self.get_dataset_for_ep(demo_id, k) - seq[k] = data[seq_begin_index: seq_end_index] + true_end_index = seq_begin_index + 1 if k.split("/")[-1] in dont_load_fut else seq_end_index + seq[k] = data[seq_begin_index: true_end_index] - seq = TensorUtils.pad_sequence(seq, padding=(seq_begin_pad, seq_end_pad), pad_same=True) + for k in seq: + if k.split("/")[-1] not in dont_load_fut: + seq[k] = TensorUtils.pad_sequence(seq[k], padding=(seq_begin_pad, seq_end_pad), pad_same=True) pad_mask = np.array([0] * seq_begin_pad + [1] * (seq_end_index - seq_begin_index) + [0] * seq_end_pad) pad_mask = pad_mask[:, None].astype(bool) return seq, pad_mask - def get_obs_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1, prefix="obs"): + def get_obs_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1, prefix="obs", dont_load_fut=False): """ Extract a (sub)sequence of observation items from a demo given the @keys of the items. @@ -527,16 +537,25 @@ def get_obs_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to Returns: a dictionary of extracted items. """ + seq_length_to_load = 1 if self.prestacked_actions else seq_length obs, pad_mask = self.get_sequence_from_demo( demo_id, index_in_demo=index_in_demo, keys=tuple('{}/{}'.format(prefix, k) for k in keys), num_frames_to_stack=num_frames_to_stack, - seq_length=seq_length, + seq_length=seq_length_to_load, + dont_load_fut=dont_load_fut ) obs = {k.split('/')[1]: obs[k] for k in obs} # strip the prefix if self.get_pad_mask: obs["pad_mask"] = pad_mask + + # Interpolate obs + # to_interp = [k for k in obs if ObsUtils.key_is_obs_modality(k, "low_dim")] + to_interp = ["pad_mask"] + # t = time.time() + obs["pad_mask"] = np.repeat(obs["pad_mask"], seq_length, axis=0) + # print("Interpolation time: ", time.time() - t) return obs @@ -554,15 +573,30 @@ def get_dataset_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frame Returns: a dictionary of extracted items. """ + seq_length_to_load = 1 if self.prestacked_actions else seq_length data, pad_mask = self.get_sequence_from_demo( demo_id, index_in_demo=index_in_demo, keys=keys, num_frames_to_stack=num_frames_to_stack, - seq_length=seq_length, + seq_length=seq_length_to_load, ) if self.get_pad_mask: data["pad_mask"] = pad_mask + + # interpolate actions + to_interp = [k for k in data] + # t = time.time() + for k in data: + if k == "pad_mask": + continue + if data[k].shape[0] == 1 and len(data[k].shape) == 3: + data[k] = data[k][0] + if not "actions" in k: + raise ValueError("Interpolating actions, but key is not an action, key: ", k) + + data["pad_mask"] = np.repeat(data["pad_mask"], seq_length, axis=0) + # print("Interpolation time: ", time.time() - t) return data def get_trajectory_at_index(self, index): diff --git a/robomimic/utils/file_utils.py b/robomimic/utils/file_utils.py index 65db00fd..a79e5efc 100644 --- a/robomimic/utils/file_utils.py +++ b/robomimic/utils/file_utils.py @@ -19,6 +19,7 @@ from robomimic.config import config_factory from robomimic.algo import algo_factory from robomimic.algo import RolloutPolicy +from robomimic.utils.log_utils import bcolors def create_hdf5_filter_key(hdf5_path, demo_keys, key_name): @@ -111,7 +112,7 @@ def get_env_metadata_from_dataset(dataset_path, set_env_specific_obs_processors= return env_meta -def get_shape_metadata_from_dataset(dataset_path, all_obs_keys=None, verbose=False): +def get_shape_metadata_from_dataset(dataset_path, all_obs_keys=None, verbose=False, ac_key="actions"): """ Retrieves shape metadata from dataset. @@ -120,6 +121,7 @@ def get_shape_metadata_from_dataset(dataset_path, all_obs_keys=None, verbose=Fal all_obs_keys (list): list of all modalities used by the model. If not provided, all modalities present in the file are used. verbose (bool): if True, include print statements + ac_dim (bool): whether to pull ac_dim Returns: shape_meta (dict): shape metadata. Contains the following keys: @@ -140,7 +142,9 @@ def get_shape_metadata_from_dataset(dataset_path, all_obs_keys=None, verbose=Fal demo = f["data/{}".format(demo_id)] # action dimension - shape_meta['ac_dim'] = f["data/{}/actions".format(demo_id)].shape[1] + shape_meta["ac_dim"] = f[f"data/{demo_id}/{ac_key}"].shape[-1] + if len(f[f"data/{demo_id}/{ac_key}"].shape) > 2: + print(f"{bcolors.WARNING}Warning: action shape has more than 2 dims, if these aren't prepacked actions something may be wrong?{bcolors.ENDC}") # observation dimensions all_shapes = OrderedDict() @@ -150,6 +154,10 @@ def get_shape_metadata_from_dataset(dataset_path, all_obs_keys=None, verbose=Fal all_obs_keys = [k for k in demo["obs"]] for k in sorted(all_obs_keys): + if k not in demo["obs"]: + if verbose: + print(f"Warning: {k} not in some demos['obs']") + continue initial_shape = demo["obs/{}".format(k)].shape[1:] if verbose: print("obs key {} with shape {}".format(k, initial_shape)) diff --git a/robomimic/utils/log_utils.py b/robomimic/utils/log_utils.py index 1e1be989..1f8aa6ae 100644 --- a/robomimic/utils/log_utils.py +++ b/robomimic/utils/log_utils.py @@ -16,7 +16,16 @@ # global list of warning messages can be populated with @log_warning and flushed with @flush_warnings WARNINGS_BUFFER = [] - +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' class PrintLogger(object): """ @@ -43,7 +52,7 @@ class DataLogger(object): """ Logging class to log metrics to tensorboard and/or retrieve running statistics about logged data. """ - def __init__(self, log_dir, config, log_tb=True, log_wandb=False): + def __init__(self, log_dir, config, log_tb=True, log_wandb=False, uid=None): """ Args: log_dir (str): base path to store logs @@ -56,6 +65,7 @@ def __init__(self, log_dir, config, log_tb=True, log_wandb=False): if log_tb: from tensorboardX import SummaryWriter self._tb_logger = SummaryWriter(os.path.join(log_dir, 'tb')) + if log_wandb: import wandb @@ -79,7 +89,7 @@ def __init__(self, log_dir, config, log_tb=True, log_wandb=False): self._wandb_logger.init( entity=Macros.WANDB_ENTITY, project=config.experiment.logging.wandb_proj_name, - name=config.experiment.name, + name=uid if uid else config.experiment.name, dir=log_dir, mode=("offline" if attempt == num_attempts - 1 else "online"), ) diff --git a/robomimic/utils/obs_utils.py b/robomimic/utils/obs_utils.py index 66fb1272..d05ae229 100644 --- a/robomimic/utils/obs_utils.py +++ b/robomimic/utils/obs_utils.py @@ -325,7 +325,7 @@ def batch_image_chw_to_hwc(im): return im.permute(start_dims + [s + 2, s + 3, s + 1]) -def process_obs(obs, obs_modality=None, obs_key=None): +def process_obs(obs, obs_modality=None, obs_key=None, imagenet_normalize=False): """ Process observation @obs corresponding to @obs_modality modality (or implicitly inferred from @obs_key) to prepare for network input. @@ -345,10 +345,10 @@ def process_obs(obs, obs_modality=None, obs_key=None): assert obs_modality is not None or obs_key is not None, "Either obs_modality or obs_key must be specified!" if obs_key is not None: obs_modality = OBS_KEYS_TO_MODALITIES[obs_key] - return OBS_MODALITY_CLASSES[obs_modality].process_obs(obs) + return OBS_MODALITY_CLASSES[obs_modality].process_obs(obs, imagenet_normalize=imagenet_normalize) -def process_obs_dict(obs_dict): +def process_obs_dict(obs_dict, imagenet_normalize=False): """ Process observations in observation dictionary to prepare for network input. @@ -359,10 +359,11 @@ def process_obs_dict(obs_dict): Returns: new_dict (dict): dictionary where observation keys have been processed by their corresponding processors """ - return { k : process_obs(obs=obs, obs_key=k) for k, obs in obs_dict.items() } # shallow copy + return { k : process_obs(obs=obs, obs_key=k, imagenet_normalize=imagenet_normalize) for k, obs in obs_dict.items() } # shallow copy -def process_frame(frame, channel_dim, scale): + +def process_frame(frame, channel_dim, scale, imagenet_normalize=False): """ Given frame fetched from dataset, process for network input. Converts array to float (from uint8), normalizes pixels from range [0, @scale] to [0, 1], and channel swaps @@ -382,6 +383,10 @@ def process_frame(frame, channel_dim, scale): if scale is not None: frame = frame / scale frame = frame.clip(0.0, 1.0) + if imagenet_normalize: + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + frame = (frame - mean) / std frame = batch_image_hwc_to_chw(frame) return frame @@ -462,7 +467,7 @@ def get_processed_shape(obs_modality, input_shape): return list(process_obs(obs=np.zeros(input_shape), obs_modality=obs_modality).shape) -def normalize_obs(obs_dict, obs_normalization_stats): +def normalize_batch(batch, normalization_stats, normalize_actions=True): """ Normalize observations using the provided "mean" and "std" entries for each observation key. The observation dictionary will be @@ -481,18 +486,16 @@ def normalize_obs(obs_dict, obs_normalization_stats): """ # ensure we have statistics for each modality key in the observation - assert set(obs_dict.keys()).issubset(obs_normalization_stats) - - for m in obs_dict: - # get rid of extra dimension - we will pad for broadcasting later - mean = obs_normalization_stats[m]["mean"][0] - std = obs_normalization_stats[m]["std"][0] + # assert set(obs_dict.keys()).issubset(obs_normalization_stats) + def _norm_helper(obs, mean, std): # shape consistency checks m_num_dims = len(mean.shape) - shape_len_diff = len(obs_dict[m].shape) - m_num_dims + shape_len_diff = len(obs.shape) - m_num_dims assert shape_len_diff >= 0, "shape length mismatch in @normalize_obs" - assert obs_dict[m].shape[-m_num_dims:] == mean.shape, "shape mismatch in @normalize_obs" + assert ( + obs.shape[-m_num_dims:] == mean.shape + ), "shape mismatch in @normalize_obs" # Obs can have one or more leading batch dims - prepare for broadcasting. # @@ -501,10 +504,92 @@ def normalize_obs(obs_dict, obs_normalization_stats): reshape_padding = tuple([1] * shape_len_diff) mean = mean.reshape(reshape_padding + tuple(mean.shape)) std = std.reshape(reshape_padding + tuple(std.shape)) + if isinstance(obs, torch.Tensor) and isinstance(mean, np.ndarray): + mean = torch.from_numpy(mean).to(obs.device) + std = torch.from_numpy(std).to(obs.device) + + return (obs - mean) / std + + for m in batch["obs"]: + if m not in normalization_stats: + continue + # get rid of extra dimension - we will pad for broadcasting later + mean = normalization_stats[m]["mean"][0] + std = normalization_stats[m]["std"][0] + + batch["obs"][m] = _norm_helper(batch["obs"][m], mean, std) + + if normalize_actions: + for k in batch: + if "actions" in k: + ac_mean = normalization_stats[k]["mean"][0] + ac_std = normalization_stats[k]["std"][0] + + batch[k] = _norm_helper(batch[k], ac_mean, ac_std) + + + return batch + +def unnormalize_batch(batch, normalization_stats): + """ + Unnormalize observations using the provided "mean" and "std" entries + for each observation key. The observation dictionary will be + modified in-place. + + Args: + obs_dict (dict): dictionary mapping observation key to np.array or + torch.Tensor. Can have any number of leading batch dimensions. + + obs_normalization_stats (dict): this should map observation keys to dicts + with a "mean" and "std" of shape (1, ...) where ... is the default + shape for the observation. + + Returns: + obs_dict (dict): obs dict with unnormalized observation arrays + """ + + # ensure we have statistics for each modality key in the observation + # assert set(obs_dict.keys()).issubset(obs_normalization_stats) + + def _unnorm_helper(obs, mean, std): + # shape consistency checks + m_num_dims = len(mean.shape) + shape_len_diff = len(obs.shape) - m_num_dims + assert shape_len_diff >= 0, "shape length mismatch in @normalize_obs" + assert ( + obs.shape[-m_num_dims:] == mean.shape + ), "shape mismatch in @normalize_obs" + + # Obs can have one or more leading batch dims - prepare for broadcasting. + # + # As an example, if the obs has shape [B, T, D] and our mean / std stats are shape [D] + # then we should pad the stats to shape [1, 1, D]. + reshape_padding = tuple([1] * shape_len_diff) + mean = torch.from_numpy(mean.reshape(reshape_padding + tuple(mean.shape))).to(obs.device) + std = torch.from_numpy(std.reshape(reshape_padding + tuple(std.shape))).to(obs.device) + + return (obs * std) + mean + + if "obs" in batch: + for m in batch["obs"]: + if m not in normalization_stats: + continue + # get rid of extra dimension - we will pad for broadcasting later + mean = normalization_stats[m]["mean"][0] + std = normalization_stats[m]["std"][0] + + batch["obs"][m] = _unnorm_helper(batch["obs"][m], mean, std) - obs_dict[m] = (obs_dict[m] - mean) / std + for k in batch: + if "actions" in k: + if normalization_stats is None: + continue + ac_mean = normalization_stats[k]["mean"][0] + ac_std = normalization_stats[k]["std"][0] - return obs_dict + batch[k] = _unnorm_helper(batch[k], ac_mean, ac_std) + + return batch def has_modality(modality, obs_keys): @@ -810,7 +895,7 @@ def _default_obs_unprocessor(cls, obs): raise NotImplementedError @classmethod - def process_obs(cls, obs): + def process_obs(cls, obs, imagenet_normalize=False): """ Prepares an observation @obs of this modality for network input. @@ -822,7 +907,10 @@ def process_obs(cls, obs): """ processor = cls._custom_obs_processor if \ cls._custom_obs_processor is not None else cls._default_obs_processor - return processor(obs) + if isinstance(cls, ImageModality): + return processor(obs, imagenet_normalize=imagenet_normalize) + else: + return processor(obs) @classmethod def unprocess_obs(cls, obs): @@ -869,7 +957,7 @@ class ImageModality(Modality): name = "rgb" @classmethod - def _default_obs_processor(cls, obs): + def _default_obs_processor(cls, obs, imagenet_normalize=False): """ Given image fetched from dataset, process for network input. Converts array to float (from uint8), normalizes pixels from range [0, 255] to [0, 1], and channel swaps @@ -881,7 +969,7 @@ def _default_obs_processor(cls, obs): Returns: processed_obs (np.array or torch.Tensor): processed image """ - return process_frame(frame=obs, channel_dim=3, scale=255.) + return process_frame(frame=obs, channel_dim=3, scale=255., imagenet_normalize=imagenet_normalize) @classmethod def _default_obs_unprocessor(cls, obs):