diff --git a/pygeoapi/provider/xarray_.py b/pygeoapi/provider/xarray_.py index 585879282..ba835f033 100644 --- a/pygeoapi/provider/xarray_.py +++ b/pygeoapi/provider/xarray_.py @@ -85,13 +85,19 @@ def __init__(self, provider_def): else: data_to_open = self.data - self._data = open_func(data_to_open) + try: + self._data = open_func(data_to_open) + except ValueError as err: + # Manage non-cf-compliant time dimensions + if 'time' in str(err): + self._data = open_func(self.data, decode_times=False) + else: + raise err + self.storage_crs = self._parse_storage_crs(provider_def) self._coverage_properties = self._get_coverage_properties() - self.axes = [self._coverage_properties['x_axis_label'], - self._coverage_properties['y_axis_label'], - self._coverage_properties['time_axis_label']] + self.axes = self._coverage_properties['axes'] self.get_fields() except Exception as err: @@ -101,7 +107,7 @@ def __init__(self, provider_def): def get_fields(self): if not self._fields: for key, value in self._data.variables.items(): - if len(value.shape) >= 3: + if key not in self._data.coords: LOGGER.debug('Adding variable') dtype = value.dtype if dtype.name.startswith('float'): @@ -109,7 +115,7 @@ def get_fields(self): self._fields[key] = { 'type': dtype, - 'title': value.attrs['long_name'], + 'title': value.attrs.get('long_name'), 'x-ogc-unit': value.attrs.get('units') } @@ -142,9 +148,9 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326, data = self._data[[*properties]] - if any([self._coverage_properties['x_axis_label'] in subsets, - self._coverage_properties['y_axis_label'] in subsets, - self._coverage_properties['time_axis_label'] in subsets, + if any([self._coverage_properties.get('x_axis_label') in subsets, + self._coverage_properties.get('y_axis_label') in subsets, + self._coverage_properties.get('time_axis_label') in subsets, datetime_ is not None]): LOGGER.debug('Creating spatio-temporal subset') @@ -163,18 +169,36 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326, self._coverage_properties['y_axis_label'] in subsets, len(bbox) > 0]): msg = 'bbox and subsetting by coordinates are exclusive' - LOGGER.warning(msg) + LOGGER.error(msg) raise ProviderQueryError(msg) else: - query_params[self._coverage_properties['x_axis_label']] = \ - slice(bbox[0], bbox[2]) - query_params[self._coverage_properties['y_axis_label']] = \ - slice(bbox[1], bbox[3]) + x_axis_label = self._coverage_properties['x_axis_label'] + x_coords = data.coords[x_axis_label] + if x_coords.values[0] > x_coords.values[-1]: + LOGGER.debug( + 'Reversing slicing of x axis from high to low' + ) + query_params[x_axis_label] = slice(bbox[2], bbox[0]) + else: + query_params[x_axis_label] = slice(bbox[0], bbox[2]) + y_axis_label = self._coverage_properties['y_axis_label'] + y_coords = data.coords[y_axis_label] + if y_coords.values[0] > y_coords.values[-1]: + LOGGER.debug( + 'Reversing slicing of y axis from high to low' + ) + query_params[y_axis_label] = slice(bbox[3], bbox[1]) + else: + query_params[y_axis_label] = slice(bbox[1], bbox[3]) LOGGER.debug('bbox_crs is not currently handled') if datetime_ is not None: - if self._coverage_properties['time_axis_label'] in subsets: + if self._coverage_properties['time_axis_label'] is None: + msg = 'Dataset does not contain a time axis' + LOGGER.error(msg) + raise ProviderQueryError(msg) + elif self._coverage_properties['time_axis_label'] in subsets: msg = 'datetime and temporal subsetting are exclusive' LOGGER.error(msg) raise ProviderQueryError(msg) @@ -196,13 +220,15 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326, LOGGER.warning(err) raise ProviderQueryError(err) - if (any([data.coords[self.x_field].size == 0, - data.coords[self.y_field].size == 0, - data.coords[self.time_field].size == 0])): + if any(size == 0 for size in data.sizes.values()): msg = 'No data found' LOGGER.warning(msg) raise ProviderNoDataError(msg) + if format_ == 'json': + # json does not support float32 + data = _convert_float32_to_float64(data) + out_meta = { 'bbox': [ data.coords[self.x_field].values[0], @@ -210,18 +236,20 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326, data.coords[self.x_field].values[-1], data.coords[self.y_field].values[-1] ], - "time": [ - _to_datetime_string(data.coords[self.time_field].values[0]), - _to_datetime_string(data.coords[self.time_field].values[-1]) - ], "driver": "xarray", "height": data.sizes[self.y_field], "width": data.sizes[self.x_field], - "time_steps": data.sizes[self.time_field], "variables": {var_name: var.attrs for var_name, var in data.variables.items()} } + if self.time_field is not None: + out_meta['time'] = [ + _to_datetime_string(data.coords[self.time_field].values[0]), + _to_datetime_string(data.coords[self.time_field].values[-1]), + ] + out_meta["time_steps"] = data.sizes[self.time_field] + LOGGER.debug('Serializing data in memory') if format_ == 'json': LOGGER.debug('Creating output in CoverageJSON') @@ -230,9 +258,11 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326, LOGGER.debug('Returning data in native zarr format') return _get_zarr_data(data) else: # return data in native format - with tempfile.TemporaryFile() as fp: + with tempfile.NamedTemporaryFile() as fp: LOGGER.debug('Returning data in native NetCDF format') - fp.write(data.to_netcdf()) + data.to_netcdf( + fp.name + ) # we need to pass a string to be able to use the "netcdf4" engine # noqa fp.seek(0) return fp.read() @@ -249,7 +279,6 @@ def gen_covjson(self, metadata, data, fields): LOGGER.debug('Creating CoverageJSON domain') minx, miny, maxx, maxy = metadata['bbox'] - mint, maxt = metadata['time'] selected_fields = { key: value for key, value in self.fields.items() @@ -285,11 +314,6 @@ def gen_covjson(self, metadata, data, fields): 'start': maxy, 'stop': miny, 'num': metadata['height'] - }, - self.time_field: { - 'start': mint, - 'stop': maxt, - 'num': metadata['time_steps'] } }, 'referencing': [{ @@ -304,6 +328,14 @@ def gen_covjson(self, metadata, data, fields): 'ranges': {} } + if self.time_field is not None: + mint, maxt = metadata['time'] + cj['domain']['axes'][self.time_field] = { + 'start': mint, + 'stop': maxt, + 'num': metadata['time_steps'], + } + for key, value in selected_fields.items(): parameter = { 'type': 'Parameter', @@ -322,7 +354,6 @@ def gen_covjson(self, metadata, data, fields): cj['parameters'][key] = parameter data = data.fillna(None) - data = _convert_float32_to_float64(data) try: for key, value in selected_fields.items(): @@ -330,13 +361,18 @@ def gen_covjson(self, metadata, data, fields): 'type': 'NdArray', 'dataType': value['type'], 'axisNames': [ - 'y', 'x', self._coverage_properties['time_axis_label'] + 'y', 'x' ], 'shape': [metadata['height'], - metadata['width'], - metadata['time_steps']] + metadata['width']] } cj['ranges'][key]['values'] = data[key].values.flatten().tolist() # noqa + + if self.time_field is not None: + cj['ranges'][key]['axisNames'].append( + self._coverage_properties['time_axis_label'] + ) + cj['ranges'][key]['shape'].append(metadata['time_steps']) except IndexError as err: LOGGER.warning(err) raise ProviderQueryError('Invalid query parameter') @@ -382,31 +418,37 @@ def _get_coverage_properties(self): self._data.coords[self.x_field].values[-1], self._data.coords[self.y_field].values[-1], ], - 'time_range': [ - _to_datetime_string( - self._data.coords[self.time_field].values[0] - ), - _to_datetime_string( - self._data.coords[self.time_field].values[-1] - ) - ], 'bbox_crs': 'http://www.opengis.net/def/crs/OGC/1.3/CRS84', 'crs_type': 'GeographicCRS', 'x_axis_label': self.x_field, 'y_axis_label': self.y_field, - 'time_axis_label': self.time_field, 'width': self._data.sizes[self.x_field], 'height': self._data.sizes[self.y_field], - 'time': self._data.sizes[self.time_field], - 'time_duration': self.get_time_coverage_duration(), 'bbox_units': 'degrees', - 'resx': np.abs(self._data.coords[self.x_field].values[1] - - self._data.coords[self.x_field].values[0]), - 'resy': np.abs(self._data.coords[self.y_field].values[1] - - self._data.coords[self.y_field].values[0]), - 'restime': self.get_time_resolution() + 'resx': np.abs( + self._data.coords[self.x_field].values[1] + - self._data.coords[self.x_field].values[0] + ), + 'resy': np.abs( + self._data.coords[self.y_field].values[1] + - self._data.coords[self.y_field].values[0] + ), } + if self.time_field is not None: + properties['time_axis_label'] = self.time_field + properties['time_range'] = [ + _to_datetime_string( + self._data.coords[self.time_field].values[0] + ), + _to_datetime_string( + self._data.coords[self.time_field].values[-1] + ), + ] + properties['time'] = self._data.sizes[self.time_field] + properties['time_duration'] = self.get_time_coverage_duration() + properties['restime'] = self.get_time_resolution() + # Update properties based on the xarray's CRS epsg_code = self.storage_crs.to_epsg() LOGGER.debug(f'{epsg_code}') @@ -425,10 +467,12 @@ def _get_coverage_properties(self): properties['axes'] = [ properties['x_axis_label'], - properties['y_axis_label'], - properties['time_axis_label'] + properties['y_axis_label'] ] + if self.time_field is not None: + properties['axes'].append(properties['time_axis_label']) + return properties @staticmethod @@ -455,7 +499,8 @@ def get_time_resolution(self): :returns: time resolution string """ - if self._data[self.time_field].size > 1: + if self.time_field is not None \ + and self._data[self.time_field].size > 1: time_diff = (self._data[self.time_field][1] - self._data[self.time_field][0]) @@ -472,6 +517,9 @@ def get_time_coverage_duration(self): :returns: time coverage duration string """ + if self.time_field is None: + return None + dur = self._data[self.time_field][-1] - self._data[self.time_field][0] ms_difference = dur.values.astype('timedelta64[ms]').astype(np.double) @@ -634,7 +682,7 @@ def _convert_float32_to_float64(data): for var_name in data.variables: if data[var_name].dtype == 'float32': og_attrs = data[var_name].attrs - data[var_name] = data[var_name].astype('float64') + data[var_name] = data[var_name].astype('float64', copy=False) data[var_name].attrs = og_attrs return data diff --git a/tests/test_xarray_zarr_provider.py b/tests/test_xarray_zarr_provider.py index 5163b32a6..ec014e655 100644 --- a/tests/test_xarray_zarr_provider.py +++ b/tests/test_xarray_zarr_provider.py @@ -30,6 +30,7 @@ from numpy import float64, int64 import pytest +import xarray as xr from pygeoapi.provider.xarray_ import XarrayProvider from pygeoapi.util import json_serial @@ -53,6 +54,20 @@ def config(): } +@pytest.fixture() +def config_no_time(tmp_path): + ds = xr.open_zarr(path) + ds = ds.sel(time=ds.time[0]) + ds = ds.drop_vars('time') + ds.to_zarr(tmp_path / 'no_time.zarr') + return { + 'name': 'zarr', + 'type': 'coverage', + 'data': str(tmp_path / 'no_time.zarr'), + 'format': {'name': 'zarr', 'mimetype': 'application/zip'}, + } + + def test_provider(config): p = XarrayProvider(config) @@ -85,3 +100,14 @@ def test_numpy_json_serial(): d = float64(500.00000005) assert json_serial(d) == 500.00000005 + + +def test_no_time(config_no_time): + p = XarrayProvider(config_no_time) + + assert len(p.fields) == 4 + assert p.axes == ['lon', 'lat'] + + coverage = p.query(format='json') + + assert sorted(coverage['domain']['axes'].keys()) == ['x', 'y']