Skip to content

Commit

Permalink
normalized classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisMRuss committed Jun 30, 2024
1 parent 315e3dd commit e444d54
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/oxonfair/learners/fair.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,16 +738,16 @@ def predict_proba(self, data, *, transform_features=True, force_normalization=Fa
tmp = np.zeros_like(proba)
cache = self.offset[self.infered_to_hard(onehot)]
if force_normalization:
tmp[:, 1] = np.max(cache, 0)
tmp[:, 0] = np.max(-cache, 0)
tmp[:, 1] = np.maximum(cache, 0)
tmp[:, 0] = np.maximum(-cache, 0)
else:
tmp[:, 1] = cache
else:
tmp2 = onehot.dot(self.offset)
if force_normalization:
tmp = np.zeros_like(proba)
tmp[:, 1] = np.max(tmp2[:, 1]-tmp2[:, 0], 0)
tmp[:, 0] = np.max(tmp2[:, 0]-tmp2[:, 1], 0)
tmp[:, 1] = np.maximum(tmp2[:, 1] - tmp2[:, 0], 0)
tmp[:, 0] = np.maximum(tmp2[:, 0] - tmp2[:, 1], 0)
else:
tmp = tmp2
if self.round is not False:
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/test_ag.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def test_normalized_classifier(fast=True):
assert np.isclose(response.sum(1), 1).all()

response2 = fpred.predict_proba(test_data)
assert (response.max(1) == response2.max(1)).all()
assert (np.argmax(response,1) == np.argmax(response2,1)).all()


def test_normalized_classifier_slow():
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/test_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_normalized_classifier(fast=True):
assert np.isclose(response.sum(1), 1).all()

response2 = fpred.predict_proba(test_dict['data'])
assert (response.max(1) == response2.max(1)).all()
assert (np.argmax(response, 1) == np.argmax(response2, 1)).all()


def test_normalized_classifier_slow():
Expand Down

0 comments on commit e444d54

Please sign in to comment.