Skip to content

Commit

Permalink
[CI] Fix Minari tests (#2419)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <[email protected]>
  • Loading branch information
younik and vmoens authored Sep 17, 2024
1 parent 2332909 commit 224d637
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/unittest/linux_libs/scripts_minari/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ dependencies:
- pyyaml
- scipy
- hydra-core
- minari[gcs]
- minari[gcs,hdf5]
35 changes: 4 additions & 31 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2823,34 +2823,11 @@ def _minari_selected_datasets():

torch.manual_seed(0)

# We rely on sorting the keys as v0 < v1 but if the version is greater than 9 this won't work
total_keys = sorted(minari.list_remote_datasets())
assert not any(
key[-2:] == "10" for key in total_keys
), "You should adapt the Minari test scripts as some dataset have a version >= 10 and sorting will fail."
total_keys_splits = [key.split("-") for key in total_keys]
total_keys = sorted(
minari.list_remote_datasets(latest_version=True, compatible_minari_version=True)
)
indices = torch.randperm(len(total_keys))[:20]
keys = [total_keys[idx] for idx in indices]
keys = [
key
for key in keys
if "=0.4" in minari.list_remote_datasets()[key]["minari_version"]
]

def _replace_with_max(key):
key_split = key.split("-")
same_entries = (
torch.tensor(
[total_key[:-1] == key_split[:-1] for total_key in total_keys_splits]
)
.nonzero()
.squeeze()
.tolist()
)
last_same_entry = same_entries[-1]
return total_keys[last_same_entry]

keys = [_replace_with_max(key) for key in keys]

assert len(keys) > 5, keys
_MINARI_DATASETS += keys
Expand Down Expand Up @@ -2880,12 +2857,8 @@ def test_load(self, selected_dataset, split):
break

def test_minari_preproc(self, tmpdir):
global _MINARI_DATASETS
if not _MINARI_DATASETS:
_minari_selected_datasets()
selected_dataset = _MINARI_DATASETS[0]
dataset = MinariExperienceReplay(
selected_dataset,
"D4RL/pointmaze/large-v2",
batch_size=32,
split_trajs=False,
download="force",
Expand Down

1 comment on commit 224d637

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 224d637 Previous: 2332909 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 35.87699897523488 iter/sec (stddev: 0.1658359731567662) 203.0185644723699 iter/sec (stddev: 0.0011685412844671744) 5.66

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.