Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] XLA fails to fuse embedding lookup / array indexing #20899

Open
neel04 opened this issue Dec 27, 2024 · 4 comments
Open

[TPU] XLA fails to fuse embedding lookup / array indexing #20899

neel04 opened this issue Dec 27, 2024 · 4 comments

Comments

@neel04
Copy link

neel04 commented Dec 27, 2024

https://github.com/patrick-kidger/equinox/blob/7ee4ca944d75c33d1403122f7ccf141bc390a55e/equinox/nn/_embedding.py#L100

I'm using equinox, and Internally eqx.nn.Embedding is just naively indexing (as shown in above link). However, this is subpar as XLA is unable to fuse vmap(embed_layer) calls, instead doing hundreds of thousands of dynamic slice updates over the weight array:

image

Zooming in, we see this repetitive block pattern repeated thousands of times:
image

Instead, we can force XLA to fuse by:

- return self.weight[x]
+ return jnp.take(self.weight, x, axis=0)

image

Which fixes the issue and yields a ~25% improvement in throughput.

Here's a simple colab repro that records 2 tensorboard traces; Note that the blocks for naive lookup are too small so one may have to zoom in into the trace.

Why does XLA fail to fuse/parallelize naive indexing compared to jnp.take?
Why is the jaxpr generated by jnp.take containing Pjit but the naive indexing does not?

If those ops are equivalent, surely XLA would be able to optimize them? 🤔

Tasks

Preview Give feedback
No tasks being tracked yet.
@patrick-toulme
Copy link
Contributor

Can you paste the HLO snippet around the embedding operation with and without your change? You can get HLO IR using this

export XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=${HLO_DUMP_PATH} --xla_dump_hlo_pass_re='.*'"

@neel04
Copy link
Author

neel04 commented Jan 4, 2025

@patrick-toulme The JAXPRs are linked in this equinox issue.

I have attached the HLO dump for both. However, note that there's a huge amount of duplication between the files - but I've zipped up everything for completeness sake. Let me know if you want any other details

  1. hlo_dump_jnp_take.zip
  2. hlo_dump_naive_indexing.zip

@patrick-toulme
Copy link
Contributor

Here is Jnp.take after all passses

HloModule jit__take, is_scheduled=true, entry_computation_layout={(f32[8192,32]{0,1:T(8,128)}, s8[16384]{0:T(1024)(128)(4,1)})->f32[16384,32]{0,1:T(8,128)}}, allow_spmd_sharding_propagation_to_parameters={true,false}, allow_spmd_sharding_propagation_to_output={true}

fused_computation {
  param_0.2 = f32[8192,32]{1,0:T(8,128)S(1)} parameter(0)
  param_1.4 = s32[16384]{0:S(1)} parameter(1)
  custom-call.1 = s32[16384]{0} custom-call(param_1.4), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[16384]{0}}, metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  slice.1 = s32[16384]{0:T(1024)} slice(custom-call.1), slice={[0:16384]}, metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  reshape.22 = s32[16384]{0:T(1024)} reshape(slice.1), metadata={op_name="jit(_take)/jit(main)/convert_element_type" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  transpose.6 = s32[16384]{0:T(1024)} transpose(reshape.22), dimensions={0}, metadata={op_name="jit(_take)/jit(main)/convert_element_type" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  gather.1 = f32[16384,32]{1,0:T(8,128)} gather(param_0.2, transpose.6), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,32}, metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  transpose.5 = f32[16384,32]{1,0:T(8,128)} transpose(gather.1), dimensions={0,1}, metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  ROOT reshape.21 = f32[16384,32]{1,0:T(8,128)S(1)} reshape(transpose.5), metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
} // fused_computation

