Skip to content

Commit

Permalink
Add mesh_shape to the lowering context. This is to allow custom par…
Browse files Browse the repository at this point in the history
…titioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them.

Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions.

Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before`

In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses

PiperOrigin-RevId: 660177423
  • Loading branch information
yashk2810 authored and jax authors committed Aug 7, 2024
1 parent 7f44edc commit dd958ad
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 38 deletions.
6 changes: 5 additions & 1 deletion jax/_src/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Any
import weakref

import numpy as np
import jax
from jax import tree_util
from jax._src import api_util
Expand Down Expand Up @@ -481,17 +482,20 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
infer_sharding_from_operands,
decode_shardings,
static_args):
mesh = mesh_lib.thread_resources.env.physical_mesh
axis_context = ctx.module_context.axis_context
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values)

mesh = mesh_lib.thread_resources.env.physical_mesh
if isinstance(axis_context, sharding_impls.ShardingContext):
devices = axis_context.device_assignment
if devices is None:
raise AssertionError(
'Please file a bug at https://github.com/google/jax/issues')
if axis_context.mesh_shape is not None:
ma, ms = list(zip(*axis_context.mesh_shape))
mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma)
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = axis_context.mesh._flat_devices_tuple
else:
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,6 @@ def lower_jaxpr_to_module(
input_output_aliases: None | tuple[int | None, ...] = None,
propagated_out_mem_kinds: tuple[None | str, ...] | None = None,
lowering_parameters: LoweringParameters,
mesh_shape_tuple: tuple[tuple[str, int], ...] | None = None,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
Expand Down Expand Up @@ -1041,13 +1040,14 @@ def lower_jaxpr_to_module(
# XLA computation preserves the module name.
attrs = ctx.module.operation.attributes
if config.use_shardy_partitioner.value:
assert mesh_shape_tuple is not None
assert (isinstance(axis_context, sharding_impls.ShardingContext) and
axis_context.mesh_shape is not None)
ctx.module.body.append(
dialects.sdy.MeshOp(
"mesh",
dialects.sdy.MeshAttr.get(
[dialects.sdy.MeshAxisAttr.get(name, size)
for name, size in mesh_shape_tuple])))
for name, size in axis_context.mesh_shape])))
module_name = _module_name_regex.sub("_", module_name)
attrs["sym_name"] = ir.StringAttr.get(module_name)
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
Expand Down
23 changes: 12 additions & 11 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,7 +1881,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
propagated_out_mem_kinds: tuple[None | str, ...],
platforms: tuple[str, ...],
lowering_parameters: mlir.LoweringParameters,
mesh_shape_tuple: tuple[tuple[str, int], ...]):
mesh_shape_tuple: tuple[tuple[str, int], ...] | None):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
Expand Down Expand Up @@ -1911,7 +1911,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(global_in_avals)
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment)
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment,
mesh_shape=mesh_shape_tuple)
num_partitions = num_devices
else:
# This path is triggered for `jit(pmap)` cases.
Expand Down Expand Up @@ -1957,8 +1958,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
all_default_mem_kind=all_default_mem_kind,
input_output_aliases=inout_aliases,
propagated_out_mem_kinds=propagated_out_mem_kinds,
lowering_parameters=lowering_parameters,
mesh_shape_tuple=mesh_shape_tuple)
lowering_parameters=lowering_parameters)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
Expand Down Expand Up @@ -2203,14 +2203,15 @@ def lower_sharding_computation(
semantic_out_shardings = SemanticallyEqualShardings(
out_shardings, global_out_avals) # type: ignore
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)

# TODO(yashkatariya): Initialize with context_mesh here?
mesh_shape_tuple = None
if config.use_shardy_partitioner.value:
for sharding in it.chain(
in_shardings, out_shardings,
[js for js, _ in unique_intermediate_shardings]):
if isinstance(sharding, sharding_impls.NamedSharding):
mesh_shape_tuple = sharding.mesh.shape_tuple
break
for sharding in it.chain(
in_shardings, out_shardings,
[js for js, _ in unique_intermediate_shardings]):
if isinstance(sharding, sharding_impls.NamedSharding):
mesh_shape_tuple = sharding.mesh.shape_tuple
break

(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
Expand Down
1 change: 1 addition & 0 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,7 @@ class ShardingContext:
"""
num_devices: int
device_assignment: tuple[xc.Device, ...] | None = None
mesh_shape: tuple[tuple[str, int], ...] | None = None

def __post_init__(self):
if self.device_assignment is not None:
Expand Down
17 changes: 0 additions & 17 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,23 +1178,6 @@ def test_jit_cpp_cache_hit(self):
self.assertArraysEqual(out, np_inp @ np_inp.T)
self.assertArraysEqual(out2, np_inp @ np_inp.T)

def test_jit_compilation_cache_hit(self):
mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
inp2 = jax.device_put(
np_inp, GSPMDSharding(tuple(mesh.devices.flat),
s._to_xla_hlo_sharding(inp.ndim),
memory_kind="device")
)

f = jax.jit(lambda x: x @ x.T)

with (jtu.count_pjit_cpp_cache_miss() as cpp_count,
jtu.count_jit_and_pmap_compiles() as compile_count):
f(inp)
f(inp2)
self.assertEqual(cpp_count[0], 2)
self.assertEqual(compile_count[0], 1)

def test_jit_cpp_cache_output_hit(self):
_, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device")

Expand Down
51 changes: 45 additions & 6 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,45 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
pjit_f = pjit(jit_f, in_shardings=(P('x')), out_shardings=P('x'))
self.assertArraysEqual(x, pjit_f(x))

def test_custom_partitioning_no_mesh_context(self):
self.skip_if_custom_partitioning_not_supported()

@custom_partitioning
def f(x):
return x

def partition(mesh, arg_shapes, result_shape):
def lower_fn(x):
@jax.jit
def g(y):
return y

return g(x)

x_shard = arg_shapes[0].sharding
return (
mesh,
lower_fn,
NamedSharding(x_shard.mesh, P('x')),
(NamedSharding(x_shard.mesh, P('x')),),
)

def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
x_shard = arg_shapes[0].sharding
return NamedSharding(x_shard.mesh, P('x'))

f.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition,
)

mesh = jtu.create_global_mesh((4,), ('x',))
x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32)
s = NamedSharding(mesh, P('x'))

pjit_f = jax.jit(f, in_shardings=s, out_shardings=s)
self.assertArraysEqual(x, pjit_f(x))

@jtu.with_mesh([('x', 4)])
def test_custom_partitioner_with_scan(self):
self.skip_if_custom_partitioning_not_supported()
Expand Down Expand Up @@ -3409,8 +3448,8 @@ def mul(x):
cache_info4 = pxla._cached_compilation.cache_info()
self.assertIsInstance(out4.sharding, PositionalSharding)

self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
self.assertEqual(cache_info4.misses, cache_info3.misses)
self.assertEqual(cache_info4.hits, cache_info3.hits)
self.assertEqual(cache_info4.misses, cache_info3.misses + 1)

def test_cache_hit_pjit_lower_with_cpp_cache_miss(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
Expand Down Expand Up @@ -3521,8 +3560,8 @@ def test_jit_mul_sum_sharding_preserved(self):
self.assertIsInstance(out3.sharding, PositionalSharding)
self.assertEqual(count[0], 1)

self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
self.assertEqual(cache_info2.hits, cache_info1.hits)
self.assertEqual(cache_info2.misses, cache_info1.misses + 1)

self.assertEqual(pl_cache_info2.hits, pl_cache_info1.hits)
self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1)
Expand Down Expand Up @@ -3813,7 +3852,7 @@ def test_lowering_cache_hit_different_devices(self):
self.skipTest('Requires >=4 devices')

mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y')
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'x')

@jax.jit
def f(x):
Expand All @@ -3824,7 +3863,7 @@ def g(a):
out_a = f(a) # lowering cached

# same num_devices but different devices.
b = jax.device_put(out_a, NamedSharding(mesh2, P('y')))
b = jax.device_put(out_a, NamedSharding(mesh2, P('x')))
f(b) # lowering cache *hit*

with jtu.count_jit_and_pmap_compiles() as count:
Expand Down

0 comments on commit dd958ad

Please sign in to comment.