-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improving eqx.nn.Embedding
's performance
#920
Comments
So I'm guessing that you're running this on GPU? I recognise this What surprises me is that the jaxprs for the two are very similar: import jax
import jax.numpy as jnp
def f(x, y):
return x[y]
def g(x, y):
return jnp.take(x, y, axis=0)
ff = jax.vmap(f, in_axes=(None, 0))
gg = jax.vmap(g, in_axes=(None, 0))
x = jnp.arange(12.).reshape(3, 4)
y = jnp.array([[0, 1], [2, 3]])
print(jax.make_jaxpr(ff)(x, y))
# { lambda ; a:f32[3,4] b:i32[2,2]. let
# c:bool[2,2] = lt b 0
# d:i32[2,2] = add b 3
# e:i32[2,2] = select_n c b d
# f:i32[2,2,1] = broadcast_in_dim[
# broadcast_dimensions=(0, np.int64(1))
# shape=(2, 2, 1)
# sharding=None
# ] e
# g:f32[2,2,4] = gather[
# dimension_numbers=GatherDimensionNumbers(offset_dims=(2,), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
# fill_value=None
# indices_are_sorted=False
# mode=GatherScatterMode.PROMISE_IN_BOUNDS
# slice_sizes=(1, 4)
# unique_indices=False
# ] a f
# in (g,) }
print(jax.make_jaxpr(gg)(x, y))
# { lambda ; a:f32[3,4] b:i32[2,2]. let
# c:f32[2,2,4] = pjit[
# name=_take
# jaxpr={ lambda ; d:f32[3,4] e:i32[2,2]. let
# f:bool[2,2] = lt e 0
# g:i32[2,2] = add e 3
# h:i32[2,2] = pjit[
# name=_where
# jaxpr={ lambda ; i:bool[2,2] j:i32[2,2] k:i32[2,2]. let
# l:i32[2,2] = select_n i k j
# in (l,) }
# ] f g e
# m:i32[2,2,1] = broadcast_in_dim[
# broadcast_dimensions=(0, np.int64(1))
# shape=(2, 2, 1)
# sharding=None
# ] h
# n:f32[2,2,4] = gather[
# dimension_numbers=GatherDimensionNumbers(offset_dims=(2,), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
# fill_value=nan
# indices_are_sorted=False
# mode=GatherScatterMode.FILL_OR_DROP
# slice_sizes=(1, 4)
# unique_indices=False
# ] d m
# in (n,) }
# ] a b
# in (c,) } the second one has some nested JITs which I assume don't do anything, and it tweaks Do you know what is causing this (de-)optimization pass to trigger or not-trigger / can you try reproducing this with slightly different jaxprs and see what causes it? And did you know Happy to adjust things in Equinox if need be but I'd like to understand the issue a bit better first. |
No, I'm on TPUs (v4-32 to be specific)
That's interesting. I did a bit of exploring around in a colab, and sharding the One thing I notice is that naive indexing is still missing a Note that the
When I saw the traces, my guess was that JAX was treating each update of the embedding serially instead of fusing. One immediately thinks of using JAX primitives that XLA is more likely to pattern-match against; I think this is the common problem of not giving enough information to the compiler - think how autovectorization in traditional compilers often breaks. I assume |
Got it. I assume XLA:TPU has the same pass then.
FWIW I think this is arguably also a bug at the XLA level, that it is failing to optimize things as well as it could. You might consider reporting this upstream. Anyway, it's still a little mysterious to me exactly what triggers this pessimisation, but as you've offered then I'd be happy to take a PR on this! (Including a comment back to this thread to explain.) |
The jnp.take causes the indices to be assumed to be in bounds. This assumption will be faster on chip. The jnp.take also seems to fuse into four kernels while the naive indexing is two kernels. |
Hi @patrick-toulme ! Thanks for taking a look at this. At least at the jaxpr level, I think it is naive indexing that is promised to be in-bounds:
Whilst
i.e. the other way around to your own investigation? (These are pulled from the jaxprs above.) |
I agree with Kidger - it does seem that fused_computation.1 {
param_1.18 = pred[16384]{0:T(1024)(128)(4,1)S(1)} parameter(1)
broadcast.8 = pred[16384,32]{1,0:T(8,128)(4,1)} broadcast(param_1.18), dimensions={0}, metadata={op_name="jit(_take)/jit(main)/broadcast_in_dim" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
param_0.4 = f32[16384,32]{1,0:T(8,128)S(1)} parameter(0)
constant.17 = f32[]{:T(128)} constant(nan)
broadcast.7 = f32[16384,32]{1,0:T(8,128)} broadcast(constant.17), dimensions={}
ROOT select.1 = f32[16384,32]{1,0:T(8,128)S(1)} select(broadcast.8, param_0.4, broadcast.7), metadata={op_name="jit(_take)/jit(main)/select_n" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
} // fused_computation.1 You can easily trace-back how XLA does padding statically using As for the kernels, while its usually a good rule of thumb to have less kernels and it indicates more fusion - it doesn't gurantee that those few kernels wouldn't be run multiple times which would fail to amortize the overhead; so I think its a bit of a red herring in this case :) The naive indexing fusions are a bit... weird. There are a couple no-ops (like the ...
clamp.11 = s32[16384,1]{0,1:T(1,128)} clamp(broadcast.15, slice.13, broadcast.14), sharding={replicated}, metadata={op_name="args[1]"} Which simply From a quick analysis, it seems that naive indexing is doing some sort of a fused-loop where each lookup involves a Overall this does seem like a weird way to optimize lookups; my guess is that due to some optimization passes introduced for Not sure 🤷♂️ what do you guys think? should we CC in someone at google? |
equinox/equinox/nn/_embedding.py
Line 100 in 7ee4ca9
Internally,
eqx.nn.Embedding
is just naively indexing. However, this is subpar as XLA is unable to fusevmap(embed_layer)
calls, instead doing hundreds of thousands of dynamic slice updates over theweight
array:Zooming in, we see this repetitive block pattern repeated thousands of times:
Instead, we can force
XLA
to fuse by:Which fixes the issue :)
On my test runs, it yielded ~25% improvement in throughput.
Happy to PR. This is backwards compatible with how users already use this (doing
vmap(nn.Embedding)(...)
) but perhaps one could explicitly prevent users from providing a non-scalar index (x
here) to keep the API consistent - or optionally allow both since they're probably equivalent forXLA
.What do you think @patrick-kidger?
The text was updated successfully, but these errors were encountered: