diff --git a/stablehlo/dialect/VhloBytecode.cpp b/stablehlo/dialect/VhloBytecode.cpp index f730ada631c..aee3fb65063 100644 --- a/stablehlo/dialect/VhloBytecode.cpp +++ b/stablehlo/dialect/VhloBytecode.cpp @@ -1225,41 +1225,35 @@ VhloBytecodeInterface::readUniformQuantizedPerAxisV1Type( LOG_READ_CALL; uint64_t flags; Type storageType, expressedType; - FailureOr scale; + uint64_t quantizedDimension; + int64_t storageTypeMin, storageTypeMax; SmallVector scales; SmallVector zeroPoints; - int64_t quantizedDimension, numQuantizationParams, storageTypeMin, - storageTypeMax; - if (failed(reader.readVarInt(flags)) || - failed(reader.readType(storageType)) || - failed(reader.readType(expressedType)) || - failed(reader.readSignedVarInt(quantizedDimension)) || - failed(reader.readSignedVarInt(numQuantizationParams))) - return reader.emitError("invalid UniformQuantizedPerAxisType"), - UniformQuantizedPerAxisV1Type(); - - for (int64_t i = 0; i < numQuantizationParams; i++) { - if (failed(scale = reader.readAPFloatWithKnownSemantics( - llvm::APFloat::IEEEdouble()))) - return reader.emitError("invalid UniformQuantizedPerAxisType"), - UniformQuantizedPerAxisV1Type(); - scales.push_back(scale.value()); - } - - for (int64_t i = 0; i < numQuantizationParams; i++) { - if (failed(reader.readSignedVarInt(zeroPoints.emplace_back()))) - return reader.emitError("invalid UniformQuantizedPerAxisType"), - UniformQuantizedPerAxisV1Type(); + auto readScales = [&]() -> FailureOr { + return reader.readAPFloatWithKnownSemantics(llvm::APFloat::IEEEdouble()); + }; + auto readZeroPoints = [&]() -> FailureOr { + int64_t temp; + if (succeeded(reader.readSignedVarInt(temp))) { + return temp; + } + return failure(); + }; + if (succeeded(reader.readVarInt(flags)) && + succeeded(reader.readType(storageType)) && + succeeded(reader.readType(expressedType)) && + succeeded(reader.readVarInt(quantizedDimension)) && + succeeded(reader.readSignedVarInt(storageTypeMin)) && + succeeded(reader.readSignedVarInt(storageTypeMax)) && + succeeded(reader.readList(scales, readScales)) && + succeeded(reader.readList(zeroPoints, readZeroPoints))) { + return UniformQuantizedPerAxisV1Type::get( + getContext(), flags, storageType, expressedType, quantizedDimension, + scales, zeroPoints, storageTypeMin, storageTypeMax); } - if (failed(reader.readSignedVarInt(storageTypeMin)) || - failed(reader.readSignedVarInt(storageTypeMax))) - return reader.emitError("invalid UniformQuantizedPerAxisType"), - UniformQuantizedPerAxisV1Type(); - - return UniformQuantizedPerAxisV1Type::get( - getContext(), flags, storageType, expressedType, quantizedDimension, - scales, zeroPoints, storageTypeMin, storageTypeMax); + return reader.emitError("invalid UniformQuantizedPerAxisType"), + UniformQuantizedPerAxisV1Type(); } void VhloBytecodeInterface::write(UniformQuantizedPerAxisV1Type type, @@ -1268,15 +1262,14 @@ void VhloBytecodeInterface::write(UniformQuantizedPerAxisV1Type type, writer.writeVarInt(type.getFlags()); writer.writeType(type.getStorageType()); writer.writeType(type.getExpressedType()); - writer.writeSignedVarInt(type.getQuantizedDimension()); - int64_t numQuantizationParams = type.getScales().size(); - writer.writeSignedVarInt(numQuantizationParams); - for (auto scale : type.getScales()) - writer.writeAPFloatWithKnownSemantics(APFloat(scale)); - for (auto zeroPoint : type.getZeroPoints()) - writer.writeSignedVarInt(zeroPoint); + writer.writeVarInt(type.getQuantizedDimension()); writer.writeSignedVarInt(type.getStorageTypeMin()); writer.writeSignedVarInt(type.getStorageTypeMax()); + writer.writeList(type.getScales(), [&](const APFloat &type) { + writer.writeAPFloatWithKnownSemantics(type); + }); + writer.writeList(type.getZeroPoints(), + [&](int64_t type) { writer.writeSignedVarInt(type); }); } //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td index cc2c2a13599..cdb2c18531d 100644 --- a/stablehlo/dialect/VhloDialect.td +++ b/stablehlo/dialect/VhloDialect.td @@ -35,7 +35,7 @@ def VHLO_Dialect : Dialect { 0.15.0: MLIR bytecode version 5 => 6, use properties in VHLO. 0.16.0: Introduce `collective_broadcast` operation. 0.17.0: Allow reduce operations to promote to higher bitwidth. - 0.18.0: Allow serialization of UniformQuantizedPerAxisType. + 0.18.0: Introduce `UniformQuantizedPerAxisType` type. }]; let useDefaultAttributePrinterParser = 0; diff --git a/stablehlo/dialect/VhloTypes.td b/stablehlo/dialect/VhloTypes.td index be4406bfb9c..bdb6c1519a6 100644 --- a/stablehlo/dialect/VhloTypes.td +++ b/stablehlo/dialect/VhloTypes.td @@ -252,7 +252,7 @@ def VHLO_UniformQuantizedPerAxisV1 : VHLO_TypeDef<"UniformQuantizedPerAxisV1", " "unsigned":$flags, "::mlir::Type":$storageType, "::mlir::Type":$expressedType, - "int64_t":$quantizedDimension, + "int32_t":$quantizedDimension, VHLO_QuantizationScalesV1:$scales, ArrayRefParameter<"int64_t">:$zeroPoints, "int64_t":$storageTypeMin, @@ -263,7 +263,7 @@ def VHLO_UniformQuantizedPerAxisV1 : VHLO_TypeDef<"UniformQuantizedPerAxisV1", " LogicalResult UniformQuantizedPerAxisV1Type::verify( llvm::function_ref errFn, unsigned int, mlir::Type storageType, mlir::Type expressedType, - int64_t, ::llvm::ArrayRef<::llvm::APFloat>, ::llvm::ArrayRef, int64_t, int64_t) { + int32_t, ::llvm::ArrayRef<::llvm::APFloat>, ::llvm::ArrayRef, int64_t, int64_t) { if (!isFromVhlo(storageType) || !isFromVhlo(expressedType)) return errFn() << "expected VHLO type"; return success(); diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc new file mode 100644 index 00000000000..f6b993b17c6 Binary files /dev/null and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir index 5695d1a9fb1..70e8d5e80dc 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir @@ -1296,9 +1296,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc index ccd9c1d914d..b0fa6509965 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.0_16_0.mlir b/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.0_16_0.mlir index 0ac5297f1c3..6d1a12ad99e 100644 --- a/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.0_16_0.mlir +++ b/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.0_16_0.mlir @@ -103,21 +103,29 @@ func.func @select_and_scatter_with_promotable_types( %0 = stablehlo.constant dense<0.000000e+00> : tensor // expected-error @+1 {{failed to legalize operation 'vhlo.select_and_scatter_v1' that was explicitly marked illegal}} - %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - %2 = "stablehlo.compare"(%arg3, %arg4) { - comparison_direction = #stablehlo - } : (tensor, tensor) -> tensor - "stablehlo.return"(%2) : (tensor) -> () - }, { - ^bb0(%arg3: tensor, %arg4: tensor): - %2 = stablehlo.add %arg3, %arg4 : tensor - "stablehlo.return"(%2) : (tensor) -> () - }) { - window_dimensions = array, - window_strides = array, - padding = dense<0> : tensor<4x2xi64> - } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> - tensor<10x24x24x64xf64> - func.return + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "stablehlo.compare"(%arg3, %arg4) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = array, + window_strides = array, + padding = dense<0> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf64> + func.return +} + +// ----- + +// expected-error @+1 {{failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal}} +func.func @type_per_axis_quantization(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { + %0 = stablehlo.add %arg0, %arg0 : tensor<2x!quant.uniform> + func.return %0 : tensor<2x!quant.uniform> }