From 1c9feecd76db597e96fbea46f33f26aece028c5e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 23 Oct 2023 15:02:35 +0200 Subject: [PATCH] some fixes --- src/gluonts/core/serde/_dataclass.py | 4 ++-- src/gluonts/ext/rotbaum/_types.py | 2 +- test/dataset/artificial/test_recipe.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gluonts/core/serde/_dataclass.py b/src/gluonts/core/serde/_dataclass.py index 8cf583b32c..f03a47f527 100644 --- a/src/gluonts/core/serde/_dataclass.py +++ b/src/gluonts/core/serde/_dataclass.py @@ -115,7 +115,7 @@ def _call(self, env): def dataclass( cls=None, *, - init=True, + init=False, repr=True, eq=True, order=False, @@ -131,7 +131,7 @@ def dataclass( """ # assert frozen - assert init + assert init is False if cls is None: return _dataclass diff --git a/src/gluonts/ext/rotbaum/_types.py b/src/gluonts/ext/rotbaum/_types.py index 220c07eb04..3d612c0168 100644 --- a/src/gluonts/ext/rotbaum/_types.py +++ b/src/gluonts/ext/rotbaum/_types.py @@ -25,7 +25,7 @@ class FeatureImportanceResult(BaseModel): feat_dynamic_real: List[Union[List[float], float]] feat_dynamic_cat: List[Union[List[float], float]] - @model_validator() + @model_validator(mode="before") @classmethod def check_shape(cls, values): """ diff --git a/test/dataset/artificial/test_recipe.py b/test/dataset/artificial/test_recipe.py index 78971668ea..5774be6391 100644 --- a/test/dataset/artificial/test_recipe.py +++ b/test/dataset/artificial/test_recipe.py @@ -135,7 +135,7 @@ def test_recipe_dataset(recipe) -> None: freq="D", feat_static_real=[BasicFeatureInfo(name="feat_static_real_000")], feat_static_cat=[ - CategoricalFeatureInfo(name="foo", cardinality=10) + CategoricalFeatureInfo(name="foo", cardinality="10") ], feat_dynamic_real=[BasicFeatureInfo(name="binary_causal")], ),