Skip to content

Commit

Permalink
fix lerobot
Browse files Browse the repository at this point in the history
  • Loading branch information
KeplerC committed Sep 23, 2024
1 parent 85f5266 commit e573046
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 19 deletions.
4 changes: 2 additions & 2 deletions benchmarks/openx_by_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import fog_x
import csv
import stat
from fog_x.loader.lerobot import LeRobotLoader
from fog_x.loader.lerobot import LeRobotLoader_ByFrame
from fog_x.loader.vla import get_vla_dataloader
from fog_x.loader.hdf5 import get_hdf5_dataloader

Expand Down Expand Up @@ -310,7 +310,7 @@ def __init__(

def get_loader(self):
path = os.path.join(self.exp_dir, "hf")
return LeRobotLoader(path, self.dataset_name, batch_size=self.batch_size)
return LeRobotLoader_ByFrame(path, self.dataset_name, batch_size=1, slice_length=self.batch_size)

def _recursively_load_data(self, data):
import torch
Expand Down
6 changes: 3 additions & 3 deletions evaluation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ sudo echo "Use sudo access for clearning cache"

# Define a list of batch sizes to iterate through

batch_sizes=(4)
batch_sizes=(64)
num_batches=200
# batch_sizes=(1 2)

Expand All @@ -16,7 +16,7 @@ do
echo "Running benchmarks with batch size: $batch_size"

# python3 benchmarks/openx.py --dataset_names nyu_door_opening_surprising_effectiveness --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx_by_frame.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx_by_frame.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx_by_frame.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx.py --dataset_names berkeley_autolab_ur5 --num_batches $num_batches --batch_size $batch_size
done
22 changes: 8 additions & 14 deletions fog_x/loader/lerobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def get_batch(self):


class LeRobotLoader_ByFrame(BaseLoader):
def __init__(self, path, dataset_name, batch_size=1, delta_timestamps=None):
super(LeRobotLoader, self).__init__(path)
def __init__(self, path, dataset_name, batch_size=1, delta_timestamps=None, slice_length=16):
super(LeRobotLoader_ByFrame, self).__init__(path)
self.batch_size = batch_size
self.dataset = LeRobotDataset(root="/mnt/data/fog_x/hf/", repo_id=dataset_name, delta_timestamps=delta_timestamps)
self.episode_index = 0
self.slice_length = slice_length

def __len__(self):
return len(self.dataset.episode_data_index["from"])
Expand All @@ -78,33 +78,27 @@ def _frame_to_numpy(frame):
for attempt in range(max_retries):
try:
# repeat
if self.episode_index >= len(self.dataset):
self.episode_index = 0
self.episode_index = np.random.randint(0, len(self.dataset))
try:
from_idx = self.dataset.episode_data_index["from"][self.episode_index].item()
to_idx = self.dataset.episode_data_index["to"][self.episode_index].item()
except Exception as e:
self.episode_index = 0
continue

# Randomly select random_frames from episode
random_frames = 16
episode_length = to_idx - from_idx
if episode_length <= random_frames:
if episode_length <= self.slice_length:
random_from = from_idx
random_to = to_idx
else:
random_from = np.random.randint(from_idx, to_idx - 15)
random_to = random_from + 16
frames = [_frame_to_numpy(self.dataset[idx]) for idx in range(random_from, random_to)]
random_from = np.random.randint(from_idx, to_idx - self.slice_length)
random_to = random_from + self.slice_length
frames = [self.dataset[idx] for idx in range(random_from, random_to)]
episode.extend(frames)
self.episode_index += 1
break
except Exception as e:
if attempt == max_retries - 1:
raise e
self.episode_index += 1


batch_of_episodes.append((episode))

Expand Down

0 comments on commit e573046

Please sign in to comment.