From 6cd66a9d5de0778b6ecb93ab7e44b3222d75feb8 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 16 Dec 2024 08:35:37 -0800 Subject: [PATCH] Fix `save_from_both_policies` in presence of `save_and_offload_only_these_names` by comparing the enum PiperOrigin-RevId: 706720875 --- jax/_src/ad_checkpoint.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 93376c7bd170..3a81ecc85af9 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -142,8 +142,14 @@ def policy(prim, *_, **params): def save_from_both_policies(policy_1, policy_2): def policy(prim, *args, **params): - return policy_1(prim, *args, **params) or policy_2(prim, *args, **params) - + out1 = policy_1(prim, *args, **params) + out2 = policy_2(prim, *args, **params) + if not (isinstance(out1, bool) and isinstance(out2, bool)): + raise ValueError( + "The return value of the policies should be a boolean. Got:" + f" {out1} and {out2}. Please write a custom policy function directly," + " rather than using this helper function.") + return out1 or out2 return policy