Skip to content

Commit

Permalink
Add tests/stablehlo_complex_math_expander.mlir. Update docs and clean…
Browse files Browse the repository at this point in the history
… up.
  • Loading branch information
pearu committed Dec 17, 2024
1 parent d8f6721 commit 771b6e0
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 130 deletions.
11 changes: 11 additions & 0 deletions build_tools/math/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ build/bin/stablehlo-opt --chlo-legalize-to-stablehlo --split-input-file --verify

and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`.

A similar procedure is applied for updating
`stablehlo/tests/stablehlo_complex_math_expander.mlir`:
```sh
build/bin/stablehlo-opt --stablehlo-complex-math-expander --split-input-file --verify-diagnostics \
stablehlo/tests/stablehlo_complex_math_expander.mlir | python llvm-project/mlir/utils/generate-test-checks.py | less
```

## A procedure for adding a new algorithm to an existing operation

1. Implement a new algorithm in
Expand All @@ -98,6 +105,10 @@ and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`.
7. Add a record of the operation to
`generate_ChloDecompositionPatternsMath.py`, see the for-loop in
`main` function.
- If the operation is a StableHLO operation on complex inputs, add
it to `stable-complex-math-expander` pass: update
`populateStablehloComplexMathExpanderPatterns` function in
`stablehlo/transforms/StablehloComplexMathExpander.cpp`.
8. Generate new implementations by running
`generate_ChloDecompositionPatternsMath.py` and remove existing
implementations in
Expand Down
121 changes: 1 addition & 120 deletions stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3924,123 +3924,4 @@ func.func @square_complex_f32(%arg : tensor<complex<f32>>) -> tensor<complex<f32
func.func @square_f32(%arg : tensor<f32>) -> tensor<f32> {
%result = "chlo.square"(%arg) : (tensor<f32>) -> tensor<f32>
func.return %result : tensor<f32>
}

