Skip to content

Commit

Permalink
Improvements for xarray provider (geopython#1800)
Browse files Browse the repository at this point in the history
* Manage non-cf-compliant time dimension

* Manage datasets without a time dimension

* Allow reversed slices also for axes

* Convert also metadata to float64 for json output

* Use named temporary file to enable netcdf4 engine

* Make float64 conversion faster

* Add netcdf output to xarray provider

* Flake8 fixes

* Fix bug when no time axis in data

* Use new xarray interface

* Add test for zarr dataset without time dimension

* Avoid errors if missing long_name

* Manage zarr and netcdf output in the same way

* Revert "Manage zarr and netcdf output in the same way"

This reverts commit 0b09281.

* Revert "Add netcdf output to xarray provider"

This reverts commit 9f72bf7.
  • Loading branch information
barbuz authored and sjordan29 committed Oct 21, 2024
1 parent 0470ac9 commit 1871a49
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 55 deletions.
158 changes: 103 additions & 55 deletions pygeoapi/provider/xarray_.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,19 @@ def __init__(self, provider_def):
else:
data_to_open = self.data

self._data = open_func(data_to_open)
try:
self._data = open_func(data_to_open)
except ValueError as err:
# Manage non-cf-compliant time dimensions
if 'time' in str(err):
self._data = open_func(self.data, decode_times=False)
else:
raise err

self.storage_crs = self._parse_storage_crs(provider_def)
self._coverage_properties = self._get_coverage_properties()

self.axes = [self._coverage_properties['x_axis_label'],
self._coverage_properties['y_axis_label'],
self._coverage_properties['time_axis_label']]
self.axes = self._coverage_properties['axes']

self.get_fields()
except Exception as err:
Expand All @@ -101,15 +107,15 @@ def __init__(self, provider_def):
def get_fields(self):
if not self._fields:
for key, value in self._data.variables.items():
if len(value.shape) >= 3:
if key not in self._data.coords:
LOGGER.debug('Adding variable')
dtype = value.dtype
if dtype.name.startswith('float'):
dtype = 'number'

self._fields[key] = {
'type': dtype,
'title': value.attrs['long_name'],
'title': value.attrs.get('long_name'),
'x-ogc-unit': value.attrs.get('units')
}

Expand Down Expand Up @@ -142,9 +148,9 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326,

data = self._data[[*properties]]

if any([self._coverage_properties['x_axis_label'] in subsets,
self._coverage_properties['y_axis_label'] in subsets,
self._coverage_properties['time_axis_label'] in subsets,
if any([self._coverage_properties.get('x_axis_label') in subsets,
self._coverage_properties.get('y_axis_label') in subsets,
self._coverage_properties.get('time_axis_label') in subsets,
datetime_ is not None]):

LOGGER.debug('Creating spatio-temporal subset')
Expand All @@ -163,18 +169,36 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326,
self._coverage_properties['y_axis_label'] in subsets,
len(bbox) > 0]):
msg = 'bbox and subsetting by coordinates are exclusive'
LOGGER.warning(msg)
LOGGER.error(msg)
raise ProviderQueryError(msg)
else:
query_params[self._coverage_properties['x_axis_label']] = \
slice(bbox[0], bbox[2])
query_params[self._coverage_properties['y_axis_label']] = \
slice(bbox[1], bbox[3])
x_axis_label = self._coverage_properties['x_axis_label']
x_coords = data.coords[x_axis_label]
if x_coords.values[0] > x_coords.values[-1]:
LOGGER.debug(
'Reversing slicing of x axis from high to low'
)
query_params[x_axis_label] = slice(bbox[2], bbox[0])
else:
query_params[x_axis_label] = slice(bbox[0], bbox[2])
y_axis_label = self._coverage_properties['y_axis_label']
y_coords = data.coords[y_axis_label]
if y_coords.values[0] > y_coords.values[-1]:
LOGGER.debug(
'Reversing slicing of y axis from high to low'
)
query_params[y_axis_label] = slice(bbox[3], bbox[1])
else:
query_params[y_axis_label] = slice(bbox[1], bbox[3])

LOGGER.debug('bbox_crs is not currently handled')

if datetime_ is not None:
if self._coverage_properties['time_axis_label'] in subsets:
if self._coverage_properties['time_axis_label'] is None:
msg = 'Dataset does not contain a time axis'
LOGGER.error(msg)
raise ProviderQueryError(msg)
elif self._coverage_properties['time_axis_label'] in subsets:
msg = 'datetime and temporal subsetting are exclusive'
LOGGER.error(msg)
raise ProviderQueryError(msg)
Expand All @@ -196,32 +220,36 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326,
LOGGER.warning(err)
raise ProviderQueryError(err)

if (any([data.coords[self.x_field].size == 0,
data.coords[self.y_field].size == 0,
data.coords[self.time_field].size == 0])):
if any(size == 0 for size in data.sizes.values()):
msg = 'No data found'
LOGGER.warning(msg)
raise ProviderNoDataError(msg)

if format_ == 'json':
# json does not support float32
data = _convert_float32_to_float64(data)

out_meta = {
'bbox': [
data.coords[self.x_field].values[0],
data.coords[self.y_field].values[0],
data.coords[self.x_field].values[-1],
data.coords[self.y_field].values[-1]
],
"time": [
_to_datetime_string(data.coords[self.time_field].values[0]),
_to_datetime_string(data.coords[self.time_field].values[-1])
],
"driver": "xarray",
"height": data.sizes[self.y_field],
"width": data.sizes[self.x_field],
"time_steps": data.sizes[self.time_field],
"variables": {var_name: var.attrs
for var_name, var in data.variables.items()}
}

if self.time_field is not None:
out_meta['time'] = [
_to_datetime_string(data.coords[self.time_field].values[0]),
_to_datetime_string(data.coords[self.time_field].values[-1]),
]
out_meta["time_steps"] = data.sizes[self.time_field]

LOGGER.debug('Serializing data in memory')
if format_ == 'json':
LOGGER.debug('Creating output in CoverageJSON')
Expand All @@ -230,9 +258,11 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326,
LOGGER.debug('Returning data in native zarr format')
return _get_zarr_data(data)
else: # return data in native format
with tempfile.TemporaryFile() as fp:
with tempfile.NamedTemporaryFile() as fp:
LOGGER.debug('Returning data in native NetCDF format')
fp.write(data.to_netcdf())
data.to_netcdf(
fp.name
) # we need to pass a string to be able to use the "netcdf4" engine # noqa
fp.seek(0)
return fp.read()

Expand All @@ -249,7 +279,6 @@ def gen_covjson(self, metadata, data, fields):

LOGGER.debug('Creating CoverageJSON domain')
minx, miny, maxx, maxy = metadata['bbox']
mint, maxt = metadata['time']

selected_fields = {
key: value for key, value in self.fields.items()
Expand Down Expand Up @@ -285,11 +314,6 @@ def gen_covjson(self, metadata, data, fields):
'start': maxy,
'stop': miny,
'num': metadata['height']
},
self.time_field: {
'start': mint,
'stop': maxt,
'num': metadata['time_steps']
}
},
'referencing': [{
Expand All @@ -304,6 +328,14 @@ def gen_covjson(self, metadata, data, fields):
'ranges': {}
}

if self.time_field is not None:
mint, maxt = metadata['time']
cj['domain']['axes'][self.time_field] = {
'start': mint,
'stop': maxt,
'num': metadata['time_steps'],
}

for key, value in selected_fields.items():
parameter = {
'type': 'Parameter',
Expand All @@ -322,21 +354,25 @@ def gen_covjson(self, metadata, data, fields):
cj['parameters'][key] = parameter

data = data.fillna(None)
data = _convert_float32_to_float64(data)

try:
for key, value in selected_fields.items():
cj['ranges'][key] = {
'type': 'NdArray',
'dataType': value['type'],
'axisNames': [
'y', 'x', self._coverage_properties['time_axis_label']
'y', 'x'
],
'shape': [metadata['height'],
metadata['width'],
metadata['time_steps']]
metadata['width']]
}
cj['ranges'][key]['values'] = data[key].values.flatten().tolist() # noqa

if self.time_field is not None:
cj['ranges'][key]['axisNames'].append(
self._coverage_properties['time_axis_label']
)
cj['ranges'][key]['shape'].append(metadata['time_steps'])
except IndexError as err:
LOGGER.warning(err)
raise ProviderQueryError('Invalid query parameter')
Expand Down Expand Up @@ -382,31 +418,37 @@ def _get_coverage_properties(self):
self._data.coords[self.x_field].values[-1],
self._data.coords[self.y_field].values[-1],
],
'time_range': [
_to_datetime_string(
self._data.coords[self.time_field].values[0]
),
_to_datetime_string(
self._data.coords[self.time_field].values[-1]
)
],
'bbox_crs': 'http://www.opengis.net/def/crs/OGC/1.3/CRS84',
'crs_type': 'GeographicCRS',
'x_axis_label': self.x_field,
'y_axis_label': self.y_field,
'time_axis_label': self.time_field,
'width': self._data.sizes[self.x_field],
'height': self._data.sizes[self.y_field],
'time': self._data.sizes[self.time_field],
'time_duration': self.get_time_coverage_duration(),
'bbox_units': 'degrees',
'resx': np.abs(self._data.coords[self.x_field].values[1]
- self._data.coords[self.x_field].values[0]),
'resy': np.abs(self._data.coords[self.y_field].values[1]
- self._data.coords[self.y_field].values[0]),
'restime': self.get_time_resolution()
'resx': np.abs(
self._data.coords[self.x_field].values[1]
- self._data.coords[self.x_field].values[0]
),
'resy': np.abs(
self._data.coords[self.y_field].values[1]
- self._data.coords[self.y_field].values[0]
),
}

