Skip to content

Commit

Permalink
[jax] Canonicalize dtypes when checking if dtypes present in target d…
Browse files Browse the repository at this point in the history
…types list.

PiperOrigin-RevId: 701961663
  • Loading branch information
chr1sj0nes authored and Google-ML-Automation committed Dec 2, 2024
1 parent 7b32d88 commit 5d5b06c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3704,12 +3704,14 @@ def maybe_convert_dtype(input_dtype, target_dtype):
return input_dtype
if not isinstance(target_dtype, tuple):
target_dtype = (target_dtype,)
return input_dtype if input_dtype in target_dtype else target_dtype[0]
if np.dtype(input_dtype) in map(np.dtype, target_dtype):
return input_dtype
return target_dtype[0]

if algorithm == DotAlgorithmPreset.BF16_BF16_F32:
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type)
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type)
if lhs_dtype == dtypes.bfloat16:
if np.dtype(lhs_dtype) == dtypes.bfloat16:
out_dtype = maybe_convert_dtype(out_dtype,
(np.float32, dtypes.bfloat16))
else:
Expand Down

0 comments on commit 5d5b06c

Please sign in to comment.