Skip to content

Commit

Permalink
address feedback: II
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Feb 7, 2024
1 parent 3d56368 commit 302cb9d
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 61 deletions.
69 changes: 31 additions & 38 deletions stablehlo/dialect/VhloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1225,41 +1225,35 @@ VhloBytecodeInterface::readUniformQuantizedPerAxisV1Type(
LOG_READ_CALL;
uint64_t flags;
Type storageType, expressedType;
FailureOr<APFloat> scale;
uint64_t quantizedDimension;
int64_t storageTypeMin, storageTypeMax;
SmallVector<APFloat> scales;
SmallVector<int64_t> zeroPoints;
int64_t quantizedDimension, numQuantizationParams, storageTypeMin,
storageTypeMax;
if (failed(reader.readVarInt(flags)) ||
failed(reader.readType(storageType)) ||
failed(reader.readType(expressedType)) ||
failed(reader.readSignedVarInt(quantizedDimension)) ||
failed(reader.readSignedVarInt(numQuantizationParams)))
return reader.emitError("invalid UniformQuantizedPerAxisType"),
UniformQuantizedPerAxisV1Type();

for (int64_t i = 0; i < numQuantizationParams; i++) {
if (failed(scale = reader.readAPFloatWithKnownSemantics(
llvm::APFloat::IEEEdouble())))
return reader.emitError("invalid UniformQuantizedPerAxisType"),
UniformQuantizedPerAxisV1Type();
scales.push_back(scale.value());
}

for (int64_t i = 0; i < numQuantizationParams; i++) {
if (failed(reader.readSignedVarInt(zeroPoints.emplace_back())))
return reader.emitError("invalid UniformQuantizedPerAxisType"),
UniformQuantizedPerAxisV1Type();
auto readScales = [&]() -> FailureOr<APFloat> {
return reader.readAPFloatWithKnownSemantics(llvm::APFloat::IEEEdouble());
};
auto readZeroPoints = [&]() -> FailureOr<int64_t> {
int64_t temp;
if (succeeded(reader.readSignedVarInt(temp))) {
return temp;
}
return failure();
};
if (succeeded(reader.readVarInt(flags)) &&
succeeded(reader.readType(storageType)) &&
succeeded(reader.readType(expressedType)) &&
succeeded(reader.readVarInt(quantizedDimension)) &&
succeeded(reader.readSignedVarInt(storageTypeMin)) &&
succeeded(reader.readSignedVarInt(storageTypeMax)) &&
succeeded(reader.readList(scales, readScales)) &&
succeeded(reader.readList(zeroPoints, readZeroPoints))) {
return UniformQuantizedPerAxisV1Type::get(
getContext(), flags, storageType, expressedType, quantizedDimension,
scales, zeroPoints, storageTypeMin, storageTypeMax);
}

if (failed(reader.readSignedVarInt(storageTypeMin)) ||
failed(reader.readSignedVarInt(storageTypeMax)))
return reader.emitError("invalid UniformQuantizedPerAxisType"),
UniformQuantizedPerAxisV1Type();

return UniformQuantizedPerAxisV1Type::get(
getContext(), flags, storageType, expressedType, quantizedDimension,
scales, zeroPoints, storageTypeMin, storageTypeMax);
return reader.emitError("invalid UniformQuantizedPerAxisType"),
UniformQuantizedPerAxisV1Type();
}

void VhloBytecodeInterface::write(UniformQuantizedPerAxisV1Type type,
Expand All @@ -1268,15 +1262,14 @@ void VhloBytecodeInterface::write(UniformQuantizedPerAxisV1Type type,
writer.writeVarInt(type.getFlags());
writer.writeType(type.getStorageType());
writer.writeType(type.getExpressedType());
writer.writeSignedVarInt(type.getQuantizedDimension());
int64_t numQuantizationParams = type.getScales().size();
writer.writeSignedVarInt(numQuantizationParams);
for (auto scale : type.getScales())
writer.writeAPFloatWithKnownSemantics(APFloat(scale));
for (auto zeroPoint : type.getZeroPoints())
writer.writeSignedVarInt(zeroPoint);
writer.writeVarInt(type.getQuantizedDimension());
writer.writeSignedVarInt(type.getStorageTypeMin());
writer.writeSignedVarInt(type.getStorageTypeMax());
writer.writeList(type.getScales(), [&](const APFloat &type) {
writer.writeAPFloatWithKnownSemantics(type);
});
writer.writeList(type.getZeroPoints(),
[&](int64_t type) { writer.writeSignedVarInt(type); });
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def VHLO_Dialect : Dialect {
0.15.0: MLIR bytecode version 5 => 6, use properties in VHLO.
0.16.0: Introduce `collective_broadcast` operation.
0.17.0: Allow reduce operations to promote to higher bitwidth.
0.18.0: Allow serialization of UniformQuantizedPerAxisType.
0.18.0: Introduce `UniformQuantizedPerAxisType` type.
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/VhloTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def VHLO_UniformQuantizedPerAxisV1 : VHLO_TypeDef<"UniformQuantizedPerAxisV1", "
"unsigned":$flags,
"::mlir::Type":$storageType,
"::mlir::Type":$expressedType,
"int64_t":$quantizedDimension,
"int32_t":$quantizedDimension,
VHLO_QuantizationScalesV1:$scales,
ArrayRefParameter<"int64_t">:$zeroPoints,
"int64_t":$storageTypeMin,
Expand All @@ -263,7 +263,7 @@ def VHLO_UniformQuantizedPerAxisV1 : VHLO_TypeDef<"UniformQuantizedPerAxisV1", "
LogicalResult UniformQuantizedPerAxisV1Type::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn,
unsigned int, mlir::Type storageType, mlir::Type expressedType,
int64_t, ::llvm::ArrayRef<::llvm::APFloat>, ::llvm::ArrayRef<int64_t>, int64_t, int64_t) {
int32_t, ::llvm::ArrayRef<::llvm::APFloat>, ::llvm::ArrayRef<int64_t>, int64_t, int64_t) {
if (!isFromVhlo(storageType) || !isFromVhlo(expressedType))
return errFn() << "expected VHLO type";
return success();
Expand Down
Binary file not shown.
6 changes: 3 additions & 3 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1296,9 +1296,9 @@ func.func @op_dynamic_pad(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tenso
}

// CHECK-LABEL: "op_dynamic_reshape"
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<?xindex>) -> tensor<?x?xf32>
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc
Binary file not shown.
42 changes: 25 additions & 17 deletions stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.0_16_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,29 @@ func.func @select_and_scatter_with_promotable_types(
%0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>

// expected-error @+1 {{failed to legalize operation 'vhlo.select_and_scatter_v1' that was explicitly marked illegal}}
%1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "stablehlo.compare"(%arg3, %arg4) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"stablehlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f64>, %arg4: tensor<f64>):
%2 = stablehlo.add %arg3, %arg4 : tensor<f64>
"stablehlo.return"(%2) : (tensor<f64>) -> ()
}) {
window_dimensions = array<i64: 1, 2, 2, 1>,
window_strides = array<i64: 1, 2, 2, 1>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf64>
func.return
%1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "stablehlo.compare"(%arg3, %arg4) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"stablehlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f64>, %arg4: tensor<f64>):
%2 = stablehlo.add %arg3, %arg4 : tensor<f64>
"stablehlo.return"(%2) : (tensor<f64>) -> ()
}) {
window_dimensions = array<i64: 1, 2, 2, 1>,
window_strides = array<i64: 1, 2, 2, 1>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf64>
func.return
}

// -----

// expected-error @+1 {{failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal}}
func.func @type_per_axis_quantization(%arg0: tensor<2x!quant.uniform<i8:f32:0, {34.0:16, 34.0:16}>>) -> tensor<2x!quant.uniform<i8:f32:0, {34.0:16, 34.0:16}>> {
%0 = stablehlo.add %arg0, %arg0 : tensor<2x!quant.uniform<i8:f32:0, {34.0:16, 34.0:16}>>
func.return %0 : tensor<2x!quant.uniform<i8:f32:0, {34.0:16, 34.0:16}>>
}

0 comments on commit 302cb9d

Please sign in to comment.