if self.time_field is not None:
properties['time_axis_label'] = self.time_field
properties['time_range'] = [
_to_datetime_string(
self._data.coords[self.time_field].values[0]
),
_to_datetime_string(
self._data.coords[self.time_field].values[-1]
),
]
properties['time'] = self._data.sizes[self.time_field]
properties['time_duration'] = self.get_time_coverage_duration()
properties['restime'] = self.get_time_resolution()

# Update properties based on the xarray's CRS
epsg_code = self.storage_crs.to_epsg()
LOGGER.debug(f'{epsg_code}')
Expand All @@ -425,10 +467,12 @@ def _get_coverage_properties(self):

properties['axes'] = [
properties['x_axis_label'],
properties['y_axis_label'],
properties['time_axis_label']
properties['y_axis_label']
]

if self.time_field is not None:
properties['axes'].append(properties['time_axis_label'])

return properties

@staticmethod
Expand All @@ -455,7 +499,8 @@ def get_time_resolution(self):
:returns: time resolution string
"""

if self._data[self.time_field].size > 1:
if self.time_field is not None \
and self._data[self.time_field].size > 1:
time_diff = (self._data[self.time_field][1] -
self._data[self.time_field][0])

Expand All @@ -472,6 +517,9 @@ def get_time_coverage_duration(self):
:returns: time coverage duration string
"""

