diff --git a/modulus/datapipes/healpix/data_modules.py b/modulus/datapipes/healpix/data_modules.py index fdabb9f8d..665dc77bf 100644 --- a/modulus/datapipes/healpix/data_modules.py +++ b/modulus/datapipes/healpix/data_modules.py @@ -613,6 +613,9 @@ def setup(self) -> None: batch_size=self.batch_size, ) + if self.constants is not None: + dataset = dataset.sel(channel_c=list(self.constants.values())) + if self.splits is not None and self.forecast_init_times is None: self.train_dataset = TimeSeriesDataset( dataset.sel( @@ -1033,6 +1036,9 @@ def setup(self) -> None: batch_size=self.batch_size, ) + if self.constants is not None: + dataset = dataset.sel(channel_c=list(self.constants.values())) + if self.splits is not None and self.forecast_init_times is None: self.train_dataset = CoupledTimeSeriesDataset( dataset.sel( diff --git a/test/datapipes/test_healpix.py b/test/datapipes/test_healpix.py index 73216be18..3667cb650 100644 --- a/test/datapipes/test_healpix.py +++ b/test/datapipes/test_healpix.py @@ -550,7 +550,10 @@ def test_TimeSeriesDataModule_get_constants( # open our test dataset ds_path = Path(data_dir, dataset_name + ".zarr") zarr_ds = xr.open_zarr(ds_path) - expected = np.transpose(zarr_ds.constants.values, axes=(1, 0, 2, 3)) + expected = np.transpose( + zarr_ds.constants.sel(channel_c=list(constants.keys())).values, + axes=(1, 0, 2, 3), + ) assert np.array_equal( timeseries_dm.get_constants(), diff --git a/test/datapipes/test_healpix_couple.py b/test/datapipes/test_healpix_couple.py index 833827c9e..681a943bf 100644 --- a/test/datapipes/test_healpix_couple.py +++ b/test/datapipes/test_healpix_couple.py @@ -759,7 +759,10 @@ def test_CoupledTimeSeriesDataModule_get_constants( # open our test dataset ds_path = Path(data_dir, dataset_name + ".zarr") zarr_ds = xr.open_zarr(ds_path) - expected = np.transpose(zarr_ds.constants.values, axes=(1, 0, 2, 3)) + expected = np.transpose( + zarr_ds.constants.sel(channel_c=list(constants.keys())).values, + axes=(1, 0, 2, 3), + ) assert np.array_equal( timeseries_dm.get_constants(),