Skip to content

Commit

Permalink
Fix save_from_both_policies in presence of `save_and_offload_only_t…
Browse files Browse the repository at this point in the history
…hese_names`

PiperOrigin-RevId: 706720875
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Dec 16, 2024
1 parent 2b06f93 commit 4d63445
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ def policy(prim, *_, **params):
def save_from_both_policies(policy_1, policy_2):

def policy(prim, *args, **params):
return policy_1(prim, *args, **params) or policy_2(prim, *args, **params)

out1 = pe.ensure_enum(policy_1(prim, *args, **params))
out2 = pe.ensure_enum(policy_2(prim, *args, **params))
return isinstance(out1, pe.SaveableType) or isinstance(out2, pe.SaveableType)
return policy


Expand Down

0 comments on commit 4d63445

Please sign in to comment.