diff --git a/tests/unittests/test_scipy.py b/tests/unittests/test_scipy.py index 185aaa0..18c9182 100644 --- a/tests/unittests/test_scipy.py +++ b/tests/unittests/test_scipy.py @@ -279,7 +279,7 @@ def test_recall_diff_hybrid(): test_recall_diff('hybrid') -# """ too slow and disabled +""" too slow -- keep disabled unless hunting non-deterministic bugs. def test_many_recall_diff_hybrid(many=2000): count = 0 for i in range(many): @@ -421,3 +421,43 @@ def test_frontier_slow(): def test_frontier_hybrid(): test_frontier('hybrid') + + +def test_selection_rate_diff_levelling_up(use_fast=True): + """Maximize accuracy while enforcing demographic parity with levelling up.""" + + fpredictor = fair.FairPredictor( + predictor, test_dict, "sex_ Female", use_fast=use_fast) + + fpredictor.fit(gm.accuracy, gm.pos_pred_rate.diff, 0.025, force_levelling_up=True) + + rate = fpredictor.evaluate_groups(metrics={1: gm.pos_pred_rate}, verbose=False)[1] + assert (rate['updated'].drop('Maximum difference') >= rate['original'].drop('Maximum difference')).all() + assert (fpredictor.evaluate_groups() == fpredictor.evaluate_groups(test_dict)).all().all() + # Evaluate the change in fairness (recall difference corresponds to EO) + measures = fpredictor.evaluate_fairness(verbose=False) + + assert measures["original"]["statistical_parity"] > 0.025 + + assert measures["updated"]["statistical_parity"] <= 0.025 + + fpredictor.fit(gm.accuracy, gm.pos_pred_rate.diff, 0.025, force_levelling_up='-') + rate = fpredictor.evaluate_groups(metrics={1: gm.pos_pred_rate}, verbose=False)[1] + assert (rate['updated'].drop('Maximum difference') <= rate['original'].drop('Maximum difference')).all() + # Evaluate the change in fairness (recall difference corresponds to EO) + measures = fpredictor.evaluate_fairness(verbose=False) + + # Evaluate the change in fairness (recall difference corresponds to EO) + measures = fpredictor.evaluate_fairness(verbose=False) + + assert measures["original"]["statistical_parity"] > 0.025 + + assert measures["updated"]["statistical_parity"] <= 0.025 + + +def test_selection_rate_diff_levelling_up_slow(): + test_selection_rate_diff_levelling_up(use_fast=False) + + +def test_selection_rate_diff_levelling_up_hybrid(): + test_selection_rate_diff_levelling_up(use_fast='hybrid') \ No newline at end of file