From 1c7e97163850126e28be91d228adc4e3a1efc7a9 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 14 Jun 2019 18:41:19 +0200 Subject: [PATCH] fix mypy complaints (#110) --- test/distribution/test_distribution_inference.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/distribution/test_distribution_inference.py b/test/distribution/test_distribution_inference.py index 1bea0be612..a96a2cb77d 100644 --- a/test/distribution/test_distribution_inference.py +++ b/test/distribution/test_distribution_inference.py @@ -62,8 +62,8 @@ def maximum_likelihood_estimate_sgd( distr_output: DistributionOutput, samples: mx.ndarray, init_biases: List[mx.ndarray.NDArray] = None, - num_epochs: PositiveInt = 5, - learning_rate: PositiveFloat = 1e-2, + num_epochs: PositiveInt = PositiveInt(5), + learning_rate: PositiveFloat = PositiveFloat(1e-2), hybridize: bool = True, ) -> Iterable[float]: model_ctx = mx.cpu() @@ -151,7 +151,7 @@ def test_studentT_likelihood( init_biases=init_bias, hybridize=hybridize, num_epochs=PositiveInt(10), - learning_rate=1e-2, + learning_rate=PositiveFloat(1e-2), ) assert ( @@ -352,7 +352,7 @@ def domain_map(cls, F, mu, b): LaplaceFixedVarianceOutput(), samples, init_biases=[3 * mu, 0.1], - learning_rate=1e-3, + learning_rate=PositiveFloat(1e-3), hybridize=hybridize, )