Skip to content

Commit

Permalink
[TOSA] Add lowering for aten.expm1 (#3949)
Browse files Browse the repository at this point in the history
* Add Torch to TOSA legalization for aten.expm1
* Update xfail_sets with new test results
* Add new LIT tests


Change-Id: I834d0c7416341f884612053aebf9fcc90bcb3b53

Signed-off-by: Justin Ngo <[email protected]>
  • Loading branch information
justin-ngo-arm authored Jan 10, 2025
1 parent a45356e commit 98e4eb2
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 4 deletions.
42 changes: 42 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8212,6 +8212,47 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
return success();
}

// Legalization for aten.expm1
template <>
LogicalResult ConvertAtenOp<AtenExpm1Op>::matchAndRewrite(
AtenExpm1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// expm1 formula:
// yi = exp(x) - 1
// Note: This lowering might not provide as great precision as aten.expm1
// since TOSA doesn't have a built-in expm1 op.
auto self = adaptor.getSelf();

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();

if (!isa<mlir::FloatType>(resultElemTy))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

// If input is not a float type then cast it to result element type
auto selfElemTy = selfType.getElementType();
if (!isa<mlir::FloatType>(selfElemTy))
self = tosa::promoteType(rewriter, self, resultType);

auto one =
tosa::getConstTensor<float>(rewriter, op, 1.0f, {}, resultElemTy).value();

auto expOp = rewriter.create<tosa::ExpOp>(op->getLoc(), resultType, self);

auto result = rewriter.create<tosa::SubOp>(op->getLoc(), resultType,
expOp.getResult(), one);

rewriter.replaceOp(op, {result.getResult()});

return success();
}

// Legalization for aten.tan
template <>
LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
Expand Down Expand Up @@ -8805,6 +8846,7 @@ std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
INSERT_ATENOP_PATTERN(AtenLogitOp);
INSERT_ATENOP_PATTERN(AtenLog1pOp);
INSERT_ATENOP_PATTERN(AtenLog10Op);
INSERT_ATENOP_PATTERN(AtenExpm1Op);
INSERT_ATENOP_PATTERN(AtenTanOp);
INSERT_ATENOP_PATTERN(AtenUnfoldOp);
#undef INSERT_ATENOP_PATTERN
Expand Down
8 changes: 4 additions & 4 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,8 +1709,12 @@
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSpecialExpm1IntModule_basic",
"ElementwiseSpecialExpm1Module_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
"ElementwiseUnaryIntModule_basic",
Expand Down Expand Up @@ -3668,16 +3672,12 @@
"ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseSinhIntModule_basic",
"ElementwiseSinhModule_basic",
"ElementwiseSpecialExpm1IntModule_basic",
"ElementwiseSpecialExpm1Module_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
Expand Down
33 changes: 33 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3024,3 +3024,36 @@ func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch.
}

// -----

// CHECK-LABEL: func.func @torch.aten.expm1$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_3:.*]] = tosa.exp %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_4:.*]] = tosa.sub %[[VAL_3]], %[[VAL_2]] : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func.func @torch.aten.expm1$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
%0 = torch.aten.expm1 %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.expm1$int(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32>
// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = tosa.exp %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_4]], %[[VAL_3]] : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func.func @torch.aten.expm1$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> {
%0 = torch.aten.expm1 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}

// -----

0 comments on commit 98e4eb2

Please sign in to comment.