Skip to content

Commit

Permalink
Updating pydantic to 0.28. (#108)
Browse files Browse the repository at this point in the history
Fixes #107.

Between 0.12 (the last version we used) and 0.28, pydantic

* changed `get_validators` to `__get_validators__`
* prevents us from using the `fields` attribute on models

Fixing doctest.

Also, increasing TOL for flaky tests.
  • Loading branch information
Jasper Schulz authored Jun 14, 2019
1 parent 23bb7ef commit 04c4735
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ matplotlib==3.*
mxnet>=1.3.1
numpy==1.14.*
pandas>=0.22.0
pydantic==0.12.*
pydantic==0.28.*
tqdm>=4.23.0
ujson>=1.35
8 changes: 4 additions & 4 deletions src/gluonts/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def validated(base_model=None):
accessed through the ``Model`` attribute of the decorated initiazlier.
>>> ComplexNumber.__init__.Model
<class 'abc.ComplexNumberModel'>
<class 'ComplexNumberModel'>
The Pydantic model is synthesized automatically from on the parameter
names and types of the decorated initializer. In the ``ComplexNumber``
Expand Down Expand Up @@ -427,12 +427,12 @@ def validate(cls, v: Union[str, mx.Context]) -> mx.Context:
)

@classmethod
def get_validators(cls) -> mx.Context:
def __get_validators__(cls) -> mx.Context:
yield cls.validate


mx.Context.validate = MXContext.validate
mx.Context.get_validators = MXContext.get_validators
mx.Context.__get_validators__ = MXContext.__get_validators__


def has_gpu_support() -> bool:
Expand Down Expand Up @@ -467,7 +467,7 @@ class DType:
"""

@classmethod
def get_validators(cls):
def __get_validators__(cls):
yield cls.validate

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Timestamp(pd.Timestamp):
# we need to sublcass, since pydantic otherwise converts the value into
# datetime.datetime instead of using pd.Timestamp
@classmethod
def get_validators(cls):
def __get_validators__(cls):
def conv(val):
if isinstance(val, pd.Timestamp):
return val
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/model/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def __hash__(self):
return hash(str(self))

@classmethod
def get_validators(cls):
def __get_validators__(cls):
yield cls.validate

@classmethod
Expand Down
16 changes: 8 additions & 8 deletions src/gluonts/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,19 +569,19 @@ class SwapAxes(SimpleTransformation):
Parameters
----------
fields
input_fields
Field to apply to
axes
Axes to use
"""

@validated()
def __init__(self, fields: List[str], axes: Tuple[int, int]) -> None:
self.fields = fields
def __init__(self, input_fields: List[str], axes: Tuple[int, int]) -> None:
self.input_fields = input_fields
self.axis1, self.axis2 = axes

def transform(self, data: DataEntry) -> DataEntry:
for field in self.fields:
for field in self.input_fields:
data[field] = self.swap(data[field])
return data

Expand Down Expand Up @@ -1220,16 +1220,16 @@ class SelectFields(MapTransformation):
Parameters
----------
fields
input_fields
List of fields to keep.
"""

@validated()
def __init__(self, fields: List[str]) -> None:
self.fields = fields
def __init__(self, input_fields: List[str]) -> None:
self.input_fields = input_fields

def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
return {f: data[f] for f in self.fields}
return {f: data[f] for f in self.input_fields}


class TransformedDataset(Dataset):
Expand Down
12 changes: 6 additions & 6 deletions test/core/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def __init__(self, a: int, b: float, c: Complex, d: int) -> None:
class Bar:
@validated()
def __init__(
self, x_list: List[Foo], x_dict: Dict[int, Foo], fields: List[Foo]
self, x_list: List[Foo], x_dict: Dict[int, Foo], input_fields: List[Foo]
) -> None:
self.x_list = x_list
self.x_dict = x_dict
self.fields = fields
self.input_fields = input_fields


# define test.test_components.X as alias of X within the scope of this
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_component_ctor():
for i in range(6)
}

bar01 = Bar(x_list, fields=fields, x_dict=x_dict)
bar01 = Bar(x_list, input_fields=fields, x_dict=x_dict)
bar02 = load_code(dump_code(bar01))
bar03 = load_json(dump_json(bar02))

Expand All @@ -116,15 +116,15 @@ def compare_vals(x, y, z):

compare_tpes(bar02.x_list, bar02.x_list, bar03.x_list, tpe=list)
compare_tpes(bar02.x_dict, bar02.x_dict, bar03.x_dict, tpe=dict)
compare_tpes(bar02.fields, bar02.fields, bar03.fields, tpe=list)
compare_tpes(bar02.input_fields, bar02.input_fields, bar03.input_fields, tpe=list)

compare_vals(len(bar02.x_list), len(bar02.x_list), len(bar03.x_list))
compare_vals(len(bar02.x_dict), len(bar02.x_dict), len(bar03.x_dict))
compare_vals(len(bar02.fields), len(bar02.fields), len(bar03.fields))
compare_vals(len(bar02.input_fields), len(bar02.input_fields), len(bar03.input_fields))

compare_vals(bar02.x_list, bar02.x_list, bar03.x_list)
compare_vals(bar02.x_dict, bar02.x_dict, bar03.x_dict)
compare_vals(bar02.fields, bar02.fields, bar03.fields)
compare_vals(bar02.input_fields, bar02.input_fields, bar03.input_fields)

baz01 = Baz(a="0", b="9", c=Complex(x="1", y="2"), d="42")
baz02 = load_json(dump_json(baz01))
Expand Down
2 changes: 1 addition & 1 deletion test/distribution/test_distribution_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

NUM_SAMPLES = 2000
BATCH_SIZE = 32
TOL = 0.2
TOL = 0.3
START_TOL_MULTIPLE = 1

np.random.seed(1)
Expand Down

0 comments on commit 04c4735

Please sign in to comment.