From 585e436771f2d2bb4a8a10c8fbb6af8e4f006702 Mon Sep 17 00:00:00 2001 From: RYCKEBOER Thomas <37658138+thomasryck@users.noreply.github.com> Date: Fri, 27 Sep 2024 15:22:20 +0200 Subject: [PATCH] debug ci --- cyanure/estimators.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cyanure/estimators.py b/cyanure/estimators.py index 50284dd2c..d21f6300b 100644 --- a/cyanure/estimators.py +++ b/cyanure/estimators.py @@ -259,6 +259,11 @@ def __init__(self, loss='square', penalty='l2', fit_intercept=False, dual=None, self.n_threads = n_threads self.safe = safe + if (loss == "multiclass-logistic" or loss == "logistic") and self.lambda_1 == -10.0: + self.lambda_1 = 0 + elif (self.lambda_1 == -10.0): + self.lambda_1 = 0 + def fit(self, X, labels, le_parameter=None): """ Fit the parameters. @@ -298,11 +303,6 @@ def fit(self, X, labels, le_parameter=None): if loss is None: loss = self.loss - if (loss == "multiclass-logistic" or loss == "logistic") and self.lambda_1 == -10.0: - self.lambda_1 = 0 - elif (self.lambda_1 == -10.0): - self.lambda_1 = 0 - labels = np.squeeze(labels) initial_weight, yf, nclasses = self._initialize_weight(X, labels)