From 224d6375b3ea84f22b5c20fc5a712c42dd8104ac Mon Sep 17 00:00:00 2001 From: Omar Younis <42100908+younik@users.noreply.github.com> Date: Tue, 17 Sep 2024 21:13:04 +0300 Subject: [PATCH] [CI] Fix Minari tests (#2419) Co-authored-by: Vincent Moens --- .../linux_libs/scripts_minari/environment.yml | 2 +- test/test_libs.py | 35 +++---------------- 2 files changed, 5 insertions(+), 32 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index 23aedb4cc23..ad5bfc12650 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -17,4 +17,4 @@ dependencies: - pyyaml - scipy - hydra-core - - minari[gcs] + - minari[gcs,hdf5] diff --git a/test/test_libs.py b/test/test_libs.py index 6f5cc1bebeb..87c69bf000c 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -2823,34 +2823,11 @@ def _minari_selected_datasets(): torch.manual_seed(0) - # We rely on sorting the keys as v0 < v1 but if the version is greater than 9 this won't work - total_keys = sorted(minari.list_remote_datasets()) - assert not any( - key[-2:] == "10" for key in total_keys - ), "You should adapt the Minari test scripts as some dataset have a version >= 10 and sorting will fail." - total_keys_splits = [key.split("-") for key in total_keys] + total_keys = sorted( + minari.list_remote_datasets(latest_version=True, compatible_minari_version=True) + ) indices = torch.randperm(len(total_keys))[:20] keys = [total_keys[idx] for idx in indices] - keys = [ - key - for key in keys - if "=0.4" in minari.list_remote_datasets()[key]["minari_version"] - ] - - def _replace_with_max(key): - key_split = key.split("-") - same_entries = ( - torch.tensor( - [total_key[:-1] == key_split[:-1] for total_key in total_keys_splits] - ) - .nonzero() - .squeeze() - .tolist() - ) - last_same_entry = same_entries[-1] - return total_keys[last_same_entry] - - keys = [_replace_with_max(key) for key in keys] assert len(keys) > 5, keys _MINARI_DATASETS += keys @@ -2880,12 +2857,8 @@ def test_load(self, selected_dataset, split): break def test_minari_preproc(self, tmpdir): - global _MINARI_DATASETS - if not _MINARI_DATASETS: - _minari_selected_datasets() - selected_dataset = _MINARI_DATASETS[0] dataset = MinariExperienceReplay( - selected_dataset, + "D4RL/pointmaze/large-v2", batch_size=32, split_trajs=False, download="force",