Skip to content

Commit

Permalink
[jax:custom_partitioning] Support SdyShardingRule with multiple leading
Browse files Browse the repository at this point in the history
batching dimension groups.

Previously, we allow the use of ellipsis ... in the Einsum like notation to
represent leading batching dimensions in one group of operands and results. We
now allow the use of ellipsis optionally followed by a single digit, such as
...2, to represent leading batching dimensions for multiple groups of operands
and results.

Add tests.

PiperOrigin-RevId: 715515514
  • Loading branch information
bixia1 authored and Google-ML-Automation committed Jan 21, 2025
1 parent c4643c6 commit 2859778
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 30 deletions.
72 changes: 49 additions & 23 deletions jax/_src/custom_partitioning_sharding_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ def _check_factor(factor:str):
if char != "_" and not char.isdigit() and not char.isalpha():
raise ValueError(f"Unknown character '{char}'")

def _is_batching(factor: str) -> bool:
"""Checks if a factor is a representation for leading batching dimensions.
Leading batching dimensions is represented by a factor containing ... and
optionally followed by a digit, and ... is equivalent to ...0.
"""
if len(factor) < 1 or factor[0] != BATCHING:
return False
return len(factor) == 1 or factor[1:].isdigit()

def _get_batching_group(factor: str) -> str:
"""Extracts the batching group from a factor for leading batching dimensions."""
return factor[1:] if len(factor) > 1 else "0"

