Skip to content

Commit

Permalink
Fix DLWP-Healpix coupled dataloader not supporting extracting specifi…
Browse files Browse the repository at this point in the history
…c channels for constants array (#715)

* Add extracting channels for constants array in healpix dataset

* Update constant test to handle extracting constants

---------

Co-authored-by: David Pruitt <[email protected]>
  • Loading branch information
ivanauyeung and daviddpruitt authored Dec 3, 2024
1 parent 2975115 commit 1affed7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
6 changes: 6 additions & 0 deletions modulus/datapipes/healpix/data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ def setup(self) -> None:
batch_size=self.batch_size,
)

if self.constants is not None:
dataset = dataset.sel(channel_c=list(self.constants.values()))

if self.splits is not None and self.forecast_init_times is None:
self.train_dataset = TimeSeriesDataset(
dataset.sel(
Expand Down Expand Up @@ -1033,6 +1036,9 @@ def setup(self) -> None:
batch_size=self.batch_size,
)

if self.constants is not None:
dataset = dataset.sel(channel_c=list(self.constants.values()))

if self.splits is not None and self.forecast_init_times is None:
self.train_dataset = CoupledTimeSeriesDataset(
dataset.sel(
Expand Down
5 changes: 4 additions & 1 deletion test/datapipes/test_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,10 @@ def test_TimeSeriesDataModule_get_constants(
# open our test dataset
ds_path = Path(data_dir, dataset_name + ".zarr")
zarr_ds = xr.open_zarr(ds_path)
expected = np.transpose(zarr_ds.constants.values, axes=(1, 0, 2, 3))
expected = np.transpose(
zarr_ds.constants.sel(channel_c=list(constants.keys())).values,
axes=(1, 0, 2, 3),
)

assert np.array_equal(
timeseries_dm.get_constants(),
Expand Down
5 changes: 4 additions & 1 deletion test/datapipes/test_healpix_couple.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,10 @@ def test_CoupledTimeSeriesDataModule_get_constants(
# open our test dataset
ds_path = Path(data_dir, dataset_name + ".zarr")
zarr_ds = xr.open_zarr(ds_path)
expected = np.transpose(zarr_ds.constants.values, axes=(1, 0, 2, 3))
expected = np.transpose(
zarr_ds.constants.sel(channel_c=list(constants.keys())).values,
axes=(1, 0, 2, 3),
)

assert np.array_equal(
timeseries_dm.get_constants(),
Expand Down

0 comments on commit 1affed7

Please sign in to comment.