From 285977860dfbb446f5137ff942efa460ac8f1b22 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 14 Jan 2025 13:46:05 -0800 Subject: [PATCH] [jax:custom_partitioning] Support SdyShardingRule with multiple leading batching dimension groups. Previously, we allow the use of ellipsis ... in the Einsum like notation to represent leading batching dimensions in one group of operands and results. We now allow the use of ellipsis optionally followed by a single digit, such as ...2, to represent leading batching dimensions for multiple groups of operands and results. Add tests. PiperOrigin-RevId: 715515514 --- jax/_src/custom_partitioning_sharding_rule.py | 72 +++++++++++++------ .../custom_partitioning_sharding_rule_test.py | 32 +++++++-- 2 files changed, 74 insertions(+), 30 deletions(-) diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py index 68294b5913cc..f461410cb4c9 100644 --- a/jax/_src/custom_partitioning_sharding_rule.py +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -42,6 +42,20 @@ def _check_factor(factor:str): if char != "_" and not char.isdigit() and not char.isalpha(): raise ValueError(f"Unknown character '{char}'") +def _is_batching(factor: str) -> bool: + """Checks if a factor is a representation for leading batching dimensions. + + Leading batching dimensions is represented by a factor containing ... and + optionally followed by a digit, and ... is equivalent to ...0. + """ + if len(factor) < 1 or factor[0] != BATCHING: + return False + return len(factor) == 1 or factor[1:].isdigit() + +def _get_batching_group(factor: str) -> str: + """Extracts the batching group from a factor for leading batching dimensions.""" + return factor[1:] if len(factor) > 1 else "0" + class CompoundFactor(tuple): """Describes the factors for a compound factor. @@ -54,7 +68,7 @@ def __init__(self, *factors): for factor in factors: if not isinstance(factor, str): raise ValueError(f"Each element of CompoundFactor must be a str, but got {type(factor)}") - if factor == BATCHING: + if _is_batching(factor): raise ValueError("Ellipsis can't be used in a compound factor") else: _check_factor(factor) @@ -80,7 +94,7 @@ def __init__(self, *dim_mappings): "Each element of ArrayMapping must be a str or CompoundFactor, but" f" got {type(d)}") if isinstance(d, str): - if d == BATCHING: + if _is_batching(d): if i != 0: raise ValueError("Ellipsis can only be used at the beginning of a dimension") else: @@ -141,7 +155,7 @@ def __str__(self): return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})" -def _get_batching_dim_factor_name(batch_dim_order : int): +def _get_batching_dim_factor_name(batch_group: str,batch_dim_order : int): """Constructs a factor name for a batching dimension. We expand the leading ... into factors representing the batching dimensions @@ -149,7 +163,7 @@ def _get_batching_dim_factor_name(batch_dim_order : int): reason, we construct a factor name that won't be used by users for the batching dimensions. """ - return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}" + return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_group}_{batch_dim_order}" def _parse_values( rule: str, @@ -194,13 +208,26 @@ def add_factor(x): else: current_compound_dim.append(x) - for char in rule: + rule_len = len(rule) + rule_index = 0 + while rule_index < rule_len: + char = rule[rule_index] + rule_index += 1 if char == BATCHING: if (current_factor is not None or current_compound_dim is not None or value): raise ValueError( "Ellipsis can only be used at the beginning of a dimension") - add_factor(BATCHING) + if rule_index < rule_len and rule[rule_index].isdigit(): + batching_group_str = "" + while rule_index < rule_len and rule[rule_index].isdigit(): + batching_group_str += rule[rule_index] + rule_index += 1 + batching_group = str(int(batching_group_str)) + else: + batching_group = "0" + + add_factor(f"{BATCHING}{batching_group}") continue if char in "(), ": if current_factor is not None: @@ -342,9 +369,8 @@ def add_factor(factor, size): factor_index = len(factors_to_indices_sizes) factors_to_indices_sizes[factor] = [factor_index, size] - def add_batching_dim_factor(batch_dim_order, factor_size): - ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order) - add_factor(ellipsis_batch_dim_name, factor_size) + def add_batching_dim_factor(batch_grp, batch_dim_order, factor_size): + add_factor(_get_batching_dim_factor_name(batch_grp, batch_dim_order), factor_size) def build_dim_mapping_for_compound_factors(i, j, factors): accumulated_size = 1 @@ -365,23 +391,25 @@ def build_dim_mapping_for_compound_factors(i, j, factors): # Add factors and their sizes in the order they appear in the rule, # including the batching dimensions represented by ellipsis. - ellipsis_rank = None + batching_group_to_rank = {} for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings): value = tuple(mapping) - if value and value[0] == BATCHING: - has_batching = True + if value and _is_batching(value[0]): + batching_group = _get_batching_group(value[0]) value = value[1:] else: - has_batching = False + batching_group = None rule_rank = len(value) op_rank = get_rank_for_value(i) # The number of dimensions represented by ellipsis. current_batching_rank = 0 - if has_batching and op_rank >= rule_rank: + if batching_group is not None and op_rank >= rule_rank: current_batching_rank = op_rank - rule_rank - if has_batching: + if batching_group is not None: + ellipsis_rank = batching_group_to_rank.get(batching_group, None) if ellipsis_rank is None: ellipsis_rank = current_batching_rank + batching_group_to_rank[batching_group] = ellipsis_rank elif ellipsis_rank != current_batching_rank: raise ValueError( "Ellipsis represents different number of leading dimensions" @@ -394,7 +422,7 @@ def build_dim_mapping_for_compound_factors(i, j, factors): f" {msg} has rank {op_rank}") for j in range(current_batching_rank): - add_batching_dim_factor(j, get_size_for_value_dim(i, j)) + add_batching_dim_factor(batching_group, j, get_size_for_value_dim(i, j)) for j, dim in enumerate(value): if isinstance(dim, str): @@ -408,20 +436,18 @@ def build_dim_mapping_for_compound_factors(i, j, factors): for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings): value = tuple(mapping) dim_mappings = [] - - if value and value[0] == BATCHING: + if value and _is_batching(value[0]): + batching_group = _get_batching_group(value[0]) value = value[1:] - if ellipsis_rank is None: - current_batching_rank = 0 - else: - current_batching_rank = ellipsis_rank + current_batching_rank = batching_group_to_rank.get(batching_group) else: current_batching_rank = 0 + batching_group = None for j in range(current_batching_rank): dim_mappings.append( sdy.DimMappingAttr.get(factor_indices=[ - factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]])) + factors_to_indices_sizes[_get_batching_dim_factor_name(batching_group, j)][0]])) for j, dim in enumerate(value): if isinstance(dim, str): diff --git a/tests/custom_partitioning_sharding_rule_test.py b/tests/custom_partitioning_sharding_rule_test.py index 3aed16510a4f..f22721910408 100644 --- a/tests/custom_partitioning_sharding_rule_test.py +++ b/tests/custom_partitioning_sharding_rule_test.py @@ -50,8 +50,8 @@ def test_value_mapping_ellipsis_not_first(self): ArrayMapping("i_j", BATCHING) def test_value_mapping_str(self): - v = ArrayMapping(BATCHING, "m", CompoundFactor("i", "j"), "k") - self.assertEqual(str(v), f"('{BATCHING}', 'm', ('i', 'j'), 'k')") + v = ArrayMapping(f"{BATCHING}2", "m", CompoundFactor("i", "j"), "k") + self.assertEqual(str(v), f"('{BATCHING}2', 'm', ('i', 'j'), 'k')") def test_sdy_sharding_rule_factor_size_not_used(self): with self.assertRaisesRegex(ValueError, "Factor k is not used"): @@ -158,17 +158,18 @@ def test_sharding_rule_one_scalar_operand(self): str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})") def test_sharding_rule_factor_elementwise_add(self): - rule = str_to_sdy_sharding_rule("... i j, ...i j -> ...i j") + # An ellipsis without a number ... is treated as the same as ...0. + rule = str_to_sdy_sharding_rule("...0 i j, ...1 i j -> ...i j") self.assertEqual( str(rule), - "SdyShardingRule((('…', 'i', 'j'), ('…', 'i', 'j')), (('…', 'i'," + "SdyShardingRule((('…0', 'i', 'j'), ('…1', 'i', 'j')), (('…0', 'i'," " 'j'),), {})") def test_sharding_rule_factor_vector_scalar_add(self): - rule = str_to_sdy_sharding_rule("...i, -> ...i") + rule = str_to_sdy_sharding_rule("...87 i, -> ...87 i") self.assertEqual( str(rule), - "SdyShardingRule((('…', 'i'), ()), (('…', 'i'),), {})") + "SdyShardingRule((('…87', 'i'), ()), (('…87', 'i'),), {})") def test_sharding_rule_factor_reshape_combining(self): rule = str_to_sdy_sharding_rule("i j -> (i j)") @@ -316,7 +317,7 @@ def test_conversion_batching_dim_has_two_sizes(self): rule = str_to_sdy_sharding_rule("..., ... -> ...") with self.assertRaisesRegex( ValueError, - "Batching dimension 1 corresponds to two sizes: 32 and 64"): + "Batching dimension 0_1 corresponds to two sizes: 32 and 64"): sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,],) @@ -464,5 +465,22 @@ def test_conversion_contracting_dim_matmul(self): "#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>") + def test_conversion_multiple_batching_groups(self): + opnd0 = self.create_tensor_value((4, 5, 16, 32)) + opnd1 = self.create_tensor_value((6, 7, 8, 32, 16)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((4, 5, 32, 16))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("... j i, ...1 i j -> ...i j") + mlir_rule = sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}>") + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())