Skip to content

Commit

Permalink
Cast in/out shardings to tuple before passing to Exported ctor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586951567
  • Loading branch information
jax authors committed Dec 1, 2023
1 parent e60aa3b commit 54fee48
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/experimental/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,8 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
out_tree=lowered.out_tree,
in_avals=tuple(args_avals_flat),
out_avals=tuple(out_avals_flat),
in_shardings=lowering.compile_args["in_shardings"],
out_shardings=lowering.compile_args["out_shardings"],
in_shardings=tuple(lowering.compile_args["in_shardings"]),
out_shardings=tuple(lowering.compile_args["out_shardings"]),
lowering_platforms=actual_lowering_platforms,
ordered_effects=ordered_effects,
unordered_effects=unordered_effects,
Expand Down

0 comments on commit 54fee48

Please sign in to comment.