diff --git a/src/gluonts/testutil/shell.py b/src/gluonts/testutil/shell.py index 5e1815db6f..c723ae6712 100644 --- a/src/gluonts/testutil/shell.py +++ b/src/gluonts/testutil/shell.py @@ -21,7 +21,7 @@ import typing import waitress from contextlib import closing, contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from multiprocessing.context import ForkContext from pathlib import Path from typing import Any, ContextManager, Dict, Iterable, List, Optional, Type @@ -119,7 +119,7 @@ def free_port() -> int: class Server: env: ServeEnv forecaster_type: Optional[Type[Predictor]] - settings: Settings = Settings() + settings: Settings = field(default_factory=Settings) def run(self): flask_app = make_flask_app( diff --git a/test/conftest.py b/test/conftest.py index b624259a15..8ec8649909 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -30,7 +30,7 @@ try: import mxnet as mx -except ImportError: +except (ImportError, OSError): mx = None try: diff --git a/test/core/test_serde_dataclass.py b/test/core/test_serde_dataclass.py index 324e33c746..f602699711 100644 --- a/test/core/test_serde_dataclass.py +++ b/test/core/test_serde_dataclass.py @@ -20,14 +20,13 @@ @serde.dataclass class Estimator: prediction_length: int - context_length: int = serde.OrElse( - lambda prediction_length: prediction_length * 2 - ) + context_length: int = serde.EVENTUAL use_feat_static_cat: bool = True cardinality: List[int] = serde.EVENTUAL - def __eventually__(self, cardinality): + def __eventually__(self, context_length, cardinality): + context_length.set_default(self.prediction_length * 2) if not self.use_feat_static_cat: cardinality.set([1]) else: