From ec73bd74a2305c494991123777ad649e7f790386 Mon Sep 17 00:00:00 2001 From: CamDavidsonPilon Date: Tue, 7 Jan 2020 09:05:01 -0500 Subject: [PATCH] fix some tests and use initial conditions --- lifelines/fitters/piecewise_exponential_fitter.py | 2 ++ lifelines/fitters/spline_fitter.py | 3 +++ lifelines/tests/test_estimation.py | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/lifelines/fitters/piecewise_exponential_fitter.py b/lifelines/fitters/piecewise_exponential_fitter.py index 75990ce88..ed9673f16 100644 --- a/lifelines/fitters/piecewise_exponential_fitter.py +++ b/lifelines/fitters/piecewise_exponential_fitter.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import autograd.numpy as np from lifelines.fitters import KnownModelParametricUnivariateFitter +from lifelines import utils class PiecewiseExponentialFitter(KnownModelParametricUnivariateFitter): @@ -81,6 +82,7 @@ def __init__(self, breakpoints, *args, **kwargs): super(PiecewiseExponentialFitter, self).__init__(*args, **kwargs) def _cumulative_hazard(self, params, times): + times = np.atleast_1d(times) n = times.shape[0] times = times.reshape((n, 1)) bp = self.breakpoints diff --git a/lifelines/fitters/spline_fitter.py b/lifelines/fitters/spline_fitter.py index 506bfd7ff..40c47b92a 100644 --- a/lifelines/fitters/spline_fitter.py +++ b/lifelines/fitters/spline_fitter.py @@ -90,6 +90,9 @@ def __init__(self, knot_locations: np.ndarray, *args, **kwargs): self._bounds = [(None, None)] * (self.n_knots) super(SplineFitter, self).__init__(*args, **kwargs) + def _create_initial_point(self, Ts, E, entry, weights): + return 0.1 * np.ones(self.n_knots) + def _cumulative_hazard(self, params, t): phis = params lT = np.log(t) diff --git a/lifelines/tests/test_estimation.py b/lifelines/tests/test_estimation.py index 62eac4402..c86dbf1b1 100644 --- a/lifelines/tests/test_estimation.py +++ b/lifelines/tests/test_estimation.py @@ -118,7 +118,7 @@ def __init__(self, *args, **kwargs): class SplineFitterTesting(SplineFitter): def __init__(self, *args, **kwargs): - super(SplineFitterTesting, self).__init__([0.0, 50.0], *args, **kwargs) + super(SplineFitterTesting, self).__init__([0.0, 40.0], *args, **kwargs) class CustomRegressionModelTesting(ParametricRegressionFitter):