From e581b33f9694204e12213f26b5de7b4b5126d6ea Mon Sep 17 00:00:00 2001 From: lonely eagle <75576166+linuxlonelyeagle@users.noreply.github.com> Date: Thu, 25 Jan 2024 10:44:08 +0800 Subject: [PATCH] [Stablehlo]fix CumsumInputDtypeInt32Module_basic on stablehlo backend. (#2797) Code used for testing.For the location of CumsumInputDtypeInt32Module in the repo you can see [here](https://github.com/llvm/torch-mlir/blob/311b6b0286bfa016346bc7fd8b441bbd50216060/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py#L4148). ```python import torch import torch_mlir class CumsumInputDtypeInt32Module(torch.nn.Module): def __init__(self): super().__init__() def forward(self, val): return torch.ops.aten.cumsum(val, 1) module = torch_mlir.compile(CumsumInputDtypeInt32Module(), [torch.randn(2, 7, 4).to(torch.int32)], output_type="stablehlo") print(module.operation.get_asm()) ``` After fixing the bugs. ``` module attributes {torch.debug_module_name = "CumsumInputDtypeInt32Module"} { func.func @forward(%arg0: tensor<2x7x4xi32>) -> tensor<2x7x4xi64> { %0 = stablehlo.constant dense<0> : tensor %1 = stablehlo.convert %arg0 : (tensor<2x7x4xi32>) -> tensor<2x7x4xi64> %2 = "stablehlo.reduce_window"(%1, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor): %3 = stablehlo.add %arg1, %arg2 : tensor stablehlo.return %3 : tensor }) {padding = dense<[[0, 0], [6, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 7, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<2x7x4xi64>, tensor) -> tensor<2x7x4xi64> return %2 : tensor<2x7x4xi64> } } ``` --- lib/Conversion/TorchToStablehlo/Pooling.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 7c28a2fd3004..e90f231c74f5 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -569,11 +569,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = input.getType().cast(); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + inputTy = input.getType().cast(); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); auto inputShape = inputTy.getShape(); - auto outTy = - getTypeConverter()->convertType(op.getType()).cast(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {