diff --git a/rfcs/20230704-dynamism-101.md b/rfcs/20230704-dynamism-101.md new file mode 100644 index 00000000000..0b0942d2b0d --- /dev/null +++ b/rfcs/20230704-dynamism-101.md @@ -0,0 +1,662 @@ +# RFC: Dynamism 101 + +Status: Under review
+Initial version: 7/4/2023
+Last updated: 7/4/2023
+Discussion thread: [openxla-discuss](https://groups.google.com/a/openxla.org/g/openxla-discuss/c/HJRvFBum65k/m/7QtJxgB9AQAJ). + +## Summary + +This RFC aims to leverage existing support for dynamism in the StableHLO +dialect, discuss improvements to the existing design and then formalize the +improved design in the StableHLO opset. To that end, this document proposes +the following: + + +* [(P1) Move TensorFlow-specific operations out of StableHLO](#p1). +* [(P2) Ratify the existing convention for shape mismatches constituting undefined behavior](#p2). +* [(P3) Ratify the existing convention for relaxed constraints already implemented in the StableHLO dialect](#p3). +* [(P4) Enable shape-dependent dynamism for all size-related program elements but keep all axis-related program elements static](#p4). +* [(P5) Represent shape computations as StableHLO operations on variadic 0-dimensional tensors and drop support for unranked dynamism](#p5). + + +These proposals address a considerable chunk of community feedback on the +existing design, but some feedback is deliberately left out of scope for this +RFC to enable incremental progress while areas which require additional +alignment are developed in parallel. See sections +["Community feedback"](#community-feedback) and +["Out of scope"](#out-of-scope) for details. + +## Existing design + +This is an RFC that affects the entirety of the StableHLO opset, and there are +multiple groups of operations which are affected differently. Taking into +account all these operations appropriately was quite laborious, but +[the StableHLO Ops spreadsheet](https://docs.google.com/spreadsheets/d/1rvhxQMFUtCZ5DsY6X0_lJOCg9rVO2MdyeZlRorsc0UI/edit?resourcekey=0-5gMjnlkXDL6hCntv2yltaQ#gid=0) +was a big help. Referring to this spreadsheet will likely help in reviewing this +RFC as well. + +Before beginning, let's align on terminology. In discussions around MLIR, it is +fairly conventional to use the word "dimension" ambiguously - to refer to either +an actual dimension or the size of a dimension. Most of the time, the exact +meaning of the word "dimension" is clear from the context, but sometimes it's +ambiguous. For example, it is not obvious whether +`mlir::ShapedTypeComponents::getDims` returns dimension numbers or dimension +sizes (for what it's worth, it is the latter - it returns dimension sizes). + +To avoid ambiguity, this document will be using the following terminology: + +* **Dimension numbers** (or **axes**) to refer to actual dimensions, + e.g. "`tensor<16x?xf32>` has two axes". +* **Dimension sizes** (or **sizes**) to refer to sizes of dimensions, + e.g. "`tensor<16x?xf32>` has the following dimension sizes: 16 and unknown". +* Unqualified "dimension" will not be used at all. + +By the virtue of being bootstrapped from MHLO, the StableHLO dialect already +has support for **dynamism within types**: + +* **Unbounded dynamism**: In a tensor type, some or all dimension sizes may be + unknown (aka "dynamic"). In MLIR syntax, these dimension sizes are expressed + via question marks. +* **Bounded dynamism**: The same but some of dynamic dimensions sizes may have + known upper bounds. In MLIR syntax, these dimension sizes are expressed via + question marks, and bounds are expressed via `#stablehlo.bounds`. +* **Unranked dynamism**: In a tensor type, rank may be unknown. In MLIR + syntax, this fact is expressed via asterisks. + +```mlir +// Static shapes: +// All dimension sizes are known. +%0 = stablehlo.add %arg0, %arg1 : tensor<16x16xf32> + +// Unbounded dynamism: +// First dimension size is unknown (?). +// Second dimension size is known (16). +%1 = stablehlo.add %arg0, %arg1 : tensor + +// Bounded dynamism: +// First dimension size is unknown (?), but its bound is known (16). +// Second dimension size is known (16), so its bound is N/A (?). +%2 = stablehlo.add %arg0, %arg1 : tensor> + +// Unranked dynamism: +// The rank is unknown (*). +%3 = stablehlo.add %arg0, %arg1 : tensor<*xf32> +``` + +Similarly, the StableHLO dialect already has support for **dynamism within +operations**, with some ops such as PadOp having dynamic counterparts such as +DynamicPadOp. For example: + +```mlir +// "vanilla" PadOp: +// low, high and interior paddings are implemented via MLIR attributes, +// i.e. they are static. +%0 = stablehlo.pad %arg0, %arg1, low = [1, 1], high = [1, 1], interior = [0, 0] + : (tensor<16x16xf32>, tensor) -> tensor<20x20xf32> + +// DynamicPadOp: +// low, high and interior paddings are implemented via MLIR operands, +// i.e. they are dynamic. +%1 = stablehlo.dynamic_pad %arg0, %arg1, %arg2, %arg3, %arg4 : + : (tensor, tensor, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor +``` + +Finally, dynamism (both within types and within operations) creates new +opportunities for errors. For example, the semantics of elementwise operations +like `add` are only defined when inputs and outputs have the same shape. +For the static shape case, this can be checked **at compile time** (i.e. before +the execution of the StableHLO program). However, when dynamism is involved, +this can only be checked **at run time**. + +If some operation doesn't make sense at run time because some unknown ranks +and/or unknown dimension sizes turned out to be incompatible with the operation +and/or with each other, an error condition called **"shape mismatch"** occurs. +In order to guard against shape mismatches, StableHLO programs may employ +**shape checks**. + +## Community feedback + +The current design of dynamism in MHLO and StableHLO has been practically +useful. There are success stories of it connecting JAX, PyTorch and TensorFlow +to a number of compilers, in a mix of research and production environments. +However, this design can be improved. Over the years of using this design in +practice, the community has provided the following feedback which is summarized +below (see [#8](https://github.com/openxla/stablehlo/issues/8) for the full +version): + + **(F1)** Unranked dynamism introduces considerable complexity to +StableHLO, but the only user of unranked dynamism in StableHLO/MHLO +appears to be TensorFlow's KernelGen toolchain, and KernelGen can encapsulate +handling unranked dynamism in a different layer. Unranked tensors shouldn't be +needed in StableHLO. + + **(F2)** Having different code paths for producing/consuming +static and dynamic ops (e.g. PadOp vs DynamicPadOp in the example above) is a +testing/maintenance/cognitive burden. + + **(F3)** Dynamism within operations is modelled inconsistently. +Some ops have two versions (e.g. PadOp and DynamicPadOp), some ops have three +versions (e.g. SliceOp, DynamicSliceOp and RealDynamicSliceOp), and some ops +don't have dynamic versions even though there are use cases for them (e.g. +ReduceWindowOp). + + **(F4)** Specifying dynamic sizes as 1-dimensional tensors +introduces considerable complexity because many adjacent abstractions and +typical shape computations operate in terms of scalars. The tensor +representation is needed to support unranked dynamism, but since unranked +tensors shouldn't be needed in StableHLO (see above), shape tensors shouldn't be +needed either. + + **(F5)** Shape computations are an essential part of dynamic +programs, and there are currently multiple approaches to doing these +computations in StableHLO programs. Some of these approaches involve only +operations from the StableHLO dialect, some of these approaches use operations +from other dialects including `arith` and `shape`. This inconsistency affects +user experience (producers are uncertain which approach to use, consumers are +uncertain which approaches to support) and presents compatibility issues +(StableHLO project can only provide compatibility guarantees for the StableHLO +opset). + + **(F6)** Dimension numbers and dimension sizes are modelled +inconsistently. For example, PadOp represents paddings as arrays of `i64`. +However, DynamicPadOp takes all sorts of paddings - tensors of `index`, +tensors of `i32`, tensors of `i64`, etc. + + **(F7)** There is alignment on shape mismatches being undefined +behavior, but there are multiple schools of thought on how this undefined +behavior should be guarded against (use `shape` dialect, use asserts, don't do +anything). + + **(F8)**. Even though StableHLO dynamism is used fairly actively, +it is not reflected in the StableHLO specification. As a result, dynamism +semantics and especially constraints are underspecified. + + **(F9)**. It is important for StableHLO programs to be hardware +agnostic. As such the importance of having machinery which can refine dynamic +programs into static programs for compilation for backends that require static +shapes is of high important. + +## Out of scope + +This RFC aims to address a considerable part of community feedback on the +existing design of dynamism in MHLO and StableHLO, but some feedback is +deliberately out of scope because further discussion is needed to obtain +alignment in the corresponding areas. + + **(O1)** There is promising related work on alternative +representations for dynamism, e.g. involving symbols and formulas in modeling +dynamic sizes rather than representing dynamic sizes as question marks. These +representations are more expressive than the design proposed in this RFC, and +they have the potential to solve the problem of shape mismatches, to convey +additional information to consumers improving overall performance, etc. + +However, this is a very new area for the OpenXLA stack, so a considerable +amount of design exploration will be needed before these ideas can be turned +into concrete proposals. Meanwhile, this RFC is focused on making design +improvements which have been already socialized within the community. + + **(O2)** Shape computations are considerably improved by this RFC, +but not all of the feedback from [F5](#f5)/[F6](#f6) is addressed. + +More specifically, this RFC proposes to represent shape computations as +StableHLO operations on 0-dimensional tensors, which is a practically important +simplification for StableHLO programs and their producers/consumers. +However, interoperability with other dialects, including `arith` and `shape` +which are oftentimes used by the community for shape computations, is out of +scope, and [P5](#p5) goes into details about the rationale. + + **(O3)** Working out a common solution for shape checks would be +nice. Per [F7](#f7), there are multiple incompatible approaches, so a common +solution would result in a simplification of the overall stack. + +However, based on author's experience, alignment in this area requires +non-trivial effort (different approaches have different benefits, plus they have +limited interoperability), so this is left for future work. + + **(O4)** Shape inference is an essential part of the StableHLO +dialect, and dynamism non-trivially affects it by complicating the API (without +dynamism, almost all ops can unambiguously infer their output types from their +inputs; with dynamism, this is not the case) and complicating the implementation +(checking constraints become more involved). + +Much of this is already implemented in the StableHLO dialect (e.g. verifiers and +shape functions already support dynamism in a fairly robust manner, although +some open questions still remain), but formalizing this area is left for future +work, because that would require formalizing the general notion of shape +inference which represents a significant amount of work. + + **(O5)** [Value inference](https://github.com/tensorflow/tensorflow/blob/cc71ba69fa9d25a21009b6e377f3dc3d1323aa6c/tensorflow/compiler/xla/client/value_inference.h) +is an algorithm implemented in the XLA compiler. It "analyzes values in XlaOp +answers following questions: 1) What's the upper-bound of each value in a +tensor, 2) What's the lower-bound of each value in a tensor, 3) What's the +constant value of each tensor, 4) Whether or not each value in a tensor +is dynamic". This algorithm is used in the TF/XLA bridge to automatically +compute bounds for bounded types in HLO programs. + +Similarly to shape inference, this area is out of scope for this RFC. +From the implementation standpoint, value inference is currently not available +for the StableHLO dialect, but this may change in the future depending on +community feedback. + + **(O6)** Shape specialization is a practically important problem +because not all consumers support dynamic shapes. In the StableHLO repository, +there are `--stablehlo-refine-shapes` and `--stablehlo-canonicalize-dynamism` +passes which address this problem. Furthermore, based on a semi-formal proof and +the experience with JAX native serialization, the author believes that these +passes are guaranteed to fully specialize dynamic StableHLO programs which only +involve shape polymorphism and have static arguments to static StableHLO +programs. + +However, similarly to formalizing shape inference, formalizing shape +specialization is a significant amount of work which is not on the critical +path to [StableHLO v1.0](../docs/roadmap.md), so it is left for future work. In +the meantime this refinement machinery and verification which ensures that +programs are refinable will be made available to frameworks and hardware teams. +Per the feedback in [F9](#f9), this RFC aims to only support shape dependent +dynamism. Anything beyond shape dependent dynamism will be left to future RFCs. + + **(O7)** Unifying dynamic and static op definitions was initially +proposed in this RFC, but per reviewer feedback has been taken out of scope. + +There are several operations in the StableHLO dialect which provide dynamic +versions of existing StableHLO operations. For example, `PadOp` defines +`edge_padding_low`, `edge_padding_high` and `interior_padding` as static +attributes, whereas `DynamicPadOp` has the same contract except that those +arguments are dynamic values. Unifying these ops may be nice, but prior to that +significant investigation to determine the usability of unified ops, as well as +the impact of having unified ops on DRR users needs to be investigated. In this +RFC [F2](#f2) remains unaddressed. + + +## (P1) Move TensorFlow-specific operations out of StableHLO + + +**Baseline:** The MHLO and CHLO dialects have been initially co-designed with +the TensorFlow dialect, the MLIR-based TF/XLA bridge and TensorFlow's +[KernelGen toolchain](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tools/kernel_gen). +This has been inherited by the StableHLO repository when the StableHLO dialect +was bootstrapped from MHLO and the CHLO dialect was moved to the StableHLO +repository. + +As a result, there are 2 StableHLO ops (`compute_reshape_shape` and +`cstr_reshapable`) as well as 4 CHLO ops (`dynamic_reshape`, +`minimum_broadcast_shapes`, `rank_specialization_cluster` and +`rank_specialization_cluster_yield`) which appear specific to TensorFlow, i.e. +it looks like they are only used in [legalize_tf.cc](https://github.com/tensorflow/tensorflow/blob/cf2e180455065ce718b1c5328014dd953b1fddc9/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc) +and [rank_specialization.cc](https://github.com/tensorflow/mlir-hlo/blob/cb944f56130eab16b746e680772305b006743006/mhlo/transforms/rank_specialization/rank_specialization.cc) +passes within the TensorFlow ecosystem. This specific nature of these ops does +not meet the bar for inclusion in the StableHLO repository. + +**Proposal:** This RFC proposes to immediately remove these 6 operations from +their current dialects and move them into a new dialect called `chlo_legacy`. + +StableHLO [doesn't guarantee](../docs/compatibility.md#out-of-scope) +compatibility for these operations (see the "Unspecced features" clause). +However, this RFC proposes to provide compatibility guarantees on exceptional +basis given that they don't appear to involve a lot of work. More specifically, +the proposal is to keep the `chlo_legacy` dialect in the StableHLO repository +for 6 months from the date of its creation and only then remove it. + + +## (P2) Ratify the existing convention for shape mismatches constituting undefined behavior + + +**Baseline:** As mentioned in [F7](#f7), there is broad consensus that shape +mismatches in StableHLO operations should constitute undefined behavior, i.e. +that guarding against shape mismatches should be made a producer responsibility. +This has been a subject of many informal conversations as well as one of the +discussion items +[at the StableHLO dynamism breakout session](https://drive.google.com/drive/u/1/folders/1fGqq8Tcebhcwq1KJqAZPDgXksRdqRvM2) +during the OpenXLA Dev Summit in April 2023. + +**Proposal:** This RFC proposes to ratify this convention as a StableHLO design +principle, so that the folklore consensus becomes codified in the StableHLO +specification. + +**Discussion:** A) The rationale for this proposal is that lower-level +abstractions (e.g. HLO or Linalg) typically don't want to concern themselves +with shape checks - they are focused on generating high-performance code, so +they want shape checks to be handled at a higher level, i.e. at StableHLO +or above. + +Furthermore, different producers have different requirements and different +preferences for expressing shape checks (e.g. JAX's type system enables it to +need fewer checks than TensorFlow's type system), so a specific way of +performing shape checks don't look like something that should be standardized +at the StableHLO level either. + +B\) From the implementation standpoint, this proposal would involve updating +[the "Errors" section](../docs/spec.md#errors) of the specification, and +changing the StableHLO dialect to implement the `ConditionallySpeculatable` +interface, prohibiting speculation for StableHLO ops that involve dynamism +([documentation](https://mlir.llvm.org/docs/Rationale/SideEffectsAndSpeculation/)). + + +## (P3) Ratify the existing convention for relaxed constraints already implemented in the StableHLO dialect + + +**Baseline:** At the moment, the StableHLO dialect supports +[relaxed constraints](https://github.com/openxla/stablehlo/blob/8993ad54839add6648b88801f1d223b7f9bc2e58/stablehlo/dialect/Base.cpp#L102-L120) +that were inherited from the MHLO dialect. For example, the code below is valid +in the StableHLO dialect: + +```mlir +func.func @main(%arg0: tensor, %arg1: tensor<1xf32>) { + %0 = stablehlo.add %arg0, %arg0 : (tensor, tensor) -> tensor + %1 = stablehlo.add %arg0, %arg0 : (tensor, tensor) -> tensor<1xf32> + %2 = stablehlo.add %arg0, %arg1 : (tensor, tensor<1xf32>) -> tensor + %3 = stablehlo.add %arg0, %arg1 : (tensor, tensor<1xf32>) -> tensor<1xf32> + %4 = stablehlo.add %arg1, %arg0 : (tensor<1xf32>, tensor) -> tensor + %5 = stablehlo.add %arg1, %arg0 : (tensor<1xf32>, tensor) -> tensor<1xf32> + %6 = stablehlo.add %arg1, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor + %7 = stablehlo.add %arg1, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return +} +``` + +More formally, these relaxed constraints generalize the constraints that are +already documented in the StableHLO specification. If a constraint for a +specific operation cannot be evaluated at compile time because it involves an +unknown rank or an unknown dimension size, it gets deferred until the run time +of the operation. If at that point the constraint fails, a shape mismatch +occurs, and [P2](#p2) discusses what should happen in that case. + +```mlir +// Static case - constraints I1-I5 and C1-C4 can be evaluated at compile time. +%0 = stablehlo.pad %arg0, %arg1, low = [1, 1], high = [1, 1], interior = [0, 0] + : (tensor<16x16xf32>, tensor) -> tensor<20x20xf32> + +// Dynamic case - constraints C2 and C4 cannot be evaluated at compile time. +// C2 depends on rank(operand) which is unknown. +// C4 depends on shape(operand) and shape(result) which are both unknown. +// These constraints are deferred until run time. +%1 = stablehlo.pad %arg0, %arg1, low = [1, 1], high = [1, 1], interior = [0, 0] + : (tensor<*xf32>, tensor) -> tensor + +// Dynamic case - constraints C3 and C4 cannot be evaluated at compile time. +// C3 depends on the value of interior_padding which is unknown. +// C4 depends on a number of shapes and values which are all unknown. +// Note that C2 can be evaluated at compile time - even though the values of +// edge_padding_low, edge_padding_high, interior_padding and operand are +// unknown, their size and rank are actually known. +%2 = stablehlo.dynamic_pad %arg0, %arg1, %arg2, %arg3, %arg4 : + : (tensor, tensor, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor +``` + +**Proposal:** A) This RFC proposes to ratify this convention as a StableHLO +design principle, given that: I) it allows the flexibility for producers to mix +static and dynamic elements in StableHLO programs at their convenience (which +has been proven to be practically useful, e.g. for JAX native serialization), +II\) it has a concise formulation that is demonstrably sound (it only affects +constraints, and no constraints are disregarded - they just move from compile +time to run time). + +B\) From the implementation standpoint, this proposal would involve updating +[the "Notation" section](../docs/spec.md#values) of the specification, and +auditing the existing verifiers and shape functions in the StableHLO dialect to +identify specification compliance issues, update [status.md](../docs/status.md) +accordingly and file +[Specification Compliance](https://github.com/orgs/openxla/projects/9) tickets. + + +## (P4) Enable shape-dependent dynamism for all size-related program elements but keep all axis-related program elements static + + +**Baseline:** The StableHLO dialect has inherited a considerable degree of +support for dynamism from the MHLO dialect. All MLIR operands can have dynamic +shapes, almost all MLIR results can have dynamic shapes and many +size-related MLIR attributes have dynamic counterparts. + +1\) As far as MLIR results go, for 7 operations in the StableHLO dialect - +`BroadcastInDimOp`, `ConstantOp`, `InfeedOp`, `IotaOp`, `RecvOp`, `ReshapeOp` +and `RngBitGeneratorOp` - the static shape of the results is "load-bearing", +i.e. allowing dynamic shapes there would not make sense as is. Operations that +do not have "load-bearing" result shapes can infer result shapes with static +operand shapes. + +```mlir +// Static result type - makes sense. +%0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<1xf32>) -> tensor<1x2xf32> + +// Dynamic result type - doesn't make sense as is. +// How does the operation know what result to produce? 1x1xf32? 1x2xf32? etc. +// Resolving this would need an additional argument - see below. +%1 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<1xf32>) -> tensor<1x?xf32> +``` + +2\) As far as MLIR attributes go, there are 9 operations in the StableHLO +dialect which have size-related attributes (operations are grouped together +using the categories from +[the StableHLO Ops spreadsheet](https://docs.google.com/spreadsheets/d/1rvhxQMFUtCZ5DsY6X0_lJOCg9rVO2MdyeZlRorsc0UI/edit?resourcekey=0-5gMjnlkXDL6hCntv2yltaQ#gid=0) +, operations from the Dynamism category are analyzed below, operations from the +Not in HLO category are not included because they are on their way out of the +StableHLO dialect): + +* Data Movement: `DynamicSliceOp`, `GatherOp`, `PadOp`, `ScatterOp`, + `SliceOp`. +* Miscellaneous: `FftOp`. +* Reduction: `ConvolutionOp`, `ReduceWindowOp`, `SelectAndScatterOp`. + +Furthermore, there are 18 operations in the StableHLO dialect which have +axis-related attributes (some of these operations overlap with the operations +from the previous section): + +* Data Movement: `BroadcastInDimOp`, `ConcatenateOp`, `GatherOp`, `ReverseOp`, + `ScatterOp`, `SortOp`, `TransposeOp`. +* Distribution: `AllGatherOp`, `AllToAllOp`, `ReduceScatterOp`. +* Elementwise: `MapOp`. +* Miscellaneous: `BatchNormGradOp`, `BatchNormInferenceOp`, + `BatchNormTrainingOp`, `IotaOp`. +* Reduction: `ConvolutionOp`, `DotGeneralOp`, `ReduceOp`. + +3\) Finally, there are 7 operations in the StableHLO +dialect which provide dynamic versions of existing StableHLO operations: +`DynamicBroadcastInDimOp`, `DynamicConvOp`, `DynamicGatherOp`, `DynamicIotaOp`, +`DynamicPadOp`, `DynamicReshapeOp` and `RealDynamicSliceOp`. Not all StableHLO +operations have dynamic counterparts, e.g. there is no `DynamicReduceWindowOp`. +Per the feedback in [F9](#f9), this RFC proposes to support shape-dependent uses +of these operations, which are refinable to be used by the entire StableHLO +ecosystem. In PyTorch and JAX, only shape-dependent uses of these operations +exist. Meaning, the the dynamic operand value to specify the result shapes is a +computation of the shape of on another operation. With this principle all +StableHLO programs with dynamism are refinable to static programs by providing +concrete input types. A refinement verification pass will be offered to ensure +in frameworks that a generated program is shape-dependent / refinable. + +```mlir +// The StableHLO dialect resolves the conundrum mentioned in the previous +// example by providing a dynamic version of BroadcastInDimOp which takes +// the shape of the result as an operand. +%0 = stablehlo.dynamic_broadcast_in_dim %arg0, %arg1, dims = [0] : + (tensor<1xf32>, tensor<2xi64>) -> tensor<1x?xf32> +``` + +Overall, as mentioned in [F3](#f3) and further elaborated above in this section, +dynamism within operations is modelled inconsistently. Some ops have two +versions (e.g. PadOp and DynamicPadOp), some ops have three versions (e.g. +SliceOp, DynamicSliceOp and RealDynamicSliceOp), and some ops don't have dynamic +versions even though there are use cases for them (e.g. ReduceWindowOp). + +**Proposal:** This RFC proposes to address the inconsistencies in the existing +support for dynamism in StableHLO summarized in +["Status of dynamic versions"](https://docs.google.com/spreadsheets/d/1rvhxQMFUtCZ5DsY6X0_lJOCg9rVO2MdyeZlRorsc0UI/edit?resourcekey=0-5gMjnlkXDL6hCntv2yltaQ#gid=335520762&fvid=75321273) +by establishing a concise design principle: + +1) If a program element is related to sizes, then it should be possible + to express it both statically and dynamically. +2) If a program element is related to axes, then it should only be possible + to express it statically. + +**Discussion:** A) This proposal builds on a broad consensus to treat dimension +sizes dynamically and axes statically, which has come up both in informal +conversations and at the OpenXLA Dev Summit. + +For 1), there's already a significant precedent to provide both static and +dynamic representations - which started in HLO and continued in MHLO +(plus, we have several feature requests from JAX: +[for ConvolutionOp](https://github.com/openxla/stablehlo/issues/1268), +[for FftOp](https://github.com/openxla/stablehlo/issues/1366), +[for ReduceWindowOp](https://github.com/openxla/stablehlo/issues/1258) and +[for RngBitGeneratorOp](https://github.com/openxla/stablehlo/issues/1344)). +This RFC proposes to extend this predecent to the entire opset, with the +rationale that the burden of adding a few extensions that were not previously +supported or requested (for ConstantOp, for InfeedOp and for RecvOp) is smaller +than the burden of having to define and maintain carveouts in the specification. + +For 2), the proposal is to keep them static because: A) there doesn't seem +to be precedent or feature requests to make them dynamic, B) when this topic +came up with several StableHLO consumers, there was considerable pushback since +that would considerably complicate code generation. + +B\) From the implementation standpoint, this proposal would involve updating +[the "Types" section](../docs/spec.md#types) of the specification to extend +the definition of `TensorType` to express **dynamism within types** and +accommodate both unbounded and bounded dynamism (but not unranked dynamism - +see [P5](#p5) for details). + +However, this section doesn't provide an opinion on what representations should +be used to express **dynamism within operations**. As mentioned in [F2](#f2), +the existing design where static ops like `PadOp` are accompanied with dynamic +ops like `DynamicPadOp` has some drawbacks, which are denoted in [O7](#o7) as +out of scope for this RFC. + + +## (P5) Represent shape computations as StableHLO operations on variadic 0-dimensional tensors and drop support for unranked dynamism + + +**Baseline:** The StableHLO dialect represents sizes as 1-dimensional tensors, +with static sizes expressed as `DenseI64ArrayAttr` and dynamic sizes +expressed as `1DTensorOf<[HLO_DimensionValue]>`. + +```c++ +class PadOp ... { + // This is how static sizes are represented in TableGen: + // DenseI64ArrayAttr:$edge_padding_low + // DenseI64ArrayAttr:$edge_padding_high + // DenseI64ArrayAttr:$interior_padding + DenseI64ArrayAttr getEdgePaddingLow(); + DenseI64ArrayAttr getEdgePaddingHigh(); + DenseI64ArrayAttr getInteriorPadding(); +} + +class DynamicPadOp ... { + // This is how dynamic sizes are represented in TableGen: + // HLO_DimensionTensor:$edge_padding_low + // HLO_DimensionTensor:$edge_padding_high + // HLO_DimensionTensor:$interior_padding + TypedValue getEdgePaddingLow(); + TypedValue getEdgePaddingHigh(); + TypedValue getInteriorPadding(); +} +``` + +As mentioned in [F5](#f5), shape computations in StableHLO programs produce +these 1-dimensional tensors using one of several approaches, including ops from +Arith, Shape, StableHLO as well as other dialects. E.g. addition of dimension +sizes can be represented via `stablehlo::AddOp`, `arith::AddIOp`, `shape::AddOp` +and in a few other ways. + +Furthermore, as mentioned in [F4](#f4), using 1-dimensional tensors introduces +considerable complexity because typical shape computations operate in terms of +scalars. As a result, in code that constructs dynamic StableHLO ops, it is not +unusual to encounter computations on 0-dimensional tensors whose results are +then reshaped to 1-dimensional tensors and then concatenated together +([example](https://github.com/tensorflow/mlir-hlo/blob/1fd13ef28d3f423363e59b1a80eb2c26e0c0979d/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L1577-L1587)). + +**Proposal:** The RFC proposes to standardize on representing shape computations +as StableHLO operations working 0-dimensional tensors, changing from using a +single operand of `tensor` to using `N` operands of `tensor`. More +specifically, the proposal is to: + +1) Affirm that only shape computations that use StableHLO operations are + supported in StableHLO portable artifacts (this can be changed in future + RFCs, but is out of scope for this one). +2) Use `Variadic<0DTensorOf>` instead of + `HLO_DimensionTensor` in the StableHLO dialect. +3) Drop support for unranked dynamism in StableHLO. + +```tablegen +def StableHLO_DynamicPadOp ... { + let arguments = (ins + HLO_Tensor:$operand, + HLO_Tensor:$padding_value, + Variadic<0DTensorOf>:$edge_padding_low, + Variadic<0DTensorOf>:$edge_padding_high, + Variadic<0DTensorOf>:$interior_padding + ); +} +``` + +In terms of a shape computation in a program, the following example roughly +correlates to CHLO's broadcasting of unbounded dimensions. Currently there are +additional `reshape` and `concatenate` operations that can be avoided by +changing to use 0D tensor values in shape computations. + +CHLO `broadcast_add(tensor, tensor)`` with output shape +specified using 1D tensors: + +```mlir +func.func @same_rank_broadcast_1D(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor + %1 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor) -> tensor + %2 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor) -> tensor + %3 = stablehlo.get_dimension_size %arg1, dim = 1 : (tensor) -> tensor + %4 = stablehlo.reshape %0 : (tensor) -> tensor<1xi32> + %5 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %6 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> + %7 = stablehlo.reshape %3 : (tensor) -> tensor<1xi32> + %8 = stablehlo.concatenate %4, %5, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %9 = stablehlo.concatenate %6, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %10 = stablehlo.maximum %8, %9 : tensor<2xi32> + %11 = stablehlo.dynamic_broadcast_in_dim %arg0, %10, dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %12 = stablehlo.dynamic_broadcast_in_dim %arg1, %10, dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %13 = stablehlo.add %11, %12 : tensor + return %13 : tensor +} +``` + +The same computation with output shape specified using a variadic number of +0D tensors: + +```mlir +func.func @same_rank_broadcast_0D(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor + %1 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor) -> tensor + %2 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor) -> tensor + %3 = stablehlo.get_dimension_size %arg1, dim = 1 : (tensor) -> tensor + %4 = stablehlo.maximum %0, %2 : tensor + %5 = stablehlo.maximum %1, %3 : tensor + %6 = stablehlo.dynamic_broadcast_in_dim %arg0, shape = [%4, %5], dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %7 = stablehlo.dynamic_broadcast_in_dim %arg1, shape = [%4, %5], dims = [0, 1] : (tensor, tensor<2xi32>) -> tensor + %8 = stablehlo.add %6, %7 : tensor + return %8 : tensor +} +``` + +**Discussion:** A) This proposal reduces the expressiveness of shape +computations in StableHLO, since it removes the capability to pass around +shapes of dynamic size. Based on the conversations with StableHLO/MHLO users, +support for unranked dynamism appears to be the only usage of this capability. + +As mentioned in [F1](#f1), the only user of unranked dynamism in StableHLO/MHLO +appears to be TensorFlow, and TensorFlow can encapsulate handling unranked +dynamism in a different layer, so overall this proposal looks like an +improvement - it does require a non-trivial refactoring in one user, but in +return it simplifies the StableHLO specification and implementation for +everyone. The ops of concern in TF are `StridedSlice`, `Reshape`, and `Squeeze`, +and it appears that interop between these ops and StableHLO is not required. + +B\) As discussed in [O2](#o2), the community is oftentimes using other dialects, +including `arith` and `shape`, for shape computations. These dialects can be +used with StableHLO programs, but there are interoperability issues which could +be improved if: I) StableHLO used scalars instead of 0-dimensional tensors, +II\) StableHLO used `index` instead of allowing integer and index types. + +However both of these changes need a significant amount of work, so they +are not included in this RFC. I) requires extending the StableHLO opset with +scalar operations, which involves defining semantics for these operations within +StableHLO programs and getting buy-in from producers to support them. This is a +fairly novel notion for the StableHLO ecosystem, so some further design +exploration is needed here. II) is conceptually straightforward but will need +updating a lot of code.