From 229b9a3f3340f0d02a36fa5900fc9849d57c4252 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 5 Feb 2024 15:05:49 -0600 Subject: [PATCH] Dynamism RFC (#1881) The current design of dynamism in MHLO and StableHLO has been practically useful. There are success stories of it connecting JAX, PyTorch and TensorFlow to a number of compilers, in a mix of research and production environments. This RFC aims to leverage existing support for dynamism in the StableHLO dialect, discuss improvements to the existing design and then formalize the improved design in the StableHLO opset. The main challenge with writing this RFC was that it affects the entire opset. The current design involves a lot of corner cases, so it took about a year of practical evaluation by the author - primarily within JAX native serialization, but also correlating with other ongoing and prior projects - to distill the design into just a few general design principles. I think I'm happy with the outcome, but please take a look at the "Summary" section for what that entails. This RFC addresses a considerable chunk of community feedback on the existing design, but some feedback is deliberately left out of scope for this RFC to enable incremental progress while areas which require additional alignment are developed in parallel. See sections "Community feedback" and "Out of scope" for details. The only open question at the moment is interoperability with HLO with respect to bounded dynamism. The proposed representation for bounded dynamic types in StableHLO is in parity with the representation for bounded dynamic types in HLO. However, I'm not sure whether this parity covers 100% of bounded dynamism functionality. For example: 1) there appears to be a mismatch in how broadcasts are represented, 2) there is misalignment in representations of dynamic windows (HLO has a high-level representation: VALID and SAME, whereas StableHLO expects the producers to dynamically compute window sizes). Nonetheless, I think that this shouldn't block the initial review of the RFC, since there's a lot of stuff to discuss - and in the meanwhile, I'll be working on confirming interoperability with HLO. Finally, I'd like to acknowledge Smit Hinsu's work on the [Bounded Dynamism RFC](https://github.com/openxla/stablehlo/pull/194) from Q4 2022, which was superseded by this work. The representation for bounded dynamic types in the StableHLO dialect was designed and implemented by Smit, and Smit's proposal to allow bounded dynamic types everywhere is compatible with the more general proposal from this RFC to enable dynamism for all size-related program elements. Furthermore, Smit contributed the formal spec for get_dimension_size as well as the informal spec for set_dimension_size. --- rfcs/20230704-dynamism-101.md | 662 ++++++++++++++++++++++++++++++++++ 1 file changed, 662 insertions(+) create mode 100644 rfcs/20230704-dynamism-101.md diff --git a/rfcs/20230704-dynamism-101.md b/rfcs/20230704-dynamism-101.md new file mode 100644 index 00000000000..0b0942d2b0d --- /dev/null +++ b/rfcs/20230704-dynamism-101.md @@ -0,0 +1,662 @@ +# RFC: Dynamism 101 + +Status: Under review
+Initial version: 7/4/2023
+Last updated: 7/4/2023
+Discussion thread: [openxla-discuss](https://groups.google.com/a/openxla.org/g/openxla-discuss/c/HJRvFBum65k/m/7QtJxgB9AQAJ). + +## Summary + +This RFC aims to leverage existing support for dynamism in the StableHLO +dialect, discuss improvements to the existing design and then formalize the +improved design in the StableHLO opset. To that end, this document proposes +the following: + + +* [(P1) Move TensorFlow-specific operations out of StableHLO](#p1). +* [(P2) Ratify the existing convention for shape mismatches constituting undefined behavior](#p2). +* [(P3) Ratify the existing convention for relaxed constraints already implemented in the StableHLO dialect](#p3). +* [(P4) Enable shape-dependent dynamism for all size-related program elements but keep all axis-related program elements static](#p4). +* [(P5) Represent shape computations as StableHLO operations on variadic 0-dimensional tensors and drop support for unranked dynamism](#p5). + + +These proposals address a considerable chunk of community feedback on the +existing design, but some feedback is deliberately left out of scope for this +RFC to enable incremental progress while areas which require additional +alignment are developed in parallel. See sections +["Community feedback"](#community-feedback) and +["Out of scope"](#out-of-scope) for details. + +## Existing design + +This is an RFC that affects the entirety of the StableHLO opset, and there are +multiple groups of operations which are affected differently. Taking into +account all these operations appropriately was quite laborious, but +[the StableHLO Ops spreadsheet](https://docs.google.com/spreadsheets/d/1rvhxQMFUtCZ5DsY6X0_lJOCg9rVO2MdyeZlRorsc0UI/edit?resourcekey=0-5gMjnlkXDL6hCntv2yltaQ#gid=0) +was a big help. Referring to this spreadsheet will likely help in reviewing this +RFC as well. + +Before beginning, let's align on terminology. In discussions around MLIR, it is +fairly conventional to use the word "dimension" ambiguously - to refer to either +an actual dimension or the size of a dimension. Most of the time, the exact +meaning of the word "dimension" is clear from the context, but sometimes it's +ambiguous. For example, it is not obvious whether +`mlir::ShapedTypeComponents::getDims` returns dimension numbers or dimension +sizes (for what it's worth, it is the latter - it returns dimension sizes). + +To avoid ambiguity, this document will be using the following terminology: + +* **Dimension numbers** (or **axes**) to refer to actual dimensions, + e.g. "`tensor<16x?xf32>` has two axes". +* **Dimension sizes** (or **sizes**) to refer to sizes of dimensions, + e.g. "`tensor<16x?xf32>` has the following dimension sizes: 16 and unknown". +* Unqualified "dimension" will not be used at all. + +By the virtue of being bootstrapped from MHLO, the StableHLO dialect already +has support for **dynamism within types**: + +* **Unbounded dynamism**: In a tensor type, some or all dimension sizes may be + unknown (aka "dynamic"). In MLIR syntax, these dimension sizes are expressed + via question marks. +* **Bounded dynamism**: The same but some of dynamic dimensions sizes may have + known upper bounds. In MLIR syntax, these dimension sizes are expressed via + question marks, and bounds are expressed via `#stablehlo.bounds`. +* **Unranked dynamism**: In a tensor type, rank may be unknown. In MLIR + syntax, this fact is expressed via asterisks. + +```mlir +// Static shapes: +// All dimension sizes are known. +%0 = stablehlo.add %arg0, %arg1 : tensor<16x16xf32> + +// Unbounded dynamism: +// First dimension size is unknown (?). +// Second dimension size is known (16). +%1 = stablehlo.add %arg0, %arg1 : tensor + +// Bounded dynamism: +// First dimension size is unknown (?), but its bound is known (16). +// Second dimension size is known (16), so its bound is N/A (?). +%2 = stablehlo.add %arg0, %arg1 : tensor> + +// Unranked dynamism: +// The rank is unknown (*). +%3 = stablehlo.add %arg0, %arg1 : tensor<*xf32> +``` + +Similarly, the StableHLO dialect already has support for **dynamism within +operations**, with some ops such as PadOp having dynamic counterparts such as +DynamicPadOp. For example: + +```mlir +// "vanilla" PadOp: +// low, high and interior paddings are implemented via MLIR attributes, +// i.e. they are static. +%0 = stablehlo.pad %arg0, %arg1, low = [1, 1], high = [1, 1], interior = [0, 0] + : (tensor<16x16xf32>, tensor) -> tensor<20x20xf32> + +// DynamicPadOp: +// low, high and interior paddings are implemented via MLIR operands, +// i.e. they are dynamic. +%1 = stablehlo.dynamic_pad %arg0, %arg1, %arg2, %arg3, %arg4 : + : (tensor, tensor, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor +``` + +Finally, dynamism (both within types and within operations) creates new +opportunities for errors. For example, the semantics of elementwise operations +like `add` are only defined when inputs and outputs have the same shape. +For the static shape case, this can be checked **at compile time** (i.e. before +the execution of the StableHLO program). However, when dynamism is involved, +this can only be checked **at run time**. + +If some operation doesn't make sense at run time because some unknown ranks +and/or unknown dimension sizes turned out to be incompatible with the operation +and/or with each other, an error condition called **"shape mismatch"** occurs. +In order to guard against shape mismatches, StableHLO programs may employ +**shape checks**. + +## Community feedback + +The current design of dynamism in MHLO and StableHLO has been practically +useful. There are success stories of it connecting JAX, PyTorch and TensorFlow +to a number of compilers, in a mix of research and production environments. +However, this design can be improved. Over the years of using this design in +practice, the community has provided the following feedback which is summarized +below (see [#8](https://github.com/openxla/stablehlo/issues/8) for the full +version): + + **(F1)** Unranked dynamism introduces considerable complexity to +StableHLO, but the only user of unranked dynamism in StableHLO/MHLO +appears to be TensorFlow's KernelGen toolchain, and KernelGen can encapsulate +handling unranked dynamism in a different layer. Unranked tensors shouldn't be +needed in StableHLO. + + **(F2)** Having different code paths for producing/consuming +static and dynamic ops (e.g. PadOp vs DynamicPadOp in the example above) is a +testing/maintenance/cognitive burden. + + **(F3)** Dynamism within operations is modelled inconsistently. +Some ops have two versions (e.g. PadOp and DynamicPadOp), some ops have three +versions (e.g. SliceOp, DynamicSliceOp and RealDynamicSliceOp), and some ops +don't have dynamic versions even though there are use cases for them (e.g. +ReduceWindowOp). + + **(F4)** Specifying dynamic sizes as 1-dimensional tensors +introduces considerable complexity because many adjacent abstractions and +typical shape computations operate in terms of scalars. The tensor +representation is needed to support unranked dynamism, but since unranked +tensors shouldn't be needed in StableHLO (see above), shape tensors shouldn't be +needed either. + + **(F5)** Shape computations are an essential part of dynamic +programs, and there are currently multiple approaches to doing these +computations in StableHLO programs. Some of these approaches involve only +operations from the StableHLO dialect, some of these approaches use operations +from other dialects including `arith` and `shape`. This inconsistency affects +user experience (producers are uncertain which approach to use, consumers are +uncertain which approaches to support) and presents compatibility issues +(StableHLO project can only provide compatibility guarantees for the StableHLO +opset). + + **(F6)** Dimension numbers and dimension sizes are modelled +inconsistently. For example, PadOp represents paddings as arrays of `i64`. +However, DynamicPadOp takes all sorts of paddings - tensors of `index`, +tensors of `i32`, tensors of `i64`, etc. + + **(F7)** There is alignment on shape mismatches being undefined +behavior, but there are multiple schools of thought on how this undefined +behavior should be guarded against (use `shape` dialect, use asserts, don't do +anything). + + **(F8)**. Even though StableHLO dynamism is used fairly actively, +it is not reflected in the StableHLO specification. As a result, dynamism +semantics and especially constraints are underspecified. + + **(F9)**. It is important for StableHLO programs to be hardware +agnostic. As such the importance of having machinery which can refine dynamic +programs into static programs for compilation for backends that require static +shapes is of high important. + +## Out of scope + +This RFC aims to address a considerable part of community feedback on the +existing design of dynamism in MHLO and StableHLO, but some feedback is +deliberately out of scope because further discussion is needed to obtain +alignment in the corresponding areas. + + **(O1)** There is promising related work on alternative +representations for dynamism, e.g. involving symbols and formulas in modeling +dynamic sizes rather than representing dynamic sizes as question marks. These +representations are more expressive than the design proposed in this RFC, and +they have the potential to solve the problem of shape mismatches, to convey +additional information to consumers improving overall performance, etc. + +However, this is a very new area for the OpenXLA stack, so a considerable +amount of design exploration will be needed before these ideas can be turned +into concrete proposals. Meanwhile, this RFC is focused on making design +improvements which have been already socialized within the community. + + **(O2)** Shape computations are considerably improved by this RFC, +but not all of the feedback from [F5](#f5)/[F6](#f6) is addressed. + +More specifically, this RFC proposes to represent shape computations as +StableHLO operations on 0-dimensional tensors, which is a practically important +simplification for StableHLO programs and their producers/consumers. +However, interoperability with other dialects, including `arith` and `shape` +which are oftentimes used by the community for shape computations, is out of +scope, and [P5](#p5) goes into details about the rationale. + + **(O3)** Working out a common solution for shape checks would be +nice. Per [F7](#f7), there are multiple incompatible approaches, so a common +solution would result in a simplification of the overall stack. + +However, based on author's experience, alignment in this area requires +non-trivial effort (different approaches have different benefits, plus they have +limited interoperability), so this is left for future work. + + **(O4)** Shape inference is an essential part of the StableHLO +dialect, and dynamism non-trivially affects it by complicating the API (without +dynamism, almost all ops can unambiguously infer their output types from their +inputs; with dynamism, this is not the case) and complicating the implementation +(checking constraints become more involved). + +Much of this is already implemented in the StableHLO dialect (e.g. verifiers and +shape functions already support dynamism in a fairly robust manner, although +some open questions still remain), but formalizing this area is left for future +work, because that would require formalizing the general notion of shape +inference which represents a significant amount of work. + + **(O5)** [Value inference](https://github.com/tensorflow/tensorflow/blob/cc71ba69fa9d25a21009b6e377f3dc3d1323aa6c/tensorflow/compiler/xla/client/value_inference.h) +is an algorithm implemented in the XLA compiler. It "analyzes values in XlaOp +answers following questions: 1) What's the upper-bound of each value in a +tensor, 2) What's the lower-bound of each value in a tensor, 3) What's the +constant value of each tensor, 4) Whether or not each value in a tensor +is dynamic". This algorithm is used in the TF/XLA bridge to automatically +compute bounds for bounded types in HLO programs. + +Similarly to shape inference, this area is out of scope for this RFC. +From the implementation standpoint, value inference is currently not available +for the StableHLO dialect, but this may change in the future depending on +community feedback. + + **(O6)** Shape specialization is a practically important problem +because not all consumers support dynamic shapes. In the StableHLO repository, +there are `--stablehlo-refine-shapes` and `--stablehlo-canonicalize-dynamism` +passes which address this problem. Furthermore, based on a semi-formal proof and +the experience with JAX native serialization, the author believes that these +passes are guaranteed to fully specialize dynamic StableHLO programs which only +involve shape polymorphism and have static arguments to static StableHLO +programs. + +However, similarly to formalizing shape inference, formalizing shape +specialization is a significant amount of work which is not on the critical +path to [StableHLO v1.0](../docs/roadmap.md), so it is left for future work. In +the meantime this refinement machinery and verification which ensures that +programs are refinable will be made available to frameworks and hardware teams. +Per the feedback in [F9](#f9), this RFC aims to only support shape dependent +dynamism. Anything beyond shape dependent dynamism will be left to future RFCs. + + **(O7)** Unifying dynamic and static op definitions was initially +proposed in this RFC, but per reviewer feedback has been taken out of scope. + +There are several operations in the StableHLO dialect which provide dynamic +versions of existing StableHLO operations. For example, `PadOp` defines +`edge_padding_low`, `edge_padding_high` and `interior_padding` as static +attributes, whereas `DynamicPadOp` has the same contract except that those +arguments are dynamic values. Unifying these ops may be nice, but prior to that +significant investigation to determine the usability of unified ops, as well as +the impact of having unified ops on DRR users needs to be investigated. In this +RFC [F2](#f2) remains unaddressed. + + +## (P1) Move TensorFlow-specific operations out of StableHLO + + +**Baseline:** The MHLO and CHLO dialects have been initially co-designed with +the TensorFlow dialect, the MLIR-based TF/XLA bridge and TensorFlow's +[KernelGen toolchain](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tools/kernel_gen). +This has been inherited by the StableHLO repository when the StableHLO dialect +was bootstrapped from MHLO and the CHLO dialect was moved to the StableHLO +repository. + +As a result, there are 2 StableHLO ops (`compute_reshape_shape` and +`cstr_reshapable`) as well as 4 CHLO ops (`dynamic_reshape`, +`minimum_broadcast_shapes`, `rank_specialization_cluster` and +`rank_specialization_cluster_yield`) which appear specific to TensorFlow, i.e. +it looks like they are only used in [legalize_tf.cc](https://github.com/tensorflow/tensorflow/blob/cf2e180455065ce718b1c5328014dd953b1fddc9/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc) +and [rank_specialization.cc](https://github.com/tensorflow/mlir-hlo/blob/cb944f56130eab16b746e680772305b006743006/mhlo/transforms/rank_specialization/rank_specialization.cc) +passes within the TensorFlow ecosystem. This specific nature of these ops does +not meet the bar for inclusion in the StableHLO repository. + +**Proposal:** This RFC proposes to immediately remove these 6 operations from +their current dialects and move them into a new dialect called `chlo_legacy`. + +StableHLO [doesn't guarantee](../docs/compatibility.md#out-of-scope) +compatibility for these operations (see the "Unspecced features" clause). +However, this RFC proposes to provide compatibility guarantees on exceptional +basis given that they don't appear to involve a lot of work. More specifically, +the proposal is to keep the `chlo_legacy` dialect in the StableHLO repository +for 6 months from the date of its creation and only then remove it. + + +## (P2) Ratify the existing convention for shape mismatches constituting undefined behavior + + +**Baseline:** As mentioned in [F7](#f7), there is broad consensus that shape +mismatches in StableHLO operations should constitute undefined behavior, i.e. +that guarding against shape mismatches should be made a producer responsibility. +This has been a subject of many informal conversations as well as one of the +discussion items +[at the StableHLO dynamism breakout session](https://drive.google.com/drive/u/1/folders/1fGqq8Tcebhcwq1KJqAZPDgXksRdqRvM2) +during the OpenXLA Dev Summit in April 2023. + +**Proposal:** This RFC proposes to ratify this convention as a StableHLO design +principle, so that the folklore consensus becomes codified in the StableHLO +specification. + +**Discussion:** A) The rationale for this proposal is that lower-level +abstractions (e.g. HLO or Linalg) typically don't want to concern themselves +with shape checks - they are focused on generating high-performance code, so +they want shape checks to be handled at a higher level, i.e. at StableHLO +or above. + +Furthermore, different producers have different requirements and different +preferences for expressing shape checks (e.g. JAX's type system enables it to +need fewer checks than TensorFlow's type system), so a specific way of +performing shape checks don't look like something that should be standardized +at the StableHLO level either. + +B\) From the implementation standpoint, this proposal would involve updating +[the "Errors" section](../docs/spec.md#errors) of the specification, and +changing the StableHLO dialect to implement the `ConditionallySpeculatable` +interface, prohibiting speculation for StableHLO ops that involve dynamism +([documentation](https://mlir.llvm.org/docs/Rationale/SideEffectsAndSpeculation/)). + + +## (P3) Ratify the existing convention for relaxed constraints already implemented in the StableHLO dialect + + +**Baseline:** At the moment, the StableHLO dialect supports +[relaxed constraints](https://github.com/openxla/stablehlo/blob/8993ad54839add6648b88801f1d223b7f9bc2e58/stablehlo/dialect/Base.cpp#L102-L120) +that were inherited from the MHLO dialect. For example, the code below is valid +in the StableHLO dialect: + +```mlir +func.func @main(%arg0: tensor, %arg1: tensor<1xf32>) { + %0 = stablehlo.add %arg0, %arg0 : (tensor, tensor) -> tensor + %1 = stablehlo.add %arg0, %arg0 : (tensor, tensor) -> tensor<1xf32> + %2 = stablehlo.add %arg0, %arg1 : (tensor, tensor<1xf32>) -> tensor + %3 = stablehlo.add %arg0, %arg1 : (tensor, tensor<1xf32>) -> tensor<1xf32> + %4 = stablehlo.add %arg1, %arg0 : (tensor<1xf32>, tensor) -> tensor + %5 = stablehlo.add %arg1, %arg0 : (tensor<1xf32>, tensor) -> tensor<1xf32> + %6 = stablehlo.add %arg1, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor + %7 = stablehlo.add %arg1, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return +} +``` + +More formally, these relaxed constraints generalize the constraints that are +already documented in the StableHLO specification. If a constraint for a +specific operation cannot be evaluated at compile time because it involves an +unknown rank or an unknown dimension size, it gets deferred until the run time +of the operation. If at that point the constraint fails, a shape mismatch +occurs, and [P2](#p2) discusses what should happen in that case. + +```mlir +// Static case - constraints I1-I5 and C1-C4 can be evaluated at compile time. +%0 = stablehlo.pad %arg0, %arg1, low = [1, 1], high = [1, 1], interior = [0, 0] + : (tensor<16x16xf32>, tensor) -> tensor<20x20xf32> + +// Dynamic case - constraints C2 and C4 cannot be evaluated at compile time. +// C2 depends on rank(operand) which is unknown. +// C4 depends on shape(operand) and shape(result) which are both unknown. +// These constraints are deferred until run time. +%1 = stablehlo.pad %arg0, %arg1, low = [1, 1], high = [1, 1], interior = [0, 0] + : (tensor<*xf32>, tensor) -> tensor + +// Dynamic case - constraints C3 and C4 cannot be evaluated at compile time. +// C3 depends on the value of interior_padding which is unknown. +// C4 depends on a number of shapes and values which are all unknown. +// Note that C2 can be evaluated at compile time - even though the values of +// edge_padding_low, edge_padding_high, interior_padding and operand are +// unknown, their size and rank are actually known. +%2 = stablehlo.dynamic_pad %arg0, %arg1, %arg2, %arg3, %arg4 : + : (tensor, tensor, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor +``` + +**Proposal:** A) This RFC proposes to ratify this convention as a StableHLO +design principle, given that: I) it allows the flexibility for producers to mix +static and dynamic elements in StableHLO programs at their convenience (which +has been proven to be practically useful, e.g. for JAX native serialization), +II\) it has a concise formulation that is demonstrably sound (it only affects +constraints, and no constraints are disregarded - they just move from compile +time to run time). + +B\) From the implementation standpoint, this proposal would involve updating +[the "Notation" section](../docs/spec.md#values) of the specification, and +auditing the existing verifiers and shape functions in the StableHLO dialect to +identify specification compliance issues, update [status.md](../docs/status.md) +accordingly and file +[Specification Compliance](https://github.com/orgs/openxla/projects/9) tickets. + + +## (P4) Enable shape-dependent dynamism for all size-related program elements but keep all axis-related program elements static + + +**Baseline:** The StableHLO dialect has inherited a considerable degree of +support for dynamism from the MHLO dialect. All MLIR operands can have dynamic +shapes, almost all MLIR results can have dynamic shapes and many +size-related MLIR attributes have dynamic counterparts. + +1\) As far as MLIR results go, for 7 operations in the StableHLO dialect - +`BroadcastInDimOp`, `ConstantOp`, `InfeedOp`, `IotaOp`, `RecvOp`, `ReshapeOp` +and `RngBitGeneratorOp` - the static shape of the results is "load-bearing", +i.e. allowing dynamic shapes there would not make sense as is. Operations that +do not have "load-bearing" result shapes can infer result shapes with static +operand shapes. + +```mlir +// Static result type - makes sense. +%0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<1xf32>) -> tensor<1x2xf32> + +// Dynamic result type - doesn't make sense as is. +// How does the operation know what result to produce? 1x1xf32? 1x2xf32? etc. +// Resolving this would need an additional argument - see below. +%1 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<1xf32>) -> tensor<1x?xf32> +``` + +2\) As far as MLIR attributes go, there are 9 operations in the StableHLO +dialect which have size-related attributes (operations are grouped together +using the categories from +[the StableHLO Ops spreadsheet](https://docs.google.com/spreadsheets/d/1rvhxQMFUtCZ5DsY6X0_lJOCg9rVO2MdyeZlRorsc0UI/edit?resourcekey=0-5gMjnlkXDL6hCntv2yltaQ#gid=0) +, operations from the Dynamism category are analyzed below, operations from the +Not in HLO category are not included because they are on their way out of the +StableHLO dialect): + +* Data Movement: `DynamicSliceOp`, `GatherOp`, `PadOp`, `ScatterOp`, + `SliceOp`. +* Miscellaneous: `FftOp`. +* Reduction: `ConvolutionOp`, `ReduceWindowOp`, `SelectAndScatterOp`. + +Furthermore, there are 18 operations in the StableHLO dialect which have +axis-related attributes (some of these operations overlap with the operations +from the previous section): + +* Data Movement: `BroadcastInDimOp`, `ConcatenateOp`, `GatherOp`, `ReverseOp`, + `ScatterOp`, `SortOp`, `TransposeOp`. +* Distribution: `AllGatherOp`, `AllToAllOp`, `ReduceScatterOp`. +* Elementwise: `MapOp`. +* Miscellaneous: `BatchNormGradOp`, `BatchNormInferenceOp`, + `BatchNormTrainingOp`, `IotaOp`. +* Reduction: `ConvolutionOp`, `DotGeneralOp`, `ReduceOp`. + +3\) Finally, there are 7 operations in the StableHLO +dialect which provide dynamic versions of existing StableHLO operations: +`DynamicBroadcastInDimOp`, `DynamicConvOp`, `DynamicGatherOp`, `DynamicIotaOp`, +`DynamicPadOp`, `DynamicReshapeOp` and `RealDynamicSliceOp`. Not all StableHLO +operations have dynamic counterparts, e.g. there is no `DynamicReduceWindowOp`. +Per the feedback in [F9](#f9), this RFC proposes to support shape-dependent uses +of these operations, which are refinable to be used by the entire StableHLO +ecosystem. In PyTorch and JAX, only shape-dependent uses of these operations +exist. Meaning, the the dynamic operand value to specify the result shapes is a +computation of the shape of on another operation. With this principle all +StableHLO programs with dynamism are refinable to static programs by providing +concrete input types. A refinement verification pass will be offered to ensure +in frameworks that a generated program is shape-dependent / refinable. + +```mlir +// The StableHLO dialect resolves the conundrum mentioned in the previous +// example by providing a dynamic version of BroadcastInDimOp which takes +// the shape of the result as an operand. +%0 = stablehlo.dynamic_broadcast_in_dim %arg0, %arg1, dims = [0] : + (tensor<1xf32>, tensor<2xi64>) -> tensor<1x?xf32> +``` + +Overall, as mentioned in [F3](#f3) and further elaborated above in this section, +dynamism within operations is modelled inconsistently. Some ops have two +versions (e.g. PadOp and DynamicPadOp), some ops have three versions (e.g. +SliceOp, DynamicSliceOp and RealDynamicSliceOp), and some ops don't have dynamic +versions even though there are use cases for them (e.g. ReduceWindowOp). + +**Proposal:** This RFC proposes to address the inconsistencies in the existing +support for dynamism in StableHLO summarized in +["Status of dynamic versions"](https://docs.google.com/spreadsheets/d/1rvhxQMFUtCZ5DsY6X0_lJOCg9rVO2MdyeZlRorsc0UI/edit?resourcekey=0-5gMjnlkXDL6hCntv2yltaQ#gid=335520762&fvid=75321273) +by establishing a concise design principle: + +1) If a program element is related to sizes, then it should be possible + to express it both statically and dynamically. +2) If a program element is related to axes, then it should only be possible + to express it statically. + +**Discussion:** A) This proposal builds on a broad consensus to treat dimension +sizes dynamically and axes statically, which has come up both in informal +conversations and at the OpenXLA Dev Summit. + +For 1), there's already a significant precedent to provide both static and +dynamic representations - which started in HLO and continued in MHLO +(plus, we have several feature requests from JAX: +[for ConvolutionOp](https://github.com/openxla/stablehlo/issues/1268), +[for FftOp](https://github.com/openxla/stablehlo/issues/1366), +[for ReduceWindowOp](https://github.com/openxla/stablehlo/issues/1258) and +[for RngBitGeneratorOp](https://github.com/openxla/stablehlo/issues/1344)). +This RFC proposes to extend this predecent to the entire opset, with the +rationale that the burden of adding a few extensions that were not previously +supported or requested (for ConstantOp, for InfeedOp and for RecvOp) is smaller +than the burden of having to define and maintain carveouts in the specification. + +For 2), the proposal is to keep them static because: A) there doesn't seem +to be precedent or feature requests to make them dynamic, B) when this topic +came up with several StableHLO consumers, there was considerable pushback since +that would considerably complicate code generation. + +B\) From the implementation standpoint, this proposal would involve updating +[the "Types" section](../docs/spec.md#types) of the specification to extend +the definition of `TensorType` to express **dynamism within types** and +accommodate both unbounded and bounded dynamism (but not unranked dynamism - +see [P5](#p5) for details). + +However, this section doesn't provide an opinion on what representations should +be used to express **dynamism within operations**. As mentioned in [F2](#f2), +the existing design where static ops like `PadOp` are accompanied with dynamic +ops like `DynamicPadOp` has some drawbacks, which are denoted in [O7](#o7) as +out of scope for this RFC. + + +## (P5) Represent shape computations as StableHLO operations on variadic 0-dimensional tensors and drop support for unranked dynamism + + +**Baseline:** The StableHLO dialect represents sizes as 1-dimensional tensors, +with static sizes expressed as `DenseI64ArrayAttr` and dynamic sizes +expressed as `1DTensorOf<[HLO_DimensionValue]>`. + +```c++ +class PadOp ... { + // This is how static sizes are represented in TableGen: + // DenseI64ArrayAttr:$edge_padding_low + // DenseI64ArrayAttr:$edge_padding_high + // DenseI64ArrayAttr:$interior_padding + DenseI64ArrayAttr getEdgePaddingLow(); + DenseI64ArrayAttr getEdgePaddingHigh(); + DenseI64ArrayAttr getInteriorPadding(); +} + +class DynamicPadOp ... { + // This is how dynamic sizes are represented in TableGen: + // HLO_DimensionTensor:$edge_padding_low + // HLO_DimensionTensor:$edge_padding_high + // HLO_DimensionTensor:$interior_padding + TypedValue getEdgePaddingLow(); + TypedValue getEdgePaddingHigh(); + TypedValue getInteriorPadding(); +} +``` + +As mentioned in [F5](#f5), shape computations in StableHLO programs produce +these 1-dimensional tensors using one of several approaches, including ops from +Arith, Shape, StableHLO as well as other dialects. E.g. addition of dimension +sizes can be represented via `stablehlo::AddOp`, `arith::AddIOp`, `shape::AddOp` +and in a few other ways. + +Furthermore, as mentioned in [F4](#f4), using 1-dimensional tensors introduces +considerable complexity because typical shape computations operate in terms of +scalars. As a result, in code that constructs dynamic StableHLO ops, it is not +unusual to encounter computations on 0-dimensional tensors whose results are +then reshaped to 1-dimensional tensors and then concatenated together +([example](https://github.com/tensorflow/mlir-hlo/blob/1fd13ef28d3f423363e59b1a80eb2c26e0c0979d/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L1577-L1587)). + +**Proposal:** The RFC proposes to standardize on representing shape computations +as StableHLO operations working 0-dimensional tensors, changing from using a +single operand of `tensor` to using `N` operands of `tensor`. More +specifically, the proposal is to: + +1) Affirm that only shape computations that use StableHLO operations are + supported in StableHLO portable artifacts (this can be changed in future + RFCs, but is out of scope for this one). +2) Use `Variadic<0DTensorOf>` instead of + `HLO_DimensionTensor` in the StableHLO dialect. +3) Drop support for unranked dynamism in StableHLO. + +```tablegen +def StableHLO_DynamicPadOp ... { + let arguments = (ins + HLO_Tensor:$operand, + HLO_Tensor:$padding_value, + Variadic<0DTensorOf>:$edge_padding_low, + Variadic<0DTensorOf>:$edge_padding_high, + Variadic<0DTensorOf>:$interior_padding + ); +} +``` + +In terms of a shape computation in a program, the following example roughly +correlates to CHLO's broadcasting of unbounded dimensions. Currently there are +additional `reshape` and `concatenate` operations that can be avoided by +changing to use 0D tensor values in shape computations. + +CHLO `broadcast_add(tensor, tensor)`` with output shape +specified using 1D tensors: + +```mlir +func.func @same_rank_broadcast_1D(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor + %1 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor) -> tensor + %2 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor) -> tensor + %3 = stablehlo.get_dimension_size %arg1, dim = 1 : (tensor) -> tensor + %4 = stablehlo.reshape %0 : (tensor) -> tensor<1xi32> + %5 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %6 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> + %7 = stablehlo.reshape %3 : (tensor) -> tensor<1xi32> + %8 = stablehlo.concatenate %4, %5, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %9 = stablehlo.concatenate %6, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %10 = stablehlo.maximum %8, %9 : tensor<2xi32> + %11 = stablehlo.dynamic_broadcast_in_dim %arg0, %10, dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %12 = stablehlo.dynamic_broadcast_in_dim %arg1, %10, dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %13 = stablehlo.add %11, %12 : tensor + return %13 : tensor +} +``` + +The same computation with output shape specified using a variadic number of +0D tensors: + +```mlir +func.func @same_rank_broadcast_0D(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor + %1 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor) -> tensor + %2 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor) -> tensor + %3 = stablehlo.get_dimension_size %arg1, dim = 1 : (tensor) -> tensor + %4 = stablehlo.maximum %0, %2 : tensor + %5 = stablehlo.maximum %1, %3 : tensor + %6 = stablehlo.dynamic_broadcast_in_dim %arg0, shape = [%4, %5], dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %7 = stablehlo.dynamic_broadcast_in_dim %arg1, shape = [%4, %5], dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %8 = stablehlo.add %6, %7 : tensor + return %8 : tensor +} +``` + +**Discussion:** A) This proposal reduces the expressiveness of shape +computations in StableHLO, since it removes the capability to pass around +shapes of dynamic size. Based on the conversations with StableHLO/MHLO users, +support for unranked dynamism appears to be the only usage of this capability. + +As mentioned in [F1](#f1), the only user of unranked dynamism in StableHLO/MHLO +appears to be TensorFlow, and TensorFlow can encapsulate handling unranked +dynamism in a different layer, so overall this proposal looks like an +improvement - it does require a non-trivial refactoring in one user, but in +return it simplifies the StableHLO specification and implementation for +everyone. The ops of concern in TF are `StridedSlice`, `Reshape`, and `Squeeze`, +and it appears that interop between these ops and StableHLO is not required. + +B\) As discussed in [O2](#o2), the community is oftentimes using other dialects, +including `arith` and `shape`, for shape computations. These dialects can be +used with StableHLO programs, but there are interoperability issues which could +be improved if: I) StableHLO used scalars instead of 0-dimensional tensors, +II\) StableHLO used `index` instead of allowing integer and index types. + +However both of these changes need a significant amount of work, so they +are not included in this RFC. I) requires extending the StableHLO opset with +scalar operations, which involves defining semantics for these operations within +StableHLO programs and getting buy-in from producers to support them. This is a +fairly novel notion for the StableHLO ecosystem, so some further design +exploration is needed here. II) is conceptually straightforward but will need +updating a lot of code.