Skip to content

Commit

Permalink
[Mosaic GPU] Do not use mgpu in wgmma.py
Browse files Browse the repository at this point in the history
This enables the dialect lowering to depend on `wgmma.py` without creating a circular dependency. I need this in a follow up CL that implements the lowering of the WGMMA dialect op.

PiperOrigin-RevId: 717784498
  • Loading branch information
dimitar-asenov authored and Google-ML-Automation committed Jan 21, 2025
1 parent bba5ada commit 296a89d
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions jax/experimental/mosaic/gpu/wgmma.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
from jaxlib.mlir.dialects import nvvm
import numpy as np

import jax.experimental.mosaic.gpu as mgpu
from . import utils
from . import fragmented_array as fa

# mypy: ignore-errors

c = mgpu.c
bytewidth = mgpu.bytewidth
c = utils.c
bytewidth = utils.bytewidth


@jax.tree_util.register_pytree_node_class
Expand All @@ -44,10 +44,10 @@ class WGMMAAccumulator:
as a WGMMA accumulator. In particular, when created from a
FragmentedArray, the necessary synchronization is inserted at construction.
"""
value: mgpu.FragmentedArray
value: fa.FragmentedArray

def __init__(self, *, _value: mgpu.FragmentedArray, _sync: bool = True):
if _value.layout != mgpu.WGMMA_LAYOUT:
def __init__(self, *, _value: fa.FragmentedArray, _sync: bool = True):
if _value.layout != fa.WGMMA_LAYOUT:
raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator")
self.value = _value
if _sync:
Expand All @@ -64,8 +64,8 @@ def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None):
dtype = f32
zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0))
return cls(
_value=mgpu.FragmentedArray.splat(
zero, (m, n), mgpu.WGMMA_LAYOUT, is_signed=is_signed
_value=fa.FragmentedArray.splat(
zero, (m, n), fa.WGMMA_LAYOUT, is_signed=is_signed
)
)

Expand Down Expand Up @@ -171,11 +171,11 @@ def wgmma_m64(
supports_transpose = bytewidth(element_type) == 2
if not supports_transpose and (a_transpose or b_transpose):
raise ValueError("Only f16 WGMMA supports transposes")
if a_in_regs := isinstance(a, mgpu.FragmentedArray):
if a_in_regs := isinstance(a, fa.FragmentedArray):
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
# Column count must be equal to swizzle // bytewidth.
if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, swizzle // 2):
if a.layout != fa.WGMMA_LAYOUT or a.shape != (64, swizzle // 2):
raise ValueError("Unsupported A register array layout")
if a_k_stride is not None or a_transpose is not None:
raise ValueError("Unsupported WGMMA features with A in registers")
Expand Down Expand Up @@ -310,7 +310,7 @@ def wgmma(
a_order: WGMMALayout | None = None,
b_order: WGMMALayout = WGMMALayout.ROW_MAJOR,
):
if a_in_regs := isinstance(a, mgpu.FragmentedArray):
if a_in_regs := isinstance(a, fa.FragmentedArray):
a_element_type = a.mlir_dtype
a_shape = a.shape
else:
Expand Down Expand Up @@ -434,23 +434,23 @@ def wgmma(
new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params
)
return WGMMAAccumulator(
_value=mgpu.FragmentedArray(
_value=fa.FragmentedArray(
_registers=new_acc_regs,
_layout=mgpu.WGMMA_LAYOUT,
_layout=fa.WGMMA_LAYOUT,
_is_signed=acc.value.is_signed,
),
_sync=False,
)


def wgmma_fence(array: mgpu.FragmentedArray):
def wgmma_fence(array: fa.FragmentedArray):
"""Fences the array construction from WGMMA instructions.
LLVM treats in-register computation as pure and can move it after the fence,
which is explicitly disallowed by the PTX programming model. For that reason,
we insert an LLVM optimization barrier before the fence.
"""
array = mgpu.optimization_barrier(array)
array = fa.optimization_barrier(array)
nvvm.wgmma_fence_aligned()
return array

Expand Down

0 comments on commit 296a89d

Please sign in to comment.