fused_computation.1 {
  param_1.18 = pred[16384]{0:T(1024)(128)(4,1)S(1)} parameter(1)
  broadcast.8 = pred[16384,32]{1,0:T(8,128)(4,1)} broadcast(param_1.18), dimensions={0}, metadata={op_name="jit(_take)/jit(main)/broadcast_in_dim" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  param_0.4 = f32[16384,32]{1,0:T(8,128)S(1)} parameter(0)
  constant.17 = f32[]{:T(128)} constant(nan)
  broadcast.7 = f32[16384,32]{1,0:T(8,128)} broadcast(constant.17), dimensions={}
  ROOT select.1 = f32[16384,32]{1,0:T(8,128)S(1)} select(broadcast.8, param_0.4, broadcast.7), metadata={op_name="jit(_take)/jit(main)/select_n" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
} // fused_computation.1

fused_computation.2 {
  constant.14 = s32[] constant(0), metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  broadcast.10 = s32[16384]{0} broadcast(constant.14), dimensions={}, metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  param_0.8 = s32[16384]{0:T(1024)S(1)} parameter(0)
  pad.1 = s32[16384]{0} pad(param_0.8, constant.14), padding=0_0, metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  constant.13 = s32[] constant(8191), metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  broadcast.9 = s32[16384]{0} broadcast(constant.13), dimensions={}, metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  ROOT clamp.1 = s32[16384]{0:S(1)} clamp(broadcast.10, pad.1, broadcast.9), metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
} // fused_computation.2

fused_computation.3 {
  param_0.11 = s8[16384]{0:T(1024)(128)(4,1)S(1)} parameter(0)
  convert.5 = s32[16384]{0:T(1024)} convert(param_0.11), metadata={op_name="jit(_take)/jit(main)/convert_element_type" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  constant.16 = s32[]{:T(128)} constant(0)
  broadcast.13 = s32[16384]{0:T(1024)} broadcast(constant.16), dimensions={}
  compare.3 = pred[16384]{0:T(1024)(128)(4,1)} compare(convert.5, broadcast.13), direction=GE, metadata={op_name="jit(_take)/jit(main)/ge" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  constant.15 = s32[]{:T(128)} constant(8191), metadata={op_name="jit(_take)/jit(main)/sub" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  broadcast.11 = s32[16384]{0:T(1024)} broadcast(constant.15), dimensions={}, metadata={op_name="jit(_take)/jit(main)/le" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  compare.2 = pred[16384]{0:T(1024)(128)(4,1)} compare(convert.5, broadcast.11), direction=LE, metadata={op_name="jit(_take)/jit(main)/le" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  ROOT and.1 = pred[16384]{0:T(1024)(128)(4,1)S(1)} and(compare.3, compare.2), metadata={op_name="jit(_take)/jit(main)/and" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
} // fused_computation.3

ENTRY main.40 {
  Arg_0.1 = f32[8192,32]{0,1:T(8,128)} parameter(0)
  Arg_1.2 = s8[16384]{0:T(1024)(128)(4,1)} parameter(1), sharding={maximal device=0}, metadata={op_name="jit(_take)/jit(main)/jit(_where)/select_n" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}
  copy = f32[8192,32]{1,0:T(8,128)S(1)} copy(Arg_0.1), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["512","1"],"input_window_bounds":["4","32"],"estimated_cycles":"5932","iteration_bounds":["2","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"5242880"}],"retry_config":{"retry_count":"0"}}
  convert.0 = s32[16384]{0:T(1024)S(1)} convert(Arg_1.2), metadata={op_name="jit(_take)/jit(main)/convert_element_type" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"1857","iteration_bounds":["1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"16384"}],"retry_config":{"retry_count":"0"}}
  fusion.2 = s32[16384]{0:S(1)} fusion(convert.0), kind=kLoop, calls=fused_computation.2, metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"1969","iteration_bounds":["1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"0"}],"retry_config":{"retry_count":"0"}}
  copy-start = (s8[16384]{0:T(1024)(128)(4,1)S(1)}, s8[16384]{0:T(1024)(128)(4,1)}, u32[]{:S(2)}) copy-start(Arg_1.2)
  fusion = f32[16384,32]{1,0:T(8,128)S(1)} fusion(copy, fusion.2), kind=kCustom, calls=fused_computation, metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}, backend_config={"flag_configs":[],"integer_config":{"integer":"0"},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"131072"}],"retry_config":{"retry_count":"0"}}
  copy-done = s8[16384]{0:T(1024)(128)(4,1)S(1)} copy-done(copy-start)
  fusion.3 = pred[16384]{0:T(1024)(128)(4,1)S(1)} fusion(copy-done), kind=kLoop, calls=fused_computation.3, metadata={op_name="jit(_take)/jit(main)/and" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"1881","iteration_bounds":["1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"0"}],"retry_config":{"retry_count":"0"}}
  fusion.1 = f32[16384,32]{1,0:T(8,128)S(1)} fusion(fusion, fusion.3), kind=kLoop, calls=fused_computation.1, metadata={op_name="jit(_take)/jit(main)/select_n" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["228","1"],"input_window_bounds":[],"estimated_cycles":"38214","iteration_bounds":["9","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"0"}],"retry_config":{"retry_count":"0"}}
  ROOT copy.1 = f32[16384,32]{0,1:T(8,128)} copy(fusion.1), metadata={op_name="jit(_take)/jit(main)/select_n" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["4","32"],"input_window_bounds":["512","1"],"estimated_cycles":"11864","iteration_bounds":["1","4"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"5242880"}],"retry_config":{"retry_count":"0"}}
} // main.40

This is naive indexing after all passes.

HloModule jit_gather, is_scheduled=true, entry_computation_layout={(f32[8192,32]{0,1:T(8,128)}, s8[16384,2]{0,1:T(4,128)(4,1)})->f32[16384,1,32]{0,2,1:T(8,128)}}, allow_spmd_sharding_propagation_to_parameters={true,false}, allow_spmd_sharding_propagation_to_output={true}

fused_computation {
  param_0.2 = f32[8192,32]{1,0:T(8,128)S(1)} parameter(0)
  constant.12 = s32[]{:T(128)} constant(8191), metadata={op_name="args[1]"}
  broadcast.12 = s32[16384,1]{1,0:T(8,128)} broadcast(constant.12), dimensions={}, metadata={op_name="args[1]"}
  param_1.5 = s32[16384]{0:T(1024)S(1)} parameter(1)
  reshape.6 = s32[16384,1]{1,0:T(8,128)} reshape(param_1.5), metadata={op_name="args[1]"}
  transpose.5 = s32[16384,1]{1,0:T(8,128)} transpose(reshape.6), dimensions={0,1}, metadata={op_name="args[1]"}
  and.3 = s32[16384,1]{1,0:T(8,128)} and(broadcast.12, transpose.5), metadata={op_name="args[1]"}
  constant.11 = s32[]{:T(128)} constant(0), metadata={op_name="args[1]"}
  broadcast.11 = s32[16384,1]{1,0:T(8,128)} broadcast(constant.11), dimensions={}, metadata={op_name="args[1]"}
  shift-right-logical.3 = s32[16384,1]{1,0:T(8,128)} shift-right-logical(and.3, broadcast.11), metadata={op_name="args[1]"}
  constant.10 = s32[]{:T(128)} constant(253952), metadata={op_name="args[1]"}
  broadcast.10 = s32[16384,1]{1,0:T(8,128)} broadcast(constant.10), dimensions={}, metadata={op_name="args[1]"}
  and.2 = s32[16384,1]{1,0:T(8,128)} and(broadcast.10, transpose.5), metadata={op_name="args[1]"}
  constant.9 = s32[]{:T(128)} constant(13), metadata={op_name="args[1]"}
  broadcast.9 = s32[16384,1]{1,0:T(8,128)} broadcast(constant.9), dimensions={}, metadata={op_name="args[1]"}
  shift-right-logical.2 = s32[16384,1]{1,0:T(8,128)} shift-right-logical(and.2, broadcast.9), metadata={op_name="args[1]"}
  concatenate.1 = s32[16384,2]{1,0:T(8,128)} concatenate(shift-right-logical.3, shift-right-logical.2), dimensions={1}, metadata={op_name="args[1]"}
  custom-call.1 = s32[16384,2]{1,0:T(8,128)} custom-call(concatenate.1), custom_call_target="GatherScatterIndicesBitpacked", operand_layout_constraints={s32[16384,2]{1,0}}, metadata={op_name="args[1]"}
  gather.1 = f32[16384,32]{1,0:T(8,128)} gather(param_0.2, custom-call.1), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,32}, metadata={op_name="jit(gather)/jit(main)/gather" source_file="<ipython-input-2-509333df85b8>" source_line=8}
  transpose.4 = f32[16384,32]{1,0:T(8,128)} transpose(gather.1), dimensions={0,1}, metadata={op_name="jit(gather)/jit(main)/gather" source_file="<ipython-input-2-509333df85b8>" source_line=8}
  ROOT reshape.5 = f32[16384,32]{1,0:T(8,128)S(1)} reshape(transpose.4), metadata={op_name="jit(gather)/jit(main)/gather" source_file="<ipython-input-2-509333df85b8>" source_line=8}
} // fused_computation

fused_computation.3 {
  constant.19 = s32[]{:T(128)} constant(0), metadata={op_name="args[1]"}
  broadcast.15 = s32[16384,1]{0,1:T(1,128)} broadcast(constant.19), dimensions={}, sharding={replicated}, metadata={op_name="args[1]"}
  param_0.16 = s32[16384,2]{0,1:T(2,128)S(1)} parameter(0)
  slice.13 = s32[16384,1]{0,1:T(1,128)} slice(param_0.16), slice={[0:16384], [0:1]}, sharding={replicated}, metadata={op_name="args[1]"}
  constant.18 = s32[]{:T(128)} constant(8191), metadata={op_name="args[1]"}
  broadcast.14 = s32[16384,1]{0,1:T(1,128)} broadcast(constant.18), dimensions={}, sharding={replicated}, metadata={op_name="args[1]"}
  clamp.11 = s32[16384,1]{0,1:T(1,128)} clamp(broadcast.15, slice.13, broadcast.14), sharding={replicated}, metadata={op_name="args[1]"}
  shift-left.7 = s32[16384,1]{0,1:T(1,128)} shift-left(clamp.11, broadcast.15), sharding={replicated}, metadata={op_name="args[1]"}
  slice.8 = s32[16384,1]{0,1:T(1,128)} slice(param_0.16), slice={[0:16384], [1:2]}, sharding={replicated}, metadata={op_name="args[1]"}
  clamp.6 = s32[16384,1]{0,1:T(1,128)} clamp(broadcast.15, slice.8, broadcast.15), sharding={replicated}, metadata={op_name="args[1]"}
  constant.17 = s32[]{:T(128)} constant(13), metadata={op_name="args[1]"}
  broadcast.13 = s32[16384,1]{0,1:T(1,128)} broadcast(constant.17), dimensions={}, sharding={replicated}, metadata={op_name="args[1]"}
  shift-left.4 = s32[16384,1]{0,1:T(1,128)} shift-left(clamp.6, broadcast.13), sharding={replicated}, metadata={op_name="args[1]"}
  ROOT or.1 = s32[16384,1]{0,1:T(1,128)S(1)} or(shift-left.7, shift-left.4), sharding={replicated}, metadata={op_name="args[1]"}
} // fused_computation.3

ENTRY main.4 {
  Arg_1.2 = s8[16384,2]{0,1:T(4,128)(4,1)} parameter(1), sharding={replicated}, metadata={op_name="args[1]"}
  Arg_0.1 = f32[8192,32]{0,1:T(8,128)} parameter(0), metadata={op_name="args[0]"}
  convert = s32[16384,2]{0,1:T(2,128)S(1)} convert(Arg_1.2), sharding={replicated}, metadata={op_name="args[1]"}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","128"],"input_window_bounds":[],"estimated_cycles":"2013","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"65536"}],"retry_config":{"retry_count":"0"}}
  fusion.3 = s32[16384,1]{0,1:T(1,128)S(1)} fusion(convert), kind=kLoop, calls=fused_computation.3, sharding={replicated}, metadata={op_name="args[1]"}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","128"],"input_window_bounds":[],"estimated_cycles":"2733","iteration_bounds":["1","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"0"}],"retry_config":{"retry_count":"0"}}
  copy.2 = f32[8192,32]{1,0:T(8,128)S(1)} copy(Arg_0.1), metadata={op_name="args[0]"}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["512","1"],"input_window_bounds":["4","32"],"estimated_cycles":"5932","iteration_bounds":["2","1"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"5242880"}],"retry_config":{"retry_count":"0"}}
  bitcast.2 = s32[16384]{0:T(1024)S(1)} bitcast(fusion.3)
  fusion = f32[16384,32]{1,0:T(8,128)S(1)} fusion(copy.2, bitcast.2), kind=kCustom, calls=fused_computation, metadata={op_name="jit(gather)/jit(main)/gather" source_file="<ipython-input-2-509333df85b8>" source_line=8}, backend_config={"flag_configs":[],"integer_config":{"integer":"128"},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"196608"}],"retry_config":{"retry_count":"0"}}
  bitcast.3 = f32[16384,1,32]{2,0,1:T(8,128)S(1)} bitcast(fusion)
  ROOT copy.3 = f32[16384,1,32]{0,2,1:T(8,128)} copy(bitcast.3), metadata={op_name="jit(gather)/jit(main)/gather" source_file="<ipython-input-2-509333df85b8>" source_line=8}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["1","4","32"],"input_window_bounds":["1","512","1"],"estimated_cycles":"11864","iteration_bounds":["1","1","4"]},"scoped_memory_configs":[],"used_scoped_memory_configs":[{"memory_space":"1","offset":"0","size":"5242880"}],"retry_config":{"retry_count":"0"}}
} // main.4

The main difference seems to be in the jnp.take indices are assumed to be in bounds

custom-call.1 = s32[16384]{0} custom-call(param_1.4), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[16384]{0}}, metadata={op_name="jit(_take)/jit(main)/gather" source_file="<ipython-input-5-78efdeb3c2eb>" source_line=11}

in naive indexing it is different

  custom-call.1 = s32[16384,2]{1,0:T(8,128)} custom-call(concatenate.1), custom_call_target="GatherScatterIndicesBitpacked", operand_layout_constraints={s32[16384,2]{1,0}}, metadata={op_name="args[1]"}

@patrick-toulme
Copy link
Contributor

The take is broken into four kernels, while naive indexing is broken into two.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants