Skip to content

Commit

Permalink
Merge pull request #47 from oxfordinternetinstitute/frontier
Browse files Browse the repository at this point in the history
Quality of life improvements
  • Loading branch information
ChrisMRuss authored Nov 26, 2024
2 parents 55fb1f5 + 8d87613 commit 82e3688
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 14 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

FAIR = "oxonfair"

version = "0.2.1.8"
version = "0.2.1.9"

PYTHON_REQUIRES = ">=3.8"

Expand Down
16 changes: 14 additions & 2 deletions src/oxonfair/learners/fair.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.preprocessing import OneHotEncoder
from ..utils import group_metrics
from .. utils.scipy_metrics_cont_wrapper import ScorerRequiresContPred
from ..utils.group_metric_classes import BaseGroupMetric
from ..utils.group_metric_classes import BaseGroupMetric, Overall

from ..utils import performance as perf
from . import efficient_compute, fair_frontier
Expand Down Expand Up @@ -720,6 +720,9 @@ def evaluate_fairness(self, data=None, groups=None, factor=None, *,

collect = pd.concat([collect, new_pd], axis='columns')
collect.columns = ['original', 'updated']
else:
collect = pd.concat([collect,], axis='columns')
collect.columns = ['original']

return collect

Expand Down Expand Up @@ -822,7 +825,9 @@ def evaluate_groups(self, data=None, groups=None, metrics=None, fact=None, *,
verbose=verbose)

out = updated
if return_original:
if self.frontier is None:
out = pd.concat([updated, ], keys=['original', ])
elif return_original:
out = pd.concat([original, updated], keys=['original', 'updated'])
return out

Expand Down Expand Up @@ -1093,6 +1098,9 @@ def fix_groups(metric, groups):

groups = groups_to_masks(groups)

if isinstance(metric, Overall): # Performance hack. If metric is of type overall, groups don't matter -- assign all groups to 1.
groups = np.ones(groups.shape[0])

def new_metric(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
return metric(y_true, y_pred, groups)
return new_metric
Expand Down Expand Up @@ -1146,6 +1154,10 @@ def fix_groups_and_conditioning(metric, groups, conditioning_factor, y_true):
weights = metric.cond_weights(conditioning_factor, groups, y_true)
groups = groups_to_masks(groups)

if isinstance(metric, Overall): # Performance hack. If metric is of type overall, groups don't matter -- assign all groups to 1.
groups = np.ones(groups.shape[0])


def new_metric(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
return metric(y_true, y_pred, groups, weights)
return new_metric
Expand Down
10 changes: 9 additions & 1 deletion tests/test_check_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,17 @@ def test_check_style_examples():

def test_md_links():
missing_links = lc.check_links('./', ext='.md', recurse=True, use_async=False)
missing_links_eg = lc.check_links('./examples/', ext='.md', recurse=True)

for link in missing_links:
warnings.warn(link)
assert missing_links == []

for link in missing_links_eg:
warnings.warn(link)

assert missing_links_eg == []
assert missing_links == [('README.md', 'https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4331652', 429),] or missing_links == []
# SSRN thinks we're crawling and blocks exactly one paper.


def test_run_notebooks_without_errors():
Expand Down
17 changes: 9 additions & 8 deletions tests/unittests/test_ag.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,20 @@ def test_recall_diff(use_fast=True):

fpredictor = fair.FairPredictor(predictor, test_data, "sex", use_fast=use_fast)

fpredictor.fit(gm.accuracy, gm.recall.diff, 0.025)
limit =0.01
fpredictor.fit(gm.accuracy, gm.recall.diff, limit)

# Evaluate the change in fairness (recall difference corresponds to EO)
measures = fpredictor.evaluate_fairness(verbose=False)

assert measures["updated"]["recall.diff"] < 0.025
assert measures["updated"]["recall.diff"] < limit
measures = fpredictor.evaluate()
acc = measures["updated"]["Accuracy"]
fpredictor.fit(gm.accuracy, gm.recall.diff, 0.025, greater_is_better_const=True)
fpredictor.fit(gm.accuracy, gm.recall.diff, limit, greater_is_better_const=True)
measures = fpredictor.evaluate_fairness(verbose=False)
assert measures["original"]["recall.diff"] > 0.025
assert measures["original"]["recall.diff"] > limit

fpredictor.fit(gm.accuracy, gm.recall.diff, 0.01, greater_is_better_obj=False)
fpredictor.fit(gm.accuracy, gm.recall.diff, limit/2, greater_is_better_obj=False)
assert acc >= fpredictor.evaluate()["updated"]["Accuracy"]


Expand All @@ -117,11 +118,11 @@ def test_subset(use_fast=True):

# Check that metrics computed over a subset of the data is consistent with metrics over all data
for group in (" White", " Black", " Amer-Indian-Eskimo"):
assert all(full_group_metrics.loc[group] == partial_group_metrics.loc[group])
assert all(full_group_metrics.loc[('original', group)] == partial_group_metrics.loc[('original', group)])

assert all(
full_group_metrics.loc["Maximum difference"]
>= partial_group_metrics.loc["Maximum difference"]
full_group_metrics.loc[('original', "Maximum difference")]
>= partial_group_metrics.loc[('original',"Maximum difference")]
)


Expand Down
6 changes: 4 additions & 2 deletions tests/unittests/test_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ def test_conflict_groups():
def test_fit_creates_updated(use_fast=True):
"""eval should return 'updated' iff fit has been called"""
fpredictor = FairPredictor(predictor, val_dict, use_fast=use_fast)
assert isinstance(fpredictor.evaluate(), pd.Series)
assert not isinstance(fpredictor.evaluate(), pd.Series)
assert 'original' in fpredictor.evaluate().columns
fpredictor.fit(gm.accuracy, gm.recall, 0) # constraint is intentionally slack
assert not isinstance(fpredictor.evaluate(), pd.Series)
assert 'original' in fpredictor.evaluate().columns
assert 'updated' in fpredictor.evaluate().columns


Expand Down Expand Up @@ -460,4 +462,4 @@ def test_selection_rate_diff_levelling_up_slow():


def test_selection_rate_diff_levelling_up_hybrid():
test_selection_rate_diff_levelling_up(use_fast='hybrid')
test_selection_rate_diff_levelling_up(use_fast='hybrid')

0 comments on commit 82e3688

Please sign in to comment.