Skip to content

Commit

Permalink
[Mosaic GPU] Allow multiple indexing on refs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705978858
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Jan 8, 2025
1 parent c1a60c6 commit 222bbd6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 13 deletions.
30 changes: 17 additions & 13 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,19 +1012,23 @@ def _handle_indexing(
]
if not indexer_idxs:
return ref, transforms
if len(indexer_idxs) > 1:
raise NotImplementedError("Only one level of indexing supported.")
[indexer_idx] = indexer_idxs
indexer = cast(indexing.NDIndexer, transforms[indexer_idx])
if indexer.int_indexer_shape:
raise NotImplementedError("int_indexer_shape non-empty")
indices = _ndindexer_indices(indexer)
new_transforms_rev = []
for t in reversed(transforms[:indexer_idx]):
indices, new_t = t.untransform_index(indices)
new_transforms_rev.append(new_t)
new_transforms = [*reversed(new_transforms_rev), *transforms[indexer_idx + 1:]]
return mgpu.memref_slice(ref, indices), new_transforms
sliced_ref = ref
new_transforms = []
for t in transforms:
if not isinstance(t, indexing.NDIndexer):
new_transforms.append(t)
continue
indexer = cast(indexing.NDIndexer, t)
if indexer.int_indexer_shape:
raise NotImplementedError("int_indexer_shape non-empty")
indices = _ndindexer_indices(indexer)
new_transforms_rev = []
for t in reversed(new_transforms):
indices, new_t = t.untransform_index(indices)
new_transforms_rev.append(new_t)
sliced_ref = mgpu.memref_slice(sliced_ref, indices)
new_transforms = list(reversed(new_transforms_rev))
return sliced_ref, new_transforms


def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]:
Expand Down
44 changes: 44 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,50 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)

def test_ref_with_multiple_indexers(self):
x = jax.random.uniform(jax.random.key(0), (2, 64, 64))
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
scratch_shapes=[
plgpu.SMEM(x.shape, jnp.float32),
plgpu.Barrier(num_arrivals=1),
],
)
def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier_ref)
plgpu.barrier_wait(barrier_ref)
x_sliced = scratch_ref.at[0, :, :] # shape=(64, 64)
o_ref[pl.ds(0, 32), :] = x_sliced[pl.ds(0, 32), :]
o_ref[pl.ds(32, 32), :] = x_sliced[pl.ds(32, 32), :]
np.testing.assert_array_equal(extract_x0(x), x[0])

def test_smem_multiple_indexers_with_transforms(self):
x = jnp.arange(512 * 512).reshape(512, 512)
@functools.partial(
pl.pallas_call,
grid=(4, 4),
out_shape=jax.ShapeDtypeStruct((256, 128), jnp.int32),
in_specs=(plgpu.GPUBlockSpec(
block_shape=(128, 128),
index_map=lambda i, j: (i, j),
memory_space=plgpu.SMEM,
transforms=(plgpu.TilingTransform((64, 32)),
plgpu.SwizzleTransform(128))),),
out_specs=(plgpu.GPUBlockSpec(
block_shape=(64, 32),
index_map=lambda i, j: (i, j),
memory_space=plgpu.SMEM,)),
)
def kernel(x_ref, o_ref):
x_sliced = x_ref.at[0:64, 32:96].at[:, 0:32] # get x_ref[0:64, 32:64]
o_ref[...] = x_sliced[...]
ref = jnp.concatenate([x[blk:blk+64, :] for blk in range(0, 512, 128)])
ref = jnp.concatenate(
[ref[:, blk+32:blk+64] for blk in range(0, 512, 128)], axis=1)
np.testing.assert_array_equal(kernel(x), ref)

@parameterized.product(indexer=[0, 1, 2, 3])
def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer):
@functools.partial(
Expand Down

0 comments on commit 222bbd6

Please sign in to comment.