Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving eqx.nn.Embedding's performance #920

Open
neel04 opened this issue Dec 24, 2024 · 6 comments
Open

Improving eqx.nn.Embedding's performance #920

neel04 opened this issue Dec 24, 2024 · 6 comments

Comments

@neel04
Copy link

neel04 commented Dec 24, 2024

return self.weight[x]

Internally, eqx.nn.Embedding is just naively indexing. However, this is subpar as XLA is unable to fuse vmap(embed_layer) calls, instead doing hundreds of thousands of dynamic slice updates over the weight array:

image

Zooming in, we see this repetitive block pattern repeated thousands of times:
image

Instead, we can force XLA to fuse by:

- return self.weight[x]
+ return jnp.take(self.weight, x, axis=0)

image

Which fixes the issue :)

On my test runs, it yielded ~25% improvement in throughput.

Happy to PR. This is backwards compatible with how users already use this (doing vmap(nn.Embedding)(...)) but perhaps one could explicitly prevent users from providing a non-scalar index (x here) to keep the API consistent - or optionally allow both since they're probably equivalent for XLA.

What do you think @patrick-kidger?

@patrick-kidger
Copy link
Owner

So I'm guessing that you're running this on GPU? I recognise this gather-to-dynamic_slice rewrite as being an optimization that XLA:GPU (and to my knowledge, nothing else) performs.

What surprises me is that the jaxprs for the two are very similar:

import jax
import jax.numpy as jnp

def f(x, y):
    return x[y]

def g(x, y):
    return jnp.take(x, y, axis=0)

ff = jax.vmap(f, in_axes=(None, 0))
gg = jax.vmap(g, in_axes=(None, 0))

x = jnp.arange(12.).reshape(3, 4)
y = jnp.array([[0, 1], [2, 3]])

print(jax.make_jaxpr(ff)(x, y))
# { lambda ; a:f32[3,4] b:i32[2,2]. let
#     c:bool[2,2] = lt b 0
#     d:i32[2,2] = add b 3
#     e:i32[2,2] = select_n c b d
#     f:i32[2,2,1] = broadcast_in_dim[
#       broadcast_dimensions=(0, np.int64(1))
#       shape=(2, 2, 1)
#       sharding=None
#     ] e
#     g:f32[2,2,4] = gather[
#       dimension_numbers=GatherDimensionNumbers(offset_dims=(2,), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
#       fill_value=None
#       indices_are_sorted=False
#       mode=GatherScatterMode.PROMISE_IN_BOUNDS
#       slice_sizes=(1, 4)
#       unique_indices=False
#     ] a f
#   in (g,) }

print(jax.make_jaxpr(gg)(x, y))
# { lambda ; a:f32[3,4] b:i32[2,2]. let
#     c:f32[2,2,4] = pjit[
#       name=_take
#       jaxpr={ lambda ; d:f32[3,4] e:i32[2,2]. let
#           f:bool[2,2] = lt e 0
#           g:i32[2,2] = add e 3
#           h:i32[2,2] = pjit[
#             name=_where
#             jaxpr={ lambda ; i:bool[2,2] j:i32[2,2] k:i32[2,2]. let
#                 l:i32[2,2] = select_n i k j
#               in (l,) }
#           ] f g e
#           m:i32[2,2,1] = broadcast_in_dim[
#             broadcast_dimensions=(0, np.int64(1))
#             shape=(2, 2, 1)
#             sharding=None
#           ] h
#           n:f32[2,2,4] = gather[
#             dimension_numbers=GatherDimensionNumbers(offset_dims=(2,), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
#             fill_value=nan
#             indices_are_sorted=False
#             mode=GatherScatterMode.FILL_OR_DROP
#             slice_sizes=(1, 4)
#             unique_indices=False
#           ] d m
#         in (n,) }
#     ] a b
#   in (c,) }

the second one has some nested JITs which I assume don't do anything, and it tweaks gather[mode=...], but I think that's it.

Do you know what is causing this (de-)optimization pass to trigger or not-trigger / can you try reproducing this with slightly different jaxprs and see what causes it? And did you know jnp.take would resolve things or were you just trying things out?

Happy to adjust things in Equinox if need be but I'd like to understand the issue a bit better first.

@neel04
Copy link
Author

neel04 commented Dec 25, 2024

So I'm guessing that you're running this on GPU?

No, I'm on TPUs (v4-32 to be specific)

What surprises me is that the jaxprs for the two are very similar:

That's interesting. I did a bit of exploring around in a colab, and sharding the y array, I think there are definite qualitative differences in the Jaxprs that probably significantly influence how XLA passes lower and optimize it.

One thing I notice is that naive indexing is still missing a PJIT in the jaxpr so those ops aren't getting parallelized. I recommend looking at the op-breakdown in the linked TensorBoard as well as the traces to get a better idea.

Note that the PJIT wraps the construction of c - so its directly JIT-ing that lambda which is a possible hint as to why XLA fuses jnp.take over naive indexing.

