Skip to content

Commit

Permalink
align test with change of interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisMRuss committed Nov 26, 2024
1 parent 01a4f94 commit ae117c4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
17 changes: 9 additions & 8 deletions tests/unittests/test_ag.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,20 @@ def test_recall_diff(use_fast=True):

fpredictor = fair.FairPredictor(predictor, test_data, "sex", use_fast=use_fast)

fpredictor.fit(gm.accuracy, gm.recall.diff, 0.025)
limit =0.01
fpredictor.fit(gm.accuracy, gm.recall.diff, limit)

# Evaluate the change in fairness (recall difference corresponds to EO)
measures = fpredictor.evaluate_fairness(verbose=False)

assert measures["updated"]["recall.diff"] < 0.025
assert measures["updated"]["recall.diff"] < limit
measures = fpredictor.evaluate()
acc = measures["updated"]["Accuracy"]
fpredictor.fit(gm.accuracy, gm.recall.diff, 0.025, greater_is_better_const=True)
fpredictor.fit(gm.accuracy, gm.recall.diff, limit, greater_is_better_const=True)
measures = fpredictor.evaluate_fairness(verbose=False)
assert measures["original"]["recall.diff"] > 0.025
assert measures["original"]["recall.diff"] > limit

fpredictor.fit(gm.accuracy, gm.recall.diff, 0.01, greater_is_better_obj=False)
fpredictor.fit(gm.accuracy, gm.recall.diff, limit/2, greater_is_better_obj=False)
assert acc >= fpredictor.evaluate()["updated"]["Accuracy"]


Expand All @@ -117,11 +118,11 @@ def test_subset(use_fast=True):

# Check that metrics computed over a subset of the data is consistent with metrics over all data
for group in (" White", " Black", " Amer-Indian-Eskimo"):
assert all(full_group_metrics.loc[group] == partial_group_metrics.loc[group])
assert all(full_group_metrics.loc[('original', group)] == partial_group_metrics.loc[('original', group)])

assert all(
full_group_metrics.loc["Maximum difference"]
>= partial_group_metrics.loc["Maximum difference"]
full_group_metrics.loc[('original', "Maximum difference")]
>= partial_group_metrics.loc[('original',"Maximum difference")]
)


Expand Down
4 changes: 3 additions & 1 deletion tests/unittests/test_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ def test_conflict_groups():
def test_fit_creates_updated(use_fast=True):
"""eval should return 'updated' iff fit has been called"""
fpredictor = FairPredictor(predictor, val_dict, use_fast=use_fast)
assert isinstance(fpredictor.evaluate(), pd.Series)
assert not isinstance(fpredictor.evaluate(), pd.Series)
assert 'original' in fpredictor.evaluate().columns
fpredictor.fit(gm.accuracy, gm.recall, 0) # constraint is intentionally slack
assert not isinstance(fpredictor.evaluate(), pd.Series)
assert 'original' in fpredictor.evaluate().columns
assert 'updated' in fpredictor.evaluate().columns


Expand Down

0 comments on commit ae117c4

Please sign in to comment.