Skip to content

Commit

Permalink
fix: multiclass case
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasryck committed Sep 26, 2024
1 parent f5de7f3 commit 634c514
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions cyanure/estimators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Contain the different estimators of the library."""

from abc import abstractmethod, ABC

import math
import inspect
import warnings
Expand Down Expand Up @@ -50,6 +49,9 @@ def _warm_start(self, X, initial_weight, nclasses):
else:
initial_weight = np.squeeze(self.coef_)


initial_weight = np.asfortranarray(initial_weight, X.dtype)

if self.warm_start and self.solver in ('auto', 'miso', 'catalyst-miso', 'qning-miso'):
n = X.shape[0]
# TODO Ecrire test pour dual surtout défensif
Expand Down Expand Up @@ -286,17 +288,19 @@ def fit(self, X, labels, le_parameter=None):

if (self.multi_class == "multinomial" or
(self.multi_class == "auto" and not self._binary_problem)) and self.loss == "logistic":
if self.multi_class == "multinomial":
if len(np.unique(labels)) != 2:
self._binary_problem = False
if len(np.unique(labels)) != 2:
self._binary_problem = False

loss = "multiclass-logistic"
logger.info(
"Loss has been set to multiclass-logistic because "
"the multiclass parameter is set to multinomial!")
loss = "multiclass-logistic"
logger.info(
"Loss has been set to multiclass-logistic because "
"the multiclass parameter is set to multinomial!")

if loss is None:
loss = self.loss

if (loss == "multiclass-logistic" or loss == "logistic") and self.lambda_1 == 1.0:
self.lambda_1 = 1 / len(X)

labels = np.squeeze(labels)
initial_weight, yf, nclasses = self._initialize_weight(X, labels)
Expand Down Expand Up @@ -1175,7 +1179,7 @@ class LogisticRegression(Classifier):
_estimator_type = "classifier"

def __init__(self, penalty='l2', loss='logistic', fit_intercept=True,
verbose=False, lambda_1=0, lambda_2=0, lambda_3=0,
verbose=False, lambda_1=1.0, lambda_2=0, lambda_3=0,
solver='auto', tol=1e-3, duality_gap_interval=10,
max_iter=500, limited_memory_qning=20,
fista_restart=50, warm_start=False, n_threads=-1,
Expand Down

0 comments on commit 634c514

Please sign in to comment.