diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 85b51ca7efaa..2a6f42a45c86 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -318,24 +318,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorOperands(operands, 9)) || binder.tensorResultType(resultType)) return failure(); - Value a = operands[0]; - Value aScale = operands[1]; - Value aZp = operands[2]; - Value b = operands[3]; - Value bScale = operands[4]; - Value bZp = operands[5]; - Value cScale = operands[6]; - Value cZp = operands[7]; - Value c = operands.size() == 9 ? operands[8] : nullptr; - - auto check = [](Value v) { - auto vTy = cast(v.getType()); - return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; }); - }; - if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || - !check(cScale) || !check(cScale)) - return rewriter.notifyMatchFailure( - binder.op, "not supported for non per-tensor quantization"); + Value input = operands[0]; + Value inputScale = operands[1]; + Value inputZp = operands[2]; + Value weight = operands[3]; + Value weightScale = operands[4]; + Value weightZp = operands[5]; + Value outputScale = operands[6]; + Value outputZp = operands[7]; + Value output = operands.size() == 9 ? operands[8] : nullptr; auto extract = [&rewriter, &binder](Value v) { auto vTy = cast(v.getType()); @@ -347,34 +338,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( v); }; - aZp = extract(aZp); - bZp = extract(bZp); - cZp = extract(cZp); - aScale = extract(aScale); - bScale = extract(bScale); - cScale = extract(cScale); - - auto make = [&rewriter, &binder](Value v, Value scale, - Value zp) -> Value { + inputZp = extract(inputZp); + outputZp = extract(outputZp); + inputScale = extract(inputScale); + outputScale = extract(outputScale); + auto makePerTensor = [&rewriter, &binder](Value v, Value scale, + Value zp) -> Value { auto ty = cast(v.getType()); auto newTy = getQTorchTypeFromTorchIntType(ty); return rewriter.create( binder.getLoc(), newTy, v, scale, zp); }; - a = make(a, aScale, aZp); - b = make(b, bScale, bZp); + auto makePerChannel = [&rewriter, &binder](Value v, Value scale, + Value zp, + Value axis) -> Value { + auto ty = cast(v.getType()); + auto newTy = getQTorchTypeFromTorchIntType(ty); + return rewriter.create( + binder.getLoc(), newTy, v, scale, zp, axis); + }; - auto cTy = rewriter.getType( + input = makePerTensor(input, inputScale, inputZp); + // The onnx's QLinearConv op expects per channel quantization only for + // the weight tensor for axis = 0. + auto weightTy = dyn_cast(weight.getType()); + auto weightScaleTy = + dyn_cast(weightScale.getType()); + if (!weightTy || !weightScaleTy || !weightTy.hasSizes() || + !weightScaleTy.hasSizes()) + return failure(); + auto weightShape = weightTy.getSizes(); + auto weightScaleShape = weightScaleTy.getSizes(); + Value weightScaleScalar = extract(weightScale); + if (weightScaleShape.size() == 1 && + weightScaleShape[0] != Torch::kUnknownSize && + weightScaleShape[0] == weightShape[0]) { + Value axis = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + weight = makePerChannel(weight, weightScale, weightZp, axis); + } else { + weightZp = extract(weightZp); + weight = makePerTensor(weight, weightScaleScalar, weightZp); + } + weightScale = weightScaleScalar; + + auto outputTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getIntegerType(32, /*issigned=*/true)); // TODO(suderman): insert convolution operator. - llvm::SmallVector newOperands = {a, b}; - if (c) - newOperands.push_back(c); + llvm::SmallVector newOperands = {input, weight}; + if (output) + newOperands.push_back(output); - cTy = rewriter.getType( + outputTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getType()); @@ -388,36 +406,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( newAttributes.push_back(namedAttr); } - c = rewriter - .create(binder.getLoc(), cTy, newOperands, - newAttributes, - binder.op->getRegions().size()) - .getResult(0); + output = rewriter + .create(binder.getLoc(), outputTy, + newOperands, newAttributes, + binder.op->getRegions().size()) + .getResult(0); Value outScale = rewriter.create( - binder.getLoc(), rewriter.getType(), aScale, - bScale); + binder.getLoc(), rewriter.getType(), inputScale, + weightScale); Value outZp = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - c = rewriter.create( - binder.getLoc(), cTy, c, outScale, outZp); - cTy = rewriter.getType( + output = rewriter.create( + binder.getLoc(), outputTy, output, outScale, outZp); + outputTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); - c = rewriter.create(binder.getLoc(), cTy, - c); - cTy = getQTorchTypeFromTorchIntType(resultType); + output = rewriter.create(binder.getLoc(), + outputTy, output); + outputTy = getQTorchTypeFromTorchIntType(resultType); Value dtyVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( - Torch::getScalarTypeForType(cTy.getDtype())))); - c = rewriter.create( - binder.getLoc(), cTy, c, cScale, cZp, dtyVal); + Torch::getScalarTypeForType(outputTy.getDtype())))); + output = rewriter.create( + binder.getLoc(), outputTy, output, outputScale, outputZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, - c); + output); return success(); }); patterns.onOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 16c86218dbc8..80cea34818aa 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -65,15 +65,15 @@ func.func @test_quantizelinear_f8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch. // ----- // CHECK-LABEL: @test_qlinearconv_nobias -func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> +func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[],f32>, %arg5: !torch.vtensor<[],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int - // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 @@ -103,17 +103,17 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: // ----- -// CHECK-LABEL: @test_qlinearconv_bias -func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: @test_qlinearconv_bias_weight_per_channel +func.func @test_qlinearconv_bias_weight_per_channel(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int - // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> - // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[B:.+]] = torch.aten._make_per_channel_quantized_tensor %arg3, %arg4, %arg5, %[[INT0]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]]