From 6aa4b5351f042e9dddebfa7988698ebab6170569 Mon Sep 17 00:00:00 2001 From: kurtamohler Date: Mon, 9 Sep 2024 23:48:56 -0700 Subject: [PATCH] [Performance] Faster `SliceSampler._tensor_slices_from_startend` (#2423) --- torchrl/data/replay_buffers/samplers.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 8338fdff74b..94d74bed468 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1076,9 +1076,28 @@ def _tensor_slices_from_startend(self, seq_length, start, storage_length): # seq_length is a 1d tensor indicating the desired length of each sequence if isinstance(seq_length, int): - result = torch.cat( - [self._start_to_end(_start, length=seq_length) for _start in start] + arange = torch.arange(seq_length, device=start.device, dtype=start.dtype) + ndims = start.shape[-1] - 1 if (start.ndim - 1) else 0 + if ndims: + arange_reshaped = torch.empty( + arange.shape + torch.Size([ndims + 1]), + device=start.device, + dtype=start.dtype, + ) + arange_reshaped[..., 0] = arange + arange_reshaped[..., 1:] = 0 + else: + arange_reshaped = arange.unsqueeze(-1) + arange_expanded = arange_reshaped.expand( + torch.Size([start.shape[0]]) + arange_reshaped.shape ) + if start.shape != arange_expanded.shape: + n_missing_dims = arange_expanded.dim() - start.dim() + start_expanded = start[ + (slice(None),) + (None,) * n_missing_dims + ].expand_as(arange_expanded) + result = (start_expanded + arange_expanded).flatten(0, 1) + else: # when padding is needed result = torch.cat(