diff --git a/test/distribution/test_distribution_inference.py b/test/distribution/test_distribution_inference.py index 1bea0be612..a96a2cb77d 100644 --- a/test/distribution/test_distribution_inference.py +++ b/test/distribution/test_distribution_inference.py @@ -62,8 +62,8 @@ def maximum_likelihood_estimate_sgd( distr_output: DistributionOutput, samples: mx.ndarray, init_biases: List[mx.ndarray.NDArray] = None, - num_epochs: PositiveInt = 5, - learning_rate: PositiveFloat = 1e-2, + num_epochs: PositiveInt = PositiveInt(5), + learning_rate: PositiveFloat = PositiveFloat(1e-2), hybridize: bool = True, ) -> Iterable[float]: model_ctx = mx.cpu() @@ -151,7 +151,7 @@ def test_studentT_likelihood( init_biases=init_bias, hybridize=hybridize, num_epochs=PositiveInt(10), - learning_rate=1e-2, + learning_rate=PositiveFloat(1e-2), ) assert ( @@ -352,7 +352,7 @@ def domain_map(cls, F, mu, b): LaplaceFixedVarianceOutput(), samples, init_biases=[3 * mu, 0.1], - learning_rate=1e-3, + learning_rate=PositiveFloat(1e-3), hybridize=hybridize, )