From e444d548956d060d8f253a25812feffa72e1557d Mon Sep 17 00:00:00 2001 From: Chris Russell Date: Sun, 30 Jun 2024 23:02:05 +0100 Subject: [PATCH] normalized classifier --- src/oxonfair/learners/fair.py | 8 ++++---- tests/unittests/test_ag.py | 2 +- tests/unittests/test_scipy.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/oxonfair/learners/fair.py b/src/oxonfair/learners/fair.py index e3b1400..88b0a79 100644 --- a/src/oxonfair/learners/fair.py +++ b/src/oxonfair/learners/fair.py @@ -738,16 +738,16 @@ def predict_proba(self, data, *, transform_features=True, force_normalization=Fa tmp = np.zeros_like(proba) cache = self.offset[self.infered_to_hard(onehot)] if force_normalization: - tmp[:, 1] = np.max(cache, 0) - tmp[:, 0] = np.max(-cache, 0) + tmp[:, 1] = np.maximum(cache, 0) + tmp[:, 0] = np.maximum(-cache, 0) else: tmp[:, 1] = cache else: tmp2 = onehot.dot(self.offset) if force_normalization: tmp = np.zeros_like(proba) - tmp[:, 1] = np.max(tmp2[:, 1]-tmp2[:, 0], 0) - tmp[:, 0] = np.max(tmp2[:, 0]-tmp2[:, 1], 0) + tmp[:, 1] = np.maximum(tmp2[:, 1] - tmp2[:, 0], 0) + tmp[:, 0] = np.maximum(tmp2[:, 0] - tmp2[:, 1], 0) else: tmp = tmp2 if self.round is not False: diff --git a/tests/unittests/test_ag.py b/tests/unittests/test_ag.py index e61bbee..dff3454 100644 --- a/tests/unittests/test_ag.py +++ b/tests/unittests/test_ag.py @@ -212,7 +212,7 @@ def test_normalized_classifier(fast=True): assert np.isclose(response.sum(1), 1).all() response2 = fpred.predict_proba(test_data) - assert (response.max(1) == response2.max(1)).all() + assert (np.argmax(response,1) == np.argmax(response2,1)).all() def test_normalized_classifier_slow(): diff --git a/tests/unittests/test_scipy.py b/tests/unittests/test_scipy.py index c4e9b53..d8a75af 100644 --- a/tests/unittests/test_scipy.py +++ b/tests/unittests/test_scipy.py @@ -346,7 +346,7 @@ def test_normalized_classifier(fast=True): assert np.isclose(response.sum(1), 1).all() response2 = fpred.predict_proba(test_dict['data']) - assert (response.max(1) == response2.max(1)).all() + assert (np.argmax(response, 1) == np.argmax(response2, 1)).all() def test_normalized_classifier_slow():