diff --git a/tests/unittests/test_ag.py b/tests/unittests/test_ag.py index 767ecb8..dd647c9 100644 --- a/tests/unittests/test_ag.py +++ b/tests/unittests/test_ag.py @@ -91,19 +91,20 @@ def test_recall_diff(use_fast=True): fpredictor = fair.FairPredictor(predictor, test_data, "sex", use_fast=use_fast) - fpredictor.fit(gm.accuracy, gm.recall.diff, 0.025) + limit =0.01 + fpredictor.fit(gm.accuracy, gm.recall.diff, limit) # Evaluate the change in fairness (recall difference corresponds to EO) measures = fpredictor.evaluate_fairness(verbose=False) - assert measures["updated"]["recall.diff"] < 0.025 + assert measures["updated"]["recall.diff"] < limit measures = fpredictor.evaluate() acc = measures["updated"]["Accuracy"] - fpredictor.fit(gm.accuracy, gm.recall.diff, 0.025, greater_is_better_const=True) + fpredictor.fit(gm.accuracy, gm.recall.diff, limit, greater_is_better_const=True) measures = fpredictor.evaluate_fairness(verbose=False) - assert measures["original"]["recall.diff"] > 0.025 + assert measures["original"]["recall.diff"] > limit - fpredictor.fit(gm.accuracy, gm.recall.diff, 0.01, greater_is_better_obj=False) + fpredictor.fit(gm.accuracy, gm.recall.diff, limit/2, greater_is_better_obj=False) assert acc >= fpredictor.evaluate()["updated"]["Accuracy"] @@ -117,11 +118,11 @@ def test_subset(use_fast=True): # Check that metrics computed over a subset of the data is consistent with metrics over all data for group in (" White", " Black", " Amer-Indian-Eskimo"): - assert all(full_group_metrics.loc[group] == partial_group_metrics.loc[group]) + assert all(full_group_metrics.loc[('original', group)] == partial_group_metrics.loc[('original', group)]) assert all( - full_group_metrics.loc["Maximum difference"] - >= partial_group_metrics.loc["Maximum difference"] + full_group_metrics.loc[('original', "Maximum difference")] + >= partial_group_metrics.loc[('original',"Maximum difference")] ) diff --git a/tests/unittests/test_scipy.py b/tests/unittests/test_scipy.py index 58cc52f..3d9e2ed 100644 --- a/tests/unittests/test_scipy.py +++ b/tests/unittests/test_scipy.py @@ -102,9 +102,11 @@ def test_conflict_groups(): def test_fit_creates_updated(use_fast=True): """eval should return 'updated' iff fit has been called""" fpredictor = FairPredictor(predictor, val_dict, use_fast=use_fast) - assert isinstance(fpredictor.evaluate(), pd.Series) + assert not isinstance(fpredictor.evaluate(), pd.Series) + assert 'original' in fpredictor.evaluate().columns fpredictor.fit(gm.accuracy, gm.recall, 0) # constraint is intentionally slack assert not isinstance(fpredictor.evaluate(), pd.Series) + assert 'original' in fpredictor.evaluate().columns assert 'updated' in fpredictor.evaluate().columns