Skip to content

Commit

Permalink
Added facade functions to_zarr and from_zarr (#2236)
Browse files Browse the repository at this point in the history
* Added facade functions `to_zarr` and `from_zarr`

* black

* added to changelog

* update PR with review comments

* fix rebase issues with changelog

* black

---------

Co-authored-by: Oriol (ZBook) <[email protected]>
  • Loading branch information
pSpitzner and OriolAbril authored Jul 11, 2023
1 parent 7fb2257 commit 8632186
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
- Add InferenceData<->DataTree conversion functions ([2253](https://github.com/arviz-devs/arviz/pull/2253))
- Bayes Factor plot: Use arviz's kde instead of the one from scipy ([2237](https://github.com/arviz-devs/arviz/pull/2237))
- InferenceData objects can now be appended to existing netCDF4 files and to specific groups within them ([2227](https://github.com/arviz-devs/arviz/pull/2227))
- Added facade functions `az.to_zarr` and `az.from_zarr` ([2236](https://github.com/arviz-devs/arviz/pull/2236))

### Maintenance and fixes
- Replace deprecated np.product with np.prod ([2249](https://github.com/arviz-devs/arviz/pull/2249))
- Fix numba deprecation warning ([2246](https://github.com/arviz-devs/arviz/pull/2246))
- Fixes for creating numpy object array ([2233](https://github.com/arviz-devs/arviz/pull/2233) and [2239](https://github.com/arviz-devs/arviz/pull/2239))
- Adapt histograms generated by plot_dist to input dtype ([2247](https://github.com/arviz-devs/arviz/pull/2247))


### Deprecation

### Documentation
Expand Down
3 changes: 3 additions & 0 deletions arviz/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .io_pyjags import from_pyjags
from .io_pyro import from_pyro
from .io_pystan import from_pystan
from .io_zarr import from_zarr, to_zarr
from .utils import extract, extract_dataset

__all__ = [
Expand Down Expand Up @@ -44,6 +45,8 @@
"to_datatree",
"to_json",
"to_netcdf",
"from_zarr",
"to_zarr",
"CoordSpec",
"DimSpec",
]
46 changes: 46 additions & 0 deletions arviz/data/io_zarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Input and output support for zarr data."""

from .converters import convert_to_inference_data
from .inference_data import InferenceData


def from_zarr(store):
return InferenceData.from_zarr(store)


from_zarr.__doc__ = InferenceData.from_zarr.__doc__


def to_zarr(data, store=None, **kwargs):
"""
Convert data to zarr, optionally saving to disk if ``store`` is provided.
The zarr storage is using the same group names as the InferenceData.
Parameters
----------
store : zarr.storage, MutableMapping or str, optional
Zarr storage class or path to desired DirectoryStore.
Default (None) a store is created in a temporary directory.
**kwargs : dict, optional
Passed to :py:func:`convert_to_inference_data`.
Returns
-------
zarr.hierarchy.group
A zarr hierarchy group containing the InferenceData.
Raises
------
TypeError
If no valid store is found.
References
----------
https://zarr.readthedocs.io/
"""
inference_data = convert_to_inference_data(data, **kwargs)
zarr_group = inference_data.to_zarr(store=store)
return zarr_group
39 changes: 39 additions & 0 deletions arviz/tests/base_tests/test_data_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from ... import InferenceData, from_dict
from ... import to_zarr, from_zarr

from ..helpers import ( # pylint: disable=unused-import
chains,
Expand Down Expand Up @@ -103,3 +104,41 @@ def test_io_method(self, data, eight_schools_params, store, fill_attrs):
assert inference_data2.attrs["test"] == 1
else:
assert "test" not in inference_data2.attrs

def test_io_function(self, data, eight_schools_params):
# create InferenceData and check it has been properly created
inference_data = self.get_inference_data( # pylint: disable=W0612
data,
eight_schools_params,
fill_attrs=True,
)
test_dict = {
"posterior": ["eta", "theta", "mu", "tau"],
"posterior_predictive": ["eta", "theta", "mu", "tau"],
"sample_stats": ["eta", "theta", "mu", "tau"],
"prior": ["eta", "theta", "mu", "tau"],
"prior_predictive": ["eta", "theta", "mu", "tau"],
"sample_stats_prior": ["eta", "theta", "mu", "tau"],
"observed_data": ["J", "y", "sigma"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

assert inference_data.attrs["test"] == 1

# check filename does not exist and use to_zarr method
with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:
filepath = os.path.join(tmp_dir, "zarr")

to_zarr(inference_data, store=filepath)
# assert file has been saved correctly
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0

inference_data2 = from_zarr(filepath)

# Everything in dict still available in inference_data2 ?
fails = check_multiple_attrs(test_dict, inference_data2)
assert not fails

assert inference_data2.attrs["test"] == 1
2 changes: 2 additions & 0 deletions doc/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ IO / General conversion
to_datatree
to_json
to_netcdf
from_zarr
to_zarr


General functions
Expand Down

0 comments on commit 8632186

Please sign in to comment.