Skip to content

Commit

Permalink
Merge branch 'reducelogsumexp' of https://github.com/archana-ramaling…
Browse files Browse the repository at this point in the history
…am/torch-mlir into reducelogsumexp
  • Loading branch information
archana-ramalingam committed Apr 26, 2024
2 parents 7ac151c + 74db4a7 commit c461b86
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
51 changes: 51 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,57 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success();
});
patterns.onOp(
"ReduceLogSumExp", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
0))
return failure();

// out = Log(reducesum(exp(data)))
Value castDType = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(/*Float64Type*/ 7));
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value constFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
auto size = data.getType()
.dyn_cast<Torch::ValueTensorType>()
.getOptionalSizes();
auto f64ResultType = rewriter.getType<Torch::ValueTensorType>(
size, rewriter.getF64Type());
Value dataCast = rewriter.create<Torch::AtenToDtypeOp>(
binder.getLoc(), f64ResultType, data, castDType,
/*non_blocking=*/constFalse, /*copy=*/constFalse,
/*memory_format=*/noneVal);

Value dataExp = rewriter.create<Torch::AtenExpOp>(
binder.getLoc(), f64ResultType, dataCast);
auto reducedSumBool = reducedSumImpl(
binder, rewriter, dataExp, f64ResultType,
/*storeValue=*/data, keepDims, noop_with_empty_axes, true);

if (failed(reducedSumBool))
return rewriter.notifyMatchFailure(
binder.op,
"Failed to perform sum operation on square of operand");

Value finalResult = rewriter.create<Torch::AtenLogOp>(
binder.getLoc(), f64ResultType, data);

Value resultDtype = Torch::getDtypeIntValueForType(
rewriter, binder.getLoc(), resultType.getDtype());
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, finalResult, resultDtype,
/*non_blocking=*/constFalse, /*copy=*/constFalse,
/*memory_format=*/noneVal);
return success();
});
patterns.onOp(
"ReduceMean", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
76 changes: 76 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,82 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<

// -----

// CHECK-LABEL: func.func @test_reduce_log_sum_exp_default_axes_keepdims_example
func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT7:.+]] = torch.constant.int 7
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,1,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[1,1,1],f32>
%0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
return %0 : !torch.vtensor<[1,1,1],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded
func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT7:.+]] = torch.constant.int 7
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2],f32>
%0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
return %0 : !torch.vtensor<[3,2],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_example
func.func @test_reduce_log_sum_exp_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT7:.+]] = torch.constant.int 7
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32>
%0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
return %0 : !torch.vtensor<[3,2,1],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_int_input_example
func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT7:.+]] = torch.constant.int 7
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],si64>
// CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64>
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log1p %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32>
%0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
return %0 : !torch.vtensor<[3,2,1],f32>
}

// -----

// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
Expand Down

0 comments on commit c461b86

Please sign in to comment.