Skip to content

Commit

Permalink
[Pallas TPU] Better error message for lowering sp.broadcast_to_p
Browse files Browse the repository at this point in the history
`sp.broadcast_to_p` is a GPU-specific primitive, but it mistakenly appears in TPU lowerings. This PR improves the error message to reflect this.

As an example, currently, users will hit this error when doing:

```
def kernel(x_ref, o_ref):
    m, n = 32, 8
    x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None], jnp.arange(n, dtype=jnp.int32)[None]))
    o_ref[...] = x
```

PiperOrigin-RevId: 700290975
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Nov 26, 2024
1 parent 231967f commit dc11d40
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,18 @@ def _proxy_reduce(arg, *, axes):
lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule


def _broadcast_to_lowering_rule(
ctx: LoweringRuleContext, x, shape: Sequence[int]
):
raise RuntimeError(
"`broadcast_to` is a Triton-specific primitive. Please consider using"
" `jnp.broadcast_to` instead."
)


lowering_rules[state_primitives.broadcast_to_p] = _broadcast_to_lowering_rule


def _broadcast_in_dim_lowering_rule(
ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding
):
Expand Down

0 comments on commit dc11d40

Please sign in to comment.