And did you know jnp.take would resolve things or were you just trying things out?

When I saw the traces, my guess was that JAX was treating each update of the embedding serially instead of fusing. One immediately thinks of using JAX primitives that XLA is more likely to pattern-match against; lax.gather comes to mind but the docs recommend not using it directly - hence bubbling up to jnp.take.

I think this is the common problem of not giving enough information to the compiler - think how autovectorization in traditional compilers often breaks. I assume lax.gather/jnp.take is more informative to XLA that Embedding is just plucking out values from an array of indices and allows it to fuse - whereas naive indexing confuses it and prevents it from parallelizing the op correctly.

@patrick-kidger
Copy link
Owner

No, I'm on TPUs (v4-32 to be specific)

Got it. I assume XLA:TPU has the same pass then.

I assume lax.gather/jnp.take is more informative to XLA that Embedding is just plucking out values from an array of indices and allows it to fuse - whereas naive indexing confuses it and prevents it from parallelizing the op correctly.

FWIW I think this is arguably also a bug at the XLA level, that it is failing to optimize things as well as it could. You might consider reporting this upstream.


Anyway, it's still a little mysterious to me exactly what triggers this pessimisation, but as you've offered then I'd be happy to take a PR on this! (Including a comment back to this thread to explain.)

@patrick-toulme
Copy link

The jnp.take causes the indices to be assumed to be in bounds. This assumption will be faster on chip.
See the IR here - openxla/xla#20899 (comment)

The jnp.take also seems to fuse into four kernels while the naive indexing is two kernels.

@patrick-kidger
Copy link
Owner

Hi @patrick-toulme !

Thanks for taking a look at this. At least at the jaxpr level, I think it is naive indexing that is promised to be in-bounds:

mode=GatherScatterMode.PROMISE_IN_BOUNDS

Whilst take has what I assume implies more work (putting in the fill value if required):

mode=GatherScatterMode.FILL_OR_DROP

i.e. the other way around to your own investigation?

(These are pulled from the jaxprs above.)

@neel04
Copy link
Author

neel04 commented Jan 4, 2025

I agree with Kidger - it does seem that take specifically has the ability to pad. You can see that in the HLO for fused_computation.1. More specifically, if you look up the semantics of XLA::Select it does a comparison with a predicate mask (boolean) to select the correct indices. If you look at the corresponding HLO:

fused_computation.1 {
  param_1.18 = pred[16384]{0:T(1024)(128)(4,1)S(1)} parameter(1)
  broadcast.8 = pred[16384,32]{1,0:T(8,128)(4,1)} broadcast(param_1.18), dimensions={0}, metadata={op_name="jit(_take)/jit(main)/broadcast_in_dim" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  param_0.4 = f32[16384,32]{1,0:T(8,128)S(1)} parameter(0)
  constant.17 = f32[]{:T(128)} constant(nan)
  broadcast.7 = f32[16384,32]{1,0:T(8,128)} broadcast(constant.17), dimensions={}
  ROOT select.1 = f32[16384,32]{1,0:T(8,128)S(1)} select(broadcast.8, param_0.4, broadcast.7), metadata={op_name="jit(_take)/jit(main)/select_n" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
} // fused_computation.1

You can easily trace-back how XLA does padding statically using Select and a broadcasted array of padding constant NaN which checks out with the defaults outlined in jnp.take's documentation.

As for the kernels, while its usually a good rule of thumb to have less kernels and it indicates more fusion - it doesn't gurantee that those few kernels wouldn't be run multiple times which would fail to amortize the overhead; so I think its a bit of a red herring in this case :)


The naive indexing fusions are a bit... weird. There are a couple no-ops (like the transpose.5 and shift-right-logical.3/shift-left.7 which is a right/left shift of 0?). It also doesn't do padding - looking at fused_computation.3, we have:

  ...
  clamp.11 = s32[16384,1]{0,1:T(1,128)} clamp(broadcast.15, slice.13, broadcast.14), sharding={replicated}, metadata={op_name="args[1]"}

Which simply xla::clamps the indices between 0 and 8191 (inclusive).

From a quick analysis, it seems that naive indexing is doing some sort of a fused-loop where each lookup involves a Gather which does make sense in the context of the TensorBoard trace; It may be dynamic slices because it operates mostly on a bit-level (it extracts the lower 13 bits which is equivalent to 8192) so due to some bit logic that I find hard to follow, its requesting different slices each time which may be why its not running as a single fused kernel? or maybe due to the excessive concatenation for each call of the loop? 🤔

Overall this does seem like a weird way to optimize lookups; my guess is that due to some optimization passes introduced for XLA:GPU, somehow its reverting to this rather strange implementation possibly because its quite efficient on GPUs for a small number of array accesses but ends up being suboptimal on TPUs...

Not sure 🤷‍♂️ what do you guys think? should we CC in someone at google?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants