diff --git a/datalad_next/constraints/compound.py b/datalad_next/constraints/compound.py index 7ff5e1009..deb517867 100644 --- a/datalad_next/constraints/compound.py +++ b/datalad_next/constraints/compound.py @@ -214,6 +214,83 @@ def for_dataset(self, dataset: DatasetParameter) -> Constraint: ) +class EnsureMappings(Constraint): + """Ensure several mappings of keys to values of a specific nature""" + + def __init__(self, + key: Constraint, + value: Constraint, + delimiter: str = ':', + pair_delimiter: str = ',', + allow_length2_sequence: bool = True + ): + """ + Parameters + ---------- + key: + Key constraint instance. + value: + Value constraint instance. + delimiter: + Delimiter to use for splitting a key from a value for a `str` input. + """ + super().__init__() + self._key_constraint = key + self._value_constraint = value + self._delimiter = delimiter + self._pair_delimiter = pair_delimiter + self._allow_length2_sequence: bool = True + + def short_description(self): + return 'mapping of {} -> {}'.format( + self._key_constraint.short_description(), + self._value_constraint.short_description(), + ) + + def _get_keys_values(self, value) -> tuple: + # determine key and value from various kinds of input + if isinstance(value, str): + # split into pairs + # TODO: what about whitespace? e.g., 'key: val', 'morekey': moreval' + pairs = value.split(sep=self._pair_delimiter) + keys_n_values = [p.split(sep=self._delimiter, maxsplit=1) for p in pairs] + try: + keys = [p[0] for p in keys_n_values] + vals = [p[1] for p in keys_n_values] + except IndexError as e: + raise(ValueError('Could not cast input to key value pairs')) + elif isinstance(value, dict): + if not len(value): + raise ValueError('dict does not contain a key') + keys = list(value.keys()) + vals = list(value.values()) + elif self._allow_length2_sequence and isinstance(value, (list, tuple)): + if not len(value) % 2 == 0 or len(value) < 2: + raise ValueError('sequence can not be chunked into pairs by 2') + keys = value[::2] + vals = value[1::2] + else: + raise ValueError(f'Unsupported data type for mapping: {value!r}') + + return keys, vals + + def __call__(self, value) -> Dict: + keys, vals = self._get_keys_values(value) + keys = [self._key_constraint(k) for k in keys] + vals = [self._value_constraint(v) for v in vals] + return dict(zip(keys, vals)) + + def for_dataset(self, dataset: DatasetParameter) -> Constraint: + # tailor both constraints to the dataset and reuse delimiter + return EnsureMappings( + key=self._key_constraint.for_dataset(dataset), + value=self._value_constraint.for_dataset(dataset), + delimiter=self._delimiter, + pair_delimiter=self._pair_delimiter, + allow_length2_sequence=self._allow_length2_sequence + ) + + class EnsureGeneratorFromFileLike(Constraint): """Ensure a constraint for each item read from a file-like. diff --git a/datalad_next/constraints/tests/test_compound.py b/datalad_next/constraints/tests/test_compound.py index 8244d4a67..a6ddf5bc8 100644 --- a/datalad_next/constraints/tests/test_compound.py +++ b/datalad_next/constraints/tests/test_compound.py @@ -3,13 +3,18 @@ import pytest from tempfile import NamedTemporaryFile from unittest.mock import patch +from pathlib import Path from datalad_next.datasets import Dataset from datalad_next.utils import on_windows +from ..base import DatasetParameter + from ..basic import ( EnsureInt, EnsureBool, + EnsurePath, + EnsureStr, ) from ..compound import ( ConstraintWithPassthrough, @@ -17,6 +22,7 @@ EnsureListOf, EnsureTupleOf, EnsureMapping, + EnsureMappings, EnsureGeneratorFromFileLike, ) @@ -121,6 +127,59 @@ def test_EnsureMapping(tmp_path): constraint._value_constraint.for_dataset(ds) +def test_EnsureMappings(tmp_path): + # test scenarios that should work + true_keys = ['one', 'two', 'three'] + true_vals = [1, 2, 3] + constraint = EnsureMappings(key=EnsureStr(), value=EnsureInt()) + assert 'mapping of str -> int' in constraint.short_description() + + for v in ('one:1,two:2,three:3', # string input + {'one': 1, 'two': 2, 'three': 3}, # dict input + dict(one=1, two=2, three=3), + ['one', 1, 'two', 2, 'three', 3] # sequence input + ): + d = constraint(v) + assert isinstance(d, dict) + assert len(d) == 3 + assert list(d.keys()) == true_keys + assert list(d.values()) == true_vals + + # test scenarios that should crash + for v in (true_keys, # non-module 2 sequence + '5', + [], # too short sequence + tuple(), + {}, + [5, False, False], + set('wtf') # wrong data type + ): + with pytest.raises(ValueError): + d = constraint(v) + + # test different delimiters + constraint = EnsureMappings(key=EnsureStr(), value=EnsureInt(), + delimiter='=', pair_delimiter='.') + d = constraint('one=1.two=2.three=3') + assert isinstance(d, dict) + assert len(d) == 3 + assert list(d.keys()) == true_keys + assert list(d.values()) == true_vals + # test that the paths are resolved for the dataset + ds = Dataset(tmp_path) + pathconstraint = \ + EnsureMappings(key=EnsurePath(), value=EnsureInt()).for_dataset( + DatasetParameter(tmp_path, ds)) + assert pathconstraint('some:5,somemore:6') == \ + {(Path.cwd() / 'some'): 5, Path.cwd() / 'somemore': 6} + pathconstraint = \ + EnsureMappings(key=EnsurePath(), value=EnsurePath()).for_dataset( + DatasetParameter(ds, ds)) + assert pathconstraint('some:other,something:more') == \ + {(ds.pathobj / 'some'): (ds.pathobj / 'other'), + (ds.pathobj / 'something'): (ds.pathobj / 'more')} + + def test_EnsureGeneratorFromFileLike(): item_constraint = EnsureMapping(EnsureInt(), EnsureBool(), delimiter='::') constraint = EnsureGeneratorFromFileLike(item_constraint)