if self.time_field is None:
return None

dur = self._data[self.time_field][-1] - self._data[self.time_field][0]
ms_difference = dur.values.astype('timedelta64[ms]').astype(np.double)

Expand Down Expand Up @@ -634,7 +682,7 @@ def _convert_float32_to_float64(data):
for var_name in data.variables:
if data[var_name].dtype == 'float32':
og_attrs = data[var_name].attrs
data[var_name] = data[var_name].astype('float64')
data[var_name] = data[var_name].astype('float64', copy=False)
data[var_name].attrs = og_attrs

return data
26 changes: 26 additions & 0 deletions tests/test_xarray_zarr_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from numpy import float64, int64

import pytest
import xarray as xr

from pygeoapi.provider.xarray_ import XarrayProvider
from pygeoapi.util import json_serial
Expand All @@ -53,6 +54,20 @@ def config():
}


@pytest.fixture()
def config_no_time(tmp_path):
ds = xr.open_zarr(path)
ds = ds.sel(time=ds.time[0])
ds = ds.drop_vars('time')
ds.to_zarr(tmp_path / 'no_time.zarr')
return {
'name': 'zarr',
'type': 'coverage',
'data': str(tmp_path / 'no_time.zarr'),
'format': {'name': 'zarr', 'mimetype': 'application/zip'},
}


def test_provider(config):
p = XarrayProvider(config)

Expand Down Expand Up @@ -85,3 +100,14 @@ def test_numpy_json_serial():

d = float64(500.00000005)
assert json_serial(d) == 500.00000005


def test_no_time(config_no_time):
p = XarrayProvider(config_no_time)

assert len(p.fields) == 4
assert p.axes == ['lon', 'lat']

coverage = p.query(format='json')

assert sorted(coverage['domain']['axes'].keys()) == ['x', 'y']

0 comments on commit 1871a49

Please sign in to comment.