class CompoundFactor(tuple):
"""Describes the factors for a compound factor.
Expand All @@ -54,7 +68,7 @@ def __init__(self, *factors):
for factor in factors:
if not isinstance(factor, str):
raise ValueError(f"Each element of CompoundFactor must be a str, but got {type(factor)}")
if factor == BATCHING:
if _is_batching(factor):
raise ValueError("Ellipsis can't be used in a compound factor")
else:
_check_factor(factor)
Expand All @@ -80,7 +94,7 @@ def __init__(self, *dim_mappings):
"Each element of ArrayMapping must be a str or CompoundFactor, but"
f" got {type(d)}")
if isinstance(d, str):
if d == BATCHING:
if _is_batching(d):
if i != 0:
raise ValueError("Ellipsis can only be used at the beginning of a dimension")
else:
Expand Down Expand Up @@ -141,15 +155,15 @@ def __str__(self):
return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})"


def _get_batching_dim_factor_name(batch_dim_order : int):
def _get_batching_dim_factor_name(batch_group: str,batch_dim_order : int):
"""Constructs a factor name for a batching dimension.
We expand the leading ... into factors representing the batching dimensions
to support building the MLIR representation for the sharding rule. For this
reason, we construct a factor name that won't be used by users for the
batching dimensions.
"""
return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}"
return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_group}_{batch_dim_order}"

def _parse_values(
rule: str,
Expand Down Expand Up @@ -194,13 +208,26 @@ def add_factor(x):
else:
current_compound_dim.append(x)

for char in rule:
rule_len = len(rule)
rule_index = 0
while rule_index < rule_len:
char = rule[rule_index]
rule_index += 1
if char == BATCHING:
if (current_factor is not None or current_compound_dim is not None
or value):
raise ValueError(
"Ellipsis can only be used at the beginning of a dimension")
add_factor(BATCHING)
if rule_index < rule_len and rule[rule_index].isdigit():
batching_group_str = ""
while rule_index < rule_len and rule[rule_index].isdigit():
batching_group_str += rule[rule_index]
rule_index += 1
batching_group = str(int(batching_group_str))
else:
batching_group = "0"

add_factor(f"{BATCHING}{batching_group}")
continue
if char in "(), ":
if current_factor is not None:
Expand Down Expand Up @@ -342,9 +369,8 @@ def add_factor(factor, size):
factor_index = len(factors_to_indices_sizes)
factors_to_indices_sizes[factor] = [factor_index, size]

def add_batching_dim_factor(batch_dim_order, factor_size):
ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order)
add_factor(ellipsis_batch_dim_name, factor_size)
def add_batching_dim_factor(batch_grp, batch_dim_order, factor_size):
add_factor(_get_batching_dim_factor_name(batch_grp, batch_dim_order), factor_size)

def build_dim_mapping_for_compound_factors(i, j, factors):
accumulated_size = 1
Expand All @@ -365,23 +391,25 @@ def build_dim_mapping_for_compound_factors(i, j, factors):

# Add factors and their sizes in the order they appear in the rule,
# including the batching dimensions represented by ellipsis.
ellipsis_rank = None
batching_group_to_rank = {}
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
value = tuple(mapping)
if value and value[0] == BATCHING:
has_batching = True
if value and _is_batching(value[0]):
batching_group = _get_batching_group(value[0])
value = value[1:]
else:
has_batching = False
batching_group = None
rule_rank = len(value)
op_rank = get_rank_for_value(i)
# The number of dimensions represented by ellipsis.
current_batching_rank = 0
if has_batching and op_rank >= rule_rank:
if batching_group is not None and op_rank >= rule_rank:
current_batching_rank = op_rank - rule_rank
if has_batching:
if batching_group is not None:
ellipsis_rank = batching_group_to_rank.get(batching_group, None)
if ellipsis_rank is None:
ellipsis_rank = current_batching_rank
batching_group_to_rank[batching_group] = ellipsis_rank
elif ellipsis_rank != current_batching_rank:
raise ValueError(
"Ellipsis represents different number of leading dimensions"
Expand All @@ -394,7 +422,7 @@ def build_dim_mapping_for_compound_factors(i, j, factors):
f" {msg} has rank {op_rank}")

for j in range(current_batching_rank):
add_batching_dim_factor(j, get_size_for_value_dim(i, j))
add_batching_dim_factor(batching_group, j, get_size_for_value_dim(i, j))

for j, dim in enumerate(value):
if isinstance(dim, str):
Expand All @@ -408,20 +436,18 @@ def build_dim_mapping_for_compound_factors(i, j, factors):
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
value = tuple(mapping)
dim_mappings = []

if value and value[0] == BATCHING:
if value and _is_batching(value[0]):
batching_group = _get_batching_group(value[0])
value = value[1:]
if ellipsis_rank is None:
current_batching_rank = 0
else:
current_batching_rank = ellipsis_rank
current_batching_rank = batching_group_to_rank.get(batching_group)
else:
current_batching_rank = 0
batching_group = None

for j in range(current_batching_rank):
dim_mappings.append(
sdy.DimMappingAttr.get(factor_indices=[
factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]]))
factors_to_indices_sizes[_get_batching_dim_factor_name(batching_group, j)][0]]))

for j, dim in enumerate(value):
if isinstance(dim, str):
Expand Down
32 changes: 25 additions & 7 deletions tests/custom_partitioning_sharding_rule_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def test_value_mapping_ellipsis_not_first(self):
ArrayMapping("i_j", BATCHING)

def test_value_mapping_str(self):
v = ArrayMapping(BATCHING, "m", CompoundFactor("i", "j"), "k")
self.assertEqual(str(v), f"('{BATCHING}', 'm', ('i', 'j'), 'k')")
v = ArrayMapping(f"{BATCHING}2", "m", CompoundFactor("i", "j"), "k")
self.assertEqual(str(v), f"('{BATCHING}2', 'm', ('i', 'j'), 'k')")

def test_sdy_sharding_rule_factor_size_not_used(self):
with self.assertRaisesRegex(ValueError, "Factor k is not used"):
Expand Down Expand Up @@ -158,17 +158,18 @@ def test_sharding_rule_one_scalar_operand(self):
str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})")

def test_sharding_rule_factor_elementwise_add(self):
rule = str_to_sdy_sharding_rule("... i j, ...i j -> ...i j")
# An ellipsis without a number ... is treated as the same as ...0.
rule = str_to_sdy_sharding_rule("...0 i j, ...1 i j -> ...i j")
self.assertEqual(
str(rule),
"SdyShardingRule((('…', 'i', 'j'), ('…', 'i', 'j')), (('…', 'i',"
"SdyShardingRule((('…0', 'i', 'j'), ('…1', 'i', 'j')), (('…0', 'i',"
" 'j'),), {})")

def test_sharding_rule_factor_vector_scalar_add(self):
rule = str_to_sdy_sharding_rule("...i, -> ...i")
rule = str_to_sdy_sharding_rule("...87 i, -> ...87 i")
self.assertEqual(
str(rule),
"SdyShardingRule((('…', 'i'), ()), (('…', 'i'),), {})")
"SdyShardingRule((('…87', 'i'), ()), (('…87', 'i'),), {})")

def test_sharding_rule_factor_reshape_combining(self):
rule = str_to_sdy_sharding_rule("i j -> (i j)")
Expand Down Expand Up @@ -316,7 +317,7 @@ def test_conversion_batching_dim_has_two_sizes(self):
rule = str_to_sdy_sharding_rule("..., ... -> ...")
with self.assertRaisesRegex(
ValueError,
"Batching dimension 1 corresponds to two sizes: 32 and 64"):
"Batching dimension 0_1 corresponds to two sizes: 32 and 64"):
sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,],)
Expand Down Expand Up @@ -464,5 +465,22 @@ def test_conversion_contracting_dim_matmul(self):
"#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>")


def test_conversion_multiple_batching_groups(self):
opnd0 = self.create_tensor_value((4, 5, 16, 32))
opnd1 = self.create_tensor_value((6, 7, 8, 32, 16))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((4, 5, 32, 16))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = str_to_sdy_sharding_rule("... j i, ...1 i j -> ...i j")
mlir_rule = sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}>")


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 2859778

Please sign in to comment.