// CHECK-LABEL: @log1p_complex_f32(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<complex<f32>>) -> tensor<complex<f32>> {
// CHECK: %[[VAL_1:.*]] = stablehlo.real %[[VAL_0]] : (tensor<complex<f32>>) -> tensor<f32>
// CHECK: %[[VAL_2:.*]] = stablehlo.abs %[[VAL_1]] : tensor<f32>
// CHECK: %[[VAL_3:.*]] = stablehlo.imag %[[VAL_0]] : (tensor<complex<f32>>) -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = stablehlo.abs %[[VAL_3]] : tensor<f32>
// CHECK: %[[VAL_5:.*]] = stablehlo.maximum %[[VAL_2]], %[[VAL_4]] : tensor<f32>
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_7:.*]] = stablehlo.sqrt %[[VAL_6]] : tensor<f32>
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.00999999977> : tensor<f32>
// CHECK: %[[VAL_9:.*]] = stablehlo.multiply %[[VAL_7]], %[[VAL_8]] : tensor<f32>
// CHECK: %[[VAL_10:.*]] = stablehlo.compare GT, %[[VAL_5]], %[[VAL_9]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_11:.*]] = stablehlo.log %[[VAL_5]] : tensor<f32>
// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<5.000000e-01> : tensor<f32>
// CHECK: %[[VAL_13:.*]] = stablehlo.minimum %[[VAL_2]], %[[VAL_4]] : tensor<f32>
// CHECK: %[[VAL_14:.*]] = stablehlo.compare EQ, %[[VAL_13]], %[[VAL_5]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_15:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[VAL_16:.*]] = stablehlo.divide %[[VAL_13]], %[[VAL_5]] : tensor<f32>
// CHECK: %[[VAL_17:.*]] = stablehlo.multiply %[[VAL_16]], %[[VAL_16]] : tensor<f32>
// CHECK: %[[VAL_18:.*]] = stablehlo.select %[[VAL_14]], %[[VAL_15]], %[[VAL_17]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_19:.*]] = stablehlo.log_plus_one %[[VAL_18]] : tensor<f32>
// CHECK: %[[VAL_20:.*]] = stablehlo.multiply %[[VAL_12]], %[[VAL_19]] : tensor<f32>
// CHECK: %[[VAL_21:.*]] = stablehlo.add %[[VAL_11]], %[[VAL_20]] : tensor<f32>
// CHECK: %[[VAL_22:.*]] = stablehlo.add %[[VAL_1]], %[[VAL_15]] : tensor<f32>
// CHECK: %[[VAL_23:.*]] = stablehlo.abs %[[VAL_22]] : tensor<f32>
// CHECK: %[[VAL_24:.*]] = stablehlo.add %[[VAL_23]], %[[VAL_4]] : tensor<f32>
// CHECK: %[[VAL_25:.*]] = stablehlo.constant dense<2.000000e-01> : tensor<f32>
// CHECK: %[[VAL_26:.*]] = stablehlo.compare LT, %[[VAL_24]], %[[VAL_25]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_27:.*]] = stablehlo.multiply %[[VAL_22]], %[[VAL_22]] : tensor<f32>
// CHECK: %[[VAL_28:.*]] = stablehlo.multiply %[[VAL_3]], %[[VAL_3]] : tensor<f32>
// CHECK: %[[VAL_29:.*]] = stablehlo.add %[[VAL_27]], %[[VAL_28]] : tensor<f32>
// CHECK: %[[VAL_30:.*]] = stablehlo.log %[[VAL_29]] : tensor<f32>
// CHECK: %[[VAL_31:.*]] = stablehlo.multiply %[[VAL_12]], %[[VAL_30]] : tensor<f32>
// CHECK: %[[VAL_32:.*]] = stablehlo.add %[[VAL_1]], %[[VAL_1]] : tensor<f32>
// CHECK: %[[VAL_33:.*]] = stablehlo.add %[[VAL_32]], %[[VAL_28]] : tensor<f32>
// CHECK: %[[VAL_34:.*]] = stablehlo.multiply %[[VAL_1]], %[[VAL_1]] : tensor<f32>
// CHECK: %[[VAL_35:.*]] = stablehlo.add %[[VAL_33]], %[[VAL_34]] : tensor<f32>
// CHECK: %[[VAL_36:.*]] = stablehlo.negate %[[VAL_28]] : tensor<f32>
// CHECK: %[[VAL_37:.*]] = stablehlo.constant dense<0x7F800000> : tensor<f32>
// CHECK: %[[VAL_38:.*]] = stablehlo.compare GT, %[[VAL_6]], %[[VAL_37]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_39:.*]] = stablehlo.constant dense<0x4D000000> : tensor<f32>
// CHECK: %[[VAL_40:.*]] = stablehlo.constant dense<9.99999968E+37> : tensor<f32>
// CHECK: %[[VAL_41:.*]] = stablehlo.compare GT, %[[VAL_6]], %[[VAL_40]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_42:.*]] = stablehlo.constant dense<4.097000e+03> : tensor<f32>
// CHECK: %[[VAL_43:.*]] = stablehlo.constant dense<6.500000e+01> : tensor<f32>
// CHECK: %[[VAL_44:.*]] = stablehlo.select %[[VAL_41]], %[[VAL_42]], %[[VAL_43]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_45:.*]] = stablehlo.select %[[VAL_38]], %[[VAL_39]], %[[VAL_44]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_46:.*]] = stablehlo.multiply %[[VAL_45]], %[[VAL_3]] : tensor<f32>
// CHECK: %[[VAL_47:.*]] = stablehlo.subtract %[[VAL_3]], %[[VAL_46]] : tensor<f32>
// CHECK: %[[VAL_48:.*]] = stablehlo.add %[[VAL_46]], %[[VAL_47]] : tensor<f32>
// CHECK: %[[VAL_49:.*]] = stablehlo.multiply %[[VAL_48]], %[[VAL_48]] : tensor<f32>
// CHECK: %[[VAL_50:.*]] = stablehlo.add %[[VAL_36]], %[[VAL_49]] : tensor<f32>
// CHECK: %[[VAL_51:.*]] = stablehlo.subtract %[[VAL_3]], %[[VAL_48]] : tensor<f32>
// CHECK: %[[VAL_52:.*]] = stablehlo.multiply %[[VAL_48]], %[[VAL_51]] : tensor<f32>
// CHECK: %[[VAL_53:.*]] = stablehlo.add %[[VAL_50]], %[[VAL_52]] : tensor<f32>
// CHECK: %[[VAL_54:.*]] = stablehlo.add %[[VAL_53]], %[[VAL_52]] : tensor<f32>
// CHECK: %[[VAL_55:.*]] = stablehlo.multiply %[[VAL_51]], %[[VAL_51]] : tensor<f32>
// CHECK: %[[VAL_56:.*]] = stablehlo.add %[[VAL_54]], %[[VAL_55]] : tensor<f32>
// CHECK: %[[VAL_57:.*]] = stablehlo.add %[[VAL_35]], %[[VAL_56]] : tensor<f32>
// CHECK: %[[VAL_58:.*]] = stablehlo.negate %[[VAL_34]] : tensor<f32>
// CHECK: %[[VAL_59:.*]] = stablehlo.multiply %[[VAL_45]], %[[VAL_1]] : tensor<f32>
// CHECK: %[[VAL_60:.*]] = stablehlo.subtract %[[VAL_1]], %[[VAL_59]] : tensor<f32>
// CHECK: %[[VAL_61:.*]] = stablehlo.add %[[VAL_59]], %[[VAL_60]] : tensor<f32>
// CHECK: %[[VAL_62:.*]] = stablehlo.multiply %[[VAL_61]], %[[VAL_61]] : tensor<f32>
// CHECK: %[[VAL_63:.*]] = stablehlo.add %[[VAL_58]], %[[VAL_62]] : tensor<f32>
// CHECK: %[[VAL_64:.*]] = stablehlo.subtract %[[VAL_1]], %[[VAL_61]] : tensor<f32>
// CHECK: %[[VAL_65:.*]] = stablehlo.multiply %[[VAL_61]], %[[VAL_64]] : tensor<f32>
// CHECK: %[[VAL_66:.*]] = stablehlo.add %[[VAL_63]], %[[VAL_65]] : tensor<f32>
// CHECK: %[[VAL_67:.*]] = stablehlo.add %[[VAL_66]], %[[VAL_65]] : tensor<f32>
// CHECK: %[[VAL_68:.*]] = stablehlo.multiply %[[VAL_64]], %[[VAL_64]] : tensor<f32>
// CHECK: %[[VAL_69:.*]] = stablehlo.add %[[VAL_67]], %[[VAL_68]] : tensor<f32>
// CHECK: %[[VAL_70:.*]] = stablehlo.add %[[VAL_57]], %[[VAL_69]] : tensor<f32>
// CHECK: %[[VAL_71:.*]] = stablehlo.subtract %[[VAL_33]], %[[VAL_32]] : tensor<f32>
// CHECK: %[[VAL_72:.*]] = stablehlo.subtract %[[VAL_33]], %[[VAL_71]] : tensor<f32>
// CHECK: %[[VAL_73:.*]] = stablehlo.subtract %[[VAL_32]], %[[VAL_72]] : tensor<f32>
// CHECK: %[[VAL_74:.*]] = stablehlo.subtract %[[VAL_28]], %[[VAL_71]] : tensor<f32>
// CHECK: %[[VAL_75:.*]] = stablehlo.add %[[VAL_73]], %[[VAL_74]] : tensor<f32>
// CHECK: %[[VAL_76:.*]] = stablehlo.subtract %[[VAL_35]], %[[VAL_33]] : tensor<f32>
// CHECK: %[[VAL_77:.*]] = stablehlo.subtract %[[VAL_35]], %[[VAL_76]] : tensor<f32>
// CHECK: %[[VAL_78:.*]] = stablehlo.subtract %[[VAL_33]], %[[VAL_77]] : tensor<f32>
// CHECK: %[[VAL_79:.*]] = stablehlo.subtract %[[VAL_34]], %[[VAL_76]] : tensor<f32>
// CHECK: %[[VAL_80:.*]] = stablehlo.add %[[VAL_78]], %[[VAL_79]] : tensor<f32>
// CHECK: %[[VAL_81:.*]] = stablehlo.add %[[VAL_75]], %[[VAL_80]] : tensor<f32>
// CHECK: %[[VAL_82:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_35]] : tensor<f32>
// CHECK: %[[VAL_83:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_82]] : tensor<f32>
// CHECK: %[[VAL_84:.*]] = stablehlo.subtract %[[VAL_35]], %[[VAL_83]] : tensor<f32>
// CHECK: %[[VAL_85:.*]] = stablehlo.subtract %[[VAL_56]], %[[VAL_82]] : tensor<f32>
// CHECK: %[[VAL_86:.*]] = stablehlo.add %[[VAL_84]], %[[VAL_85]] : tensor<f32>
// CHECK: %[[VAL_87:.*]] = stablehlo.add %[[VAL_81]], %[[VAL_86]] : tensor<f32>
// CHECK: %[[VAL_88:.*]] = stablehlo.subtract %[[VAL_70]], %[[VAL_57]] : tensor<f32>
// CHECK: %[[VAL_89:.*]] = stablehlo.subtract %[[VAL_70]], %[[VAL_88]] : tensor<f32>
// CHECK: %[[VAL_90:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_89]] : tensor<f32>
// CHECK: %[[VAL_91:.*]] = stablehlo.subtract %[[VAL_69]], %[[VAL_88]] : tensor<f32>
// CHECK: %[[VAL_92:.*]] = stablehlo.add %[[VAL_90]], %[[VAL_91]] : tensor<f32>
// CHECK: %[[VAL_93:.*]] = stablehlo.add %[[VAL_87]], %[[VAL_92]] : tensor<f32>
// CHECK: %[[VAL_94:.*]] = stablehlo.add %[[VAL_70]], %[[VAL_93]] : tensor<f32>
// CHECK: %[[VAL_95:.*]] = stablehlo.log_plus_one %[[VAL_94]] : tensor<f32>
// CHECK: %[[VAL_96:.*]] = stablehlo.multiply %[[VAL_12]], %[[VAL_95]] : tensor<f32>
// CHECK: %[[VAL_97:.*]] = stablehlo.select %[[VAL_26]], %[[VAL_31]], %[[VAL_96]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_98:.*]] = stablehlo.select %[[VAL_10]], %[[VAL_21]], %[[VAL_97]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_99:.*]] = stablehlo.atan2 %[[VAL_3]], %[[VAL_22]] : tensor<f32>
// CHECK: %[[VAL_100:.*]] = stablehlo.complex %[[VAL_98]], %[[VAL_99]] : tensor<complex<f32>>
// CHECK: return %[[VAL_100]] : tensor<complex<f32>>
// CHECK: }
func.func @log1p_complex_f32(%arg : tensor<complex<f32>>) -> tensor<complex<f32>> {
%result = "chlo.log1p"(%arg) : (tensor<complex<f32>>) -> tensor<complex<f32>>
func.return %result : tensor<complex<f32>>
}

// CHECK-LABEL: @log1p_f32(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: %[[VAL_1:.*]] = stablehlo.log_plus_one %[[VAL_0]] : tensor<f32>
// CHECK: return %[[VAL_1]] : tensor<f32>
// CHECK: }
func.func @log1p_f32(%arg : tensor<f32>) -> tensor<f32> {
%result = "chlo.log1p"(%arg) : (tensor<f32>) -> tensor<f32>
func.return %result : tensor<f32>
}
}
Loading

0 comments on commit 771b6e0

Please sign in to comment.