Skip to content

Commit

Permalink
[Mosaic GPU] Clean up imports in gpu_dialect_test.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707549269
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Dec 18, 2024
1 parent 3d54d03 commit 6a03ea3
Showing 1 changed file with 40 additions and 44 deletions.
84 changes: 40 additions & 44 deletions tests/mosaic/gpu_dialect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,9 @@
from jax._src.lib.mlir.dialects import nvvm
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
from jax.experimental.mosaic.gpu import gpu_address_space_to_nvptx # pylint: disable=g-importing-member,g-multiple-import
from jax.experimental.mosaic.gpu import infer_layout # pylint: disable=g-importing-member,g-multiple-import
from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import
from jax.experimental.mosaic.gpu import strided_fragmented_layout # pylint: disable=g-importing-member
import jax.experimental.mosaic.gpu as mgpu

_cext = mgpu._cext if mgpu is not None else None
_cext = mgpu.dialect._cext if mgpu.dialect is not None else None


config.parse_flags_with_absl()
Expand All @@ -45,7 +41,7 @@ def _make_ir_context():
context = ir.Context()
context.append_dialect_registry(mlir_interpreter.upstream_dialects)
context.load_all_available_dialects()
mgpu.register_dialect(context)
mgpu.dialect.register_dialect(context)
return context


Expand Down Expand Up @@ -76,7 +72,7 @@ def is_mosaic_gpu_op(op: ir.OpView) -> bool:


def workgroup_ptr_ty() -> ir.Type:
workgroup_nvptx_address_space = gpu_address_space_to_nvptx(
workgroup_nvptx_address_space = mgpu.gpu_address_space_to_nvptx(
gpu.AddressSpace.Workgroup
)
return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
Expand All @@ -85,7 +81,7 @@ def workgroup_ptr_ty() -> ir.Type:
class MosaicGpuTest(parameterized.TestCase):

def setUp(self):
if mgpu is None:
if mgpu.dialect is None:
raise self.skipTest("Test requires Mosaic GPU dialect")
super().setUp()
self.enter_context(_make_ir_context())
Expand All @@ -100,7 +96,7 @@ def test_dialect_module_is_loaded(self):

def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.F32Type.get()),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1,
Expand All @@ -112,7 +108,7 @@ def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):

def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=0,
Expand All @@ -122,7 +118,7 @@ def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self):

def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(ir.Type.parse(f"!llvm.ptr<{0}>")),
arrival_count=1,
Expand All @@ -132,14 +128,14 @@ def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self):

def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1,
)
self.assertTrue(self.module.operation.verify())
self.assertIsInstance(
self.module.body.operations[1], mgpu.InitializeBarrierOp
self.module.body.operations[1], mgpu.dialect.InitializeBarrierOp
)

