diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index dc2d63f0b6ec..8e5fb236928e 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1012,19 +1012,23 @@ def _handle_indexing( ] if not indexer_idxs: return ref, transforms - if len(indexer_idxs) > 1: - raise NotImplementedError("Only one level of indexing supported.") - [indexer_idx] = indexer_idxs - indexer = cast(indexing.NDIndexer, transforms[indexer_idx]) - if indexer.int_indexer_shape: - raise NotImplementedError("int_indexer_shape non-empty") - indices = _ndindexer_indices(indexer) - new_transforms_rev = [] - for t in reversed(transforms[:indexer_idx]): - indices, new_t = t.untransform_index(indices) - new_transforms_rev.append(new_t) - new_transforms = [*reversed(new_transforms_rev), *transforms[indexer_idx + 1:]] - return mgpu.memref_slice(ref, indices), new_transforms + sliced_ref = ref + new_transforms = [] + for t in transforms: + if not isinstance(t, indexing.NDIndexer): + new_transforms.append(t) + continue + indexer = cast(indexing.NDIndexer, t) + if indexer.int_indexer_shape: + raise NotImplementedError("int_indexer_shape non-empty") + indices = _ndindexer_indices(indexer) + new_transforms_rev = [] + for t in reversed(new_transforms): + indices, new_t = t.untransform_index(indices) + new_transforms_rev.append(new_t) + sliced_ref = mgpu.memref_slice(sliced_ref, indices) + new_transforms = list(reversed(new_transforms_rev)) + return sliced_ref, new_transforms def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]: diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index a9d5361e7c10..b57cbd3d4f03 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -301,6 +301,50 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0) + def test_ref_with_multiple_indexers(self): + x = jax.random.uniform(jax.random.key(0), (2, 64, 64)) + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM(x.shape, jnp.float32), + plgpu.Barrier(num_arrivals=1), + ], + ) + def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + x_sliced = scratch_ref.at[0, :, :] # shape=(64, 64) + o_ref[pl.ds(0, 32), :] = x_sliced[pl.ds(0, 32), :] + o_ref[pl.ds(32, 32), :] = x_sliced[pl.ds(32, 32), :] + np.testing.assert_array_equal(extract_x0(x), x[0]) + + def test_smem_multiple_indexers_with_transforms(self): + x = jnp.arange(512 * 512).reshape(512, 512) + @functools.partial( + pl.pallas_call, + grid=(4, 4), + out_shape=jax.ShapeDtypeStruct((256, 128), jnp.int32), + in_specs=(plgpu.GPUBlockSpec( + block_shape=(128, 128), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM, + transforms=(plgpu.TilingTransform((64, 32)), + plgpu.SwizzleTransform(128))),), + out_specs=(plgpu.GPUBlockSpec( + block_shape=(64, 32), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM,)), + ) + def kernel(x_ref, o_ref): + x_sliced = x_ref.at[0:64, 32:96].at[:, 0:32] # get x_ref[0:64, 32:64] + o_ref[...] = x_sliced[...] + ref = jnp.concatenate([x[blk:blk+64, :] for blk in range(0, 512, 128)]) + ref = jnp.concatenate( + [ref[:, blk+32:blk+64] for blk in range(0, 512, 128)], axis=1) + np.testing.assert_array_equal(kernel(x), ref) + @parameterized.product(indexer=[0, 1, 2, 3]) def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): @functools.partial(