Skip to content

Commit

Permalink
add force_normalisation to predict_proba
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisMRuss committed Jun 30, 2024
1 parent ff9c723 commit 315e3dd
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 5 deletions.
26 changes: 23 additions & 3 deletions src/oxonfair/learners/fair.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def evaluate_groups(self, data=None, groups=None, metrics=None, fact=None, *,
out = pd.concat([original, updated], keys=['original', 'updated'])
return out

def predict_proba(self, data, *, transform_features=True):
def predict_proba(self, data, *, transform_features=True, force_normalization=False):
"""Duplicates the functionality of predictor.predict_proba for fairpredictor.
parameters
----------
Expand All @@ -704,6 +704,7 @@ def predict_proba(self, data, *, transform_features=True):
------
a pandas array of scores. Note, these scores are not probabilities, and not guarenteed to
be non-negative or to sum to 1.
To make them positive and sum to 1 use force_normalization=True
"""
if self.groups is None and self.inferred_groups is False:
_guard_predictor_data_match(data, self.predictor)
Expand Down Expand Up @@ -735,12 +736,31 @@ def predict_proba(self, data, *, transform_features=True):
onehot = call_or_get_proba(self.inferred_groups, data)
if self.use_fast is True:
tmp = np.zeros_like(proba)
tmp[:, 1] = self.offset[self.infered_to_hard(onehot)]
cache = self.offset[self.infered_to_hard(onehot)]
if force_normalization:
tmp[:, 1] = np.max(cache, 0)
tmp[:, 0] = np.max(-cache, 0)
else:
tmp[:, 1] = cache
else:
tmp = onehot.dot(self.offset)
tmp2 = onehot.dot(self.offset)
if force_normalization:
tmp = np.zeros_like(proba)
tmp[:, 1] = np.max(tmp2[:, 1]-tmp2[:, 0], 0)
tmp[:, 0] = np.max(tmp2[:, 0]-tmp2[:, 1], 0)
else:
tmp = tmp2
if self.round is not False:
proba = np.around(proba / self.round) * self.round
proba += tmp
if force_normalization:
sum = proba.sum(1)
if isinstance(proba, pd.DataFrame):
proba[proba.columns[0]] /= sum
proba[proba.columns[1]] /= sum
else:
proba /= sum[:, np.newaxis]

return proba

def predict(self, data, *, transform_features=True) -> pd.Series:
Expand Down
20 changes: 20 additions & 0 deletions tests/unittests/test_ag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import oxonfair as fair
from oxonfair import FairPredictor
from oxonfair.utils import group_metrics as gm
import numpy as np


def test_base_functionality():
Expand Down Expand Up @@ -201,3 +202,22 @@ def test_recall_diff_inferred(use_fast=True):

def test_recall_diff_inferred_slow():
test_recall_diff_inferred(False)


def test_normalized_classifier(fast=True):
fpred = FairPredictor(predictor, train_data, 'sex', use_fast=fast)
fpred.fit(gm.accuracy, gm.demographic_parity, 0.02)
response = fpred.predict_proba(test_data, force_normalization=True)
assert (response >= 0).all().all()
assert np.isclose(response.sum(1), 1).all()

response2 = fpred.predict_proba(test_data)
assert (response.max(1) == response2.max(1)).all()


def test_normalized_classifier_slow():
test_normalized_classifier(False)


def test_normalized_classifier_hybrid():
test_normalized_classifier('hybrid')
25 changes: 23 additions & 2 deletions tests/unittests/test_scipy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Tests for FairPredictor"""

import pandas as pd
import sklearn.ensemble
import sklearn.tree
import oxonfair as fair
import numpy as np
from oxonfair import FairPredictor
from oxonfair.utils import group_metrics as gm

Expand All @@ -13,7 +15,7 @@
except ModuleNotFoundError:
PLT_EXISTS = False

classifier_type = sklearn.tree.DecisionTreeClassifier
classifier_type = sklearn.ensemble.RandomForestClassifier

train_data = pd.read_csv("https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv")
test_data = pd.read_csv("https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv")
Expand Down Expand Up @@ -181,7 +183,7 @@ def test_pathologoical2(use_fast=True):


def test_recall_diff(use_fast=True):
"""Maximize accuracy while enforcing weak equalized odds,
"""Maximize accuracy while enforcing weak equal opportunity,
such that the difference in recall between groups is less than 2.5%
This also tests the sign functionality on constraints and the objective"""

Expand Down Expand Up @@ -334,3 +336,22 @@ def test_recall_diff_inferred_slow():

def test_recall_diff_inferred_hybrid():
test_recall_diff_inferred('hybrid')


def test_normalized_classifier(fast=True):
fpred = FairPredictor(predictor, val_dict, 'sex_ Female', use_fast=fast)
fpred.fit(gm.accuracy, gm.demographic_parity, 0.02)
response = fpred.predict_proba(test_dict['data'], force_normalization=True)
assert (response >= 0).all().all()
assert np.isclose(response.sum(1), 1).all()

response2 = fpred.predict_proba(test_dict['data'])
assert (response.max(1) == response2.max(1)).all()


def test_normalized_classifier_slow():
test_normalized_classifier(False)


def test_normalized_classifier_hybrid():
test_normalized_classifier('hybrid')

0 comments on commit 315e3dd

Please sign in to comment.