def test_async_load_op_dest_must_be_contiguous(self):
Expand All @@ -156,7 +152,7 @@ def test_async_load_op_dest_must_be_contiguous(self):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
Expand All @@ -183,7 +179,7 @@ def test_async_load_op_source_and_dest_must_have_same_element_type(self):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
Expand All @@ -210,7 +206,7 @@ def test_async_load_op_slice_lengths_must_be_larger_than_minus_two(self):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
Expand Down Expand Up @@ -238,7 +234,7 @@ def test_async_load_op_source_and_dest_ranks_must_match_with_collapse(self):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
Expand All @@ -264,7 +260,7 @@ def test_async_load_op_indices_size_must_match_source_rank(self):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
Expand All @@ -290,7 +286,7 @@ def test_async_load_op_slice_lengths_size_must_match_source_rank(self):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
Expand All @@ -316,7 +312,7 @@ def test_async_load_op_slice_collective_must_be_unique(self):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
Expand All @@ -325,10 +321,10 @@ def test_async_load_op_slice_collective_must_be_unique(self):
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([
ir.Attribute.parse(
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
f"#mosaic_gpu.dim<{mgpu.dialect.Dimension.x.name}>"
),
ir.Attribute.parse(
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
f"#mosaic_gpu.dim<{mgpu.dialect.Dimension.x.name}>"
),
]),
)
Expand All @@ -353,7 +349,7 @@ def test_async_store_op_source_must_be_contiguous(self):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
Expand All @@ -377,7 +373,7 @@ def test_async_store_op_source_and_dest_must_have_same_element_type(self):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
Expand All @@ -401,7 +397,7 @@ def test_async_store_op_slice_lengths_must_be_larger_than_minus_two(self):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
Expand All @@ -426,7 +422,7 @@ def test_async_store_op_source_and_dest_ranks_must_match_with_collapse(self):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
Expand All @@ -449,7 +445,7 @@ def test_async_store_op_indices_size_must_match_destination_rank(self):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
Expand All @@ -472,7 +468,7 @@ def test_async_store_op_slice_lengths_size_must_match_source_rank(self):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
Expand All @@ -496,7 +492,7 @@ def test_wgmma_types_match(self):
ir.MemRefType.get([4, 5, 32, 32], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
lambda accumulator, a, b: mgpu.dialect.wgmma(
accumulator,
a,
b,
Expand All @@ -518,7 +514,7 @@ def test_wgmma_b_rank_is_4(self):
ir.MemRefType.get([5, 32, 32], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
lambda accumulator, a, b: mgpu.dialect.wgmma(
accumulator,
a,
b,
Expand All @@ -540,7 +536,7 @@ def test_wgmma_b_shape_dim_3(self):
ir.MemRefType.get([4, 5, 32, 16], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
lambda accumulator, a, b: mgpu.dialect.wgmma(
accumulator,
a,
b,
Expand All @@ -563,7 +559,7 @@ def test_wgmma_b_shape_dim_2(self):
ir.MemRefType.get([4, 5, 64, 32], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
lambda accumulator, a, b: mgpu.dialect.wgmma(
accumulator,
a,
b,
Expand All @@ -585,12 +581,12 @@ class DialectLoweringTest(MosaicGpuTest):

def test_lowering_removes_mosaic_gpu_ops(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1,
)
lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module)

self.assertEmpty(
list(filter(is_mosaic_gpu_op, self.module.body.operations))
Expand All @@ -602,13 +598,13 @@ def test_lowering_traverses_regions_correctly(self):
cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1))
if_op = scf.IfOp(cst_true)
with ir.InsertionPoint(if_op.then_block):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1,
)
scf.yield_([])
lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module)

self.assertEmpty(
list(filter(is_mosaic_gpu_op, if_op.then_block.operations))
Expand All @@ -620,7 +616,7 @@ def test_initialize_barrier_op_lowering_rule(self):
arrival_count = 1337

with ir.InsertionPoint(self.module.body):
barriers_ref = mgpu.initialize_barrier(
barriers_ref = mgpu.dialect.initialize_barrier(
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=arrival_count,
Expand All @@ -630,7 +626,7 @@ def test_initialize_barrier_op_lowering_rule(self):
memref.copy(barriers_ref, barriers_ref)

self.assertTrue(self.module.operation.verify())
lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module)
self.assertTrue(self.module.operation.verify())

all_mbarrier_init_shared_ops = find_if(
Expand Down Expand Up @@ -658,7 +654,7 @@ def test_lowering_vector_op_without_layout_fails(self):
with self.assertRaisesRegex(
ValueError, "missing a layout and can not be lowered"
):
lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module)

def test_lowering_eliminates_layouts(self):
shape = (4, 128)
Expand All @@ -669,10 +665,10 @@ def test_lowering_eliminates_layouts(self):
ty = ir.VectorType.get(shape, elt_ty)
load = vector.load(ty, ref, [zero_index, zero_index])
load.owner.attributes["out_layouts"] = ir.ArrayAttr.get(
[strided_fragmented_layout()]
[mgpu.strided_fragmented_layout()]
)

lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module)

all_ops_with_layouts = find_if(
self.module,
Expand All @@ -692,8 +688,8 @@ def test_lowering_vector_load_and_store_ops(self):
array = vector.load(ty, ref, [zero_index, zero_index])
vector.store(array, ref, [zero_index, zero_index])

infer_layout(self.module)
lower_mgpu_dialect(self.module)
mgpu.infer_layout(self.module)
mgpu.lower_mgpu_dialect(self.module)

all_loads = find_if(
self.module,
Expand Down

0 comments on commit 6a03ea3

Please sign in to comment.