Skip to content

Commit

Permalink
VHLO legalization of StableHLO UniformQuantizedPerAxisType (#1986)
Browse files Browse the repository at this point in the history
The PR enables  
1. VHLO legalization of StableHLO UniformQuantizedPerAxisType.
2. Writing/reading from bytecode format. 
 
The mlir::quant dialect has two different Quantized types,
[one](https://github.com/llvm/llvm-project/blob/e2bb91b25c8740625fecd127c1d908a2fabd0102/mlir/include/mlir/Dialect/Quant/QuantTypes.h#L255)
for per-tensor and the
[other](https://github.com/llvm/llvm-project/blob/e2bb91b25c8740625fecd127c1d908a2fabd0102/mlir/include/mlir/Dialect/Quant/QuantTypes.h#L315)
for per-axis. Following the same analogy, we are now having two types in
VHLO `UniformQuantizedV1Type` and `UniformQuantizedPerAxisV1Type` resp.
Also added the legalization of StableHLO -> VHLO and VHLO -> StableHLO
for the corresponding types.

Regarding testing the compatibility, we only added this feature (of new
per-axis type) to
[stablehlo_legalize_to_vhlo.0_17_0.mlir](https://github.com/openxla/stablehlo/compare/main...sdasgup3:serialize-per-axis-quantization-type?expand=1#diff-d78cbc82314e64545ee7b9f6a66a0b910f4a8ac2e7f27e2ad55c4b33d2ea409c.
@GleasonK Please let me know if the feature needs to be added to earlier
version as well.
  • Loading branch information
sdasgup3 authored Feb 10, 2024
1 parent e191eb4 commit 5405149
Show file tree
Hide file tree
Showing 11 changed files with 2,627 additions and 31 deletions.
4 changes: 2 additions & 2 deletions build_tools/github_actions/lint_version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
VERSION_H="stablehlo/dialect/Version.h"
set_version_var() {
# getCurrentVersion() { Version(0, X, Y); }
VERSION_STR=$(cat $VERSION_H | grep getCurrentVersion -A1 | grep -o 'Version([0-9], .*)')
REGEX="Version\(([0-9]+), ([0-9]+), ([0-9]+)\)"
VERSION_STR=$(cat $VERSION_H | grep getCurrentVersion -A1 | grep -o 'Version(.*[0-9])')
REGEX="Version\(/\*.*=\*/([0-9]+), /\*.*=\*/([0-9]+), /\*.*=\*/[^0-9]*([0-9]+)\)"
if [[ $VERSION_STR =~ $REGEX ]]; then
VERSION=("${BASH_REMATCH[1]}" "${BASH_REMATCH[2]}" "${BASH_REMATCH[3]}")
else
Expand Down
8 changes: 6 additions & 2 deletions stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(0, 17, 8); }
static Version getCurrentVersion() {
return Version(/*major=*/0, /*minor=*/18, /*patch=*/0);
}

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
static Version getMinimumVersion() {
return Version(/*major=*/0, /*minor=*/9, /*patch=*/0);
}

/// Return the MLIR Bytecode Format associated with the version instance.
/// Returns failure if version is not in compatibility window.
Expand Down
96 changes: 88 additions & 8 deletions stablehlo/dialect/VhloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,18 @@ enum TypeCode {
/// FloatF8E4M3B11FNUZV1Type {
/// }
kFloatF8E4M3B11FNUZV1Type = 29,

/// UniformQuantizedPerAxisV1Type {
/// flags: varint
/// storageType: Type
/// expressedType: Type
/// quantizedDimension: svarint
/// scales: list of APFloat
/// zeroPoints: list of svarint
/// storageTypeMin: svarint
/// storageTypeMax: svarint
/// }
kUniformQuantizedPerAxisV1Type = 30,
};

} // namespace vhlo_encoding
Expand Down Expand Up @@ -419,6 +431,8 @@ class VhloBytecodeInterface : public BytecodeDialectInterface {
bool hasEncoding) const;
TokenV1Type readTokenV1Type(DialectBytecodeReader &reader) const;
TupleV1Type readTupleV1Type(DialectBytecodeReader &reader) const;
UniformQuantizedPerAxisV1Type readUniformQuantizedPerAxisV1Type(
DialectBytecodeReader &reader) const;
UniformQuantizedV1Type readUniformQuantizedV1Type(
DialectBytecodeReader &reader) const;
UnrankedTensorV1Type readUnrankedTensorV1Type(
Expand All @@ -431,6 +445,8 @@ class VhloBytecodeInterface : public BytecodeDialectInterface {
void write(RankedTensorV1Type type, DialectBytecodeWriter &writer) const;
void write(TokenV1Type type, DialectBytecodeWriter &writer) const;
void write(TupleV1Type type, DialectBytecodeWriter &writer) const;
void write(UniformQuantizedPerAxisV1Type type,
DialectBytecodeWriter &writer) const;
void write(UniformQuantizedV1Type type, DialectBytecodeWriter &writer) const;
void write(UnrankedTensorV1Type type, DialectBytecodeWriter &writer) const;
};
Expand Down Expand Up @@ -971,6 +987,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
return readTokenV1Type(reader);
case vhlo_encoding::kTupleV1Type:
return readTupleV1Type(reader);
case vhlo_encoding::kUniformQuantizedPerAxisV1Type:
return readUniformQuantizedPerAxisV1Type(reader);
case vhlo_encoding::kUniformQuantizedV1Type:
return readUniformQuantizedV1Type(reader);
case vhlo_encoding::kUnrankedTensorV1Type:
Expand All @@ -988,11 +1006,11 @@ LogicalResult VhloBytecodeInterface::writeType(
Type type, DialectBytecodeWriter &writer) const {
return TypeSwitch<Type, LogicalResult>(type)
.Case<ComplexV1Type, FunctionV1Type, RankedTensorV1Type, TokenV1Type,
TupleV1Type, UnrankedTensorV1Type, UniformQuantizedV1Type>(
[&](auto type) {
LOG_WRITE_CALL;
return write(type, writer), success();
})
TupleV1Type, UnrankedTensorV1Type, UniformQuantizedPerAxisV1Type,
UniformQuantizedV1Type>([&](auto type) {
LOG_WRITE_CALL;
return write(type, writer), success();
})
.Case([&](BooleanV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kBooleanV1Type), success();
Expand Down Expand Up @@ -1197,17 +1215,79 @@ void VhloBytecodeInterface::write(TupleV1Type type,
writer.writeTypes(type.getTypes());
}

//===----------------------------------------------------------------------===//
// UniformQuantizedPerAxisV1Type
//===----------------------------------------------------------------------===//

UniformQuantizedPerAxisV1Type
VhloBytecodeInterface::readUniformQuantizedPerAxisV1Type(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
uint64_t flags = 0;
Type storageType;
Type expressedType;
uint64_t quantizedDimension = 0;
int64_t storageTypeMin = 0;
int64_t storageTypeMax = 0;
SmallVector<APFloat> scales;
SmallVector<int64_t> zeroPoints;
auto readScales = [&]() -> FailureOr<APFloat> {
return reader.readAPFloatWithKnownSemantics(llvm::APFloat::IEEEdouble());
};
auto readZeroPoints = [&]() -> FailureOr<int64_t> {
int64_t temp;
if (succeeded(reader.readSignedVarInt(temp))) {
return temp;
}
return failure();
};
if (succeeded(reader.readVarInt(flags)) &&
succeeded(reader.readType(storageType)) &&
succeeded(reader.readType(expressedType)) &&
succeeded(reader.readVarInt(quantizedDimension)) &&
succeeded(reader.readSignedVarInt(storageTypeMin)) &&
succeeded(reader.readSignedVarInt(storageTypeMax)) &&
succeeded(reader.readList(scales, readScales)) &&
succeeded(reader.readList(zeroPoints, readZeroPoints))) {
return UniformQuantizedPerAxisV1Type::get(
getContext(), flags, storageType, expressedType, quantizedDimension,
scales, zeroPoints, storageTypeMin, storageTypeMax);
}

return reader.emitError("invalid UniformQuantizedPerAxisType"),
UniformQuantizedPerAxisV1Type();
}

void VhloBytecodeInterface::write(UniformQuantizedPerAxisV1Type type,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(vhlo_encoding::kUniformQuantizedPerAxisV1Type);
writer.writeVarInt(type.getFlags());
writer.writeType(type.getStorageType());
writer.writeType(type.getExpressedType());
writer.writeVarInt(type.getQuantizedDimension());
writer.writeSignedVarInt(type.getStorageTypeMin());
writer.writeSignedVarInt(type.getStorageTypeMax());
writer.writeList(type.getScales(), [&](const APFloat &type) {
writer.writeAPFloatWithKnownSemantics(type);
});
writer.writeList(type.getZeroPoints(),
[&](int64_t type) { writer.writeSignedVarInt(type); });
}

//===----------------------------------------------------------------------===//
// UniformQuantizedV1Type
//===----------------------------------------------------------------------===//

UniformQuantizedV1Type VhloBytecodeInterface::readUniformQuantizedV1Type(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
uint64_t flags;
Type storageType, expressedType;
uint64_t flags = 0;
Type storageType;
Type expressedType;
FailureOr<APFloat> scale;
int64_t zeroPoint, storageTypeMin, storageTypeMax;
int64_t zeroPoint = 0;
int64_t storageTypeMin = 0;
int64_t storageTypeMax = 0;
if (failed(reader.readVarInt(flags)) ||
failed(reader.readType(storageType)) ||
failed(reader.readType(expressedType)) ||
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def VHLO_Dialect : Dialect {
0.15.0: MLIR bytecode version 5 => 6, use properties in VHLO.
0.16.0: Introduce `collective_broadcast` operation.
0.17.0: Allow reduce operations to promote to higher bitwidth.
0.18.0: Introduce `UniformQuantizedPerAxisType` type.
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
24 changes: 24 additions & 0 deletions stablehlo/dialect/VhloTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ void VhloTypeConverter::addBuiltinToVhloConversions() {
convertedExpressedType, APFloat(type.getScale()), type.getZeroPoint(),
type.getStorageTypeMin(), type.getStorageTypeMax());
});
addConversion([&](quant::UniformQuantizedPerAxisType type) -> Type {
Type convertedStorageType = convertType(type.getStorageType());
Type convertedExpressedType = convertType(type.getExpressedType());
if (!convertedStorageType || !convertedExpressedType) return {};
SmallVector<APFloat> scales = llvm::to_vector(llvm::map_range(
type.getScales(), [](double scale) { return APFloat(scale); }));
return vhlo::UniformQuantizedPerAxisV1Type::get(
type.getContext(), type.getFlags(), convertedStorageType,
convertedExpressedType, type.getQuantizedDimension(), scales,
type.getZeroPoints(), type.getStorageTypeMin(),
type.getStorageTypeMax());
});
addConversion([&](UnrankedTensorType type) -> Type {
auto convertedElementType = convertType(type.getElementType());
if (!convertedElementType) return {};
Expand Down Expand Up @@ -223,6 +235,18 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
type.getScale().convertToDouble(), type.getZeroPoint(),
type.getStorageTypeMin(), type.getStorageTypeMax());
});
addConversion([&](UniformQuantizedPerAxisV1Type type) -> Type {
Type convertedStorageType = convertType(type.getStorageType());
Type convertedExpressedType = convertType(type.getExpressedType());
if (!convertedStorageType || !convertedExpressedType) return {};
SmallVector<double> scales = llvm::to_vector(llvm::map_range(
type.getScales(),
[](const APFloat& scale) { return scale.convertToDouble(); }));
return quant::UniformQuantizedPerAxisType::get(
type.getFlags(), convertedStorageType, convertedExpressedType, scales,
type.getZeroPoints(), type.getQuantizedDimension(),
type.getStorageTypeMin(), type.getStorageTypeMax());
});
addConversion([&](UnrankedTensorV1Type type) -> Type {
auto convertedElementType = convertType(type.getElementType());
if (!convertedElementType) return {};
Expand Down
43 changes: 43 additions & 0 deletions stablehlo/dialect/VhloTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,49 @@ def VHLO_UniformQuantizedV1 : VHLO_TypeDef<"UniformQuantizedV1", "quant_v1", "0.
}];
let assemblyFormat = "`<` $storageType `` `:` `` $expressedType `,` $scale `` `:` `` $zeroPoint `,` $storageTypeMin `` `:` `` $storageTypeMax `,` $flags `>`";
}
def VHLO_QuantizationScalesV1 : ArrayRefParameter<"::llvm::APFloat", "array of double scales"> {
let parser = [{
[&]() -> FailureOr<llvm::SmallVector<::llvm::APFloat>> {
::llvm::SmallVector<double> scales;

auto parseResult = $_parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
return $_parser.parseFloat(scales.emplace_back());
});
if(failed(parseResult)) return failure();
return llvm::to_vector(llvm::map_range(
scales, [](double scale) { return APFloat(scale); }));
}()
}];
let printer = [{
llvm::interleaveComma($_self, $_printer, [&](APFloat scale) {
$_printer << scale;
});
}];
}
def VHLO_UniformQuantizedPerAxisV1 : VHLO_TypeDef<"UniformQuantizedPerAxisV1", "quant_per_axis_v1", "0.18.0", "current"> {
let parameters = (ins
"unsigned":$flags,
"::mlir::Type":$storageType,
"::mlir::Type":$expressedType,
"int32_t":$quantizedDimension,
VHLO_QuantizationScalesV1:$scales,
ArrayRefParameter<"int64_t">:$zeroPoints,
"int64_t":$storageTypeMin,
"int64_t":$storageTypeMax
);
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult UniformQuantizedPerAxisV1Type::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn,
unsigned int, mlir::Type storageType, mlir::Type expressedType,
int32_t, ::llvm::ArrayRef<::llvm::APFloat>, ::llvm::ArrayRef<int64_t>, int64_t, int64_t) {
if (!isFromVhlo(storageType) || !isFromVhlo(expressedType))
return errFn() << "expected VHLO type";
return success();
}
}];
let assemblyFormat = "`<` $storageType `` `:` `` $expressedType `,` $quantizedDimension `,` $scales `,` $zeroPoints `,` $storageTypeMin `` `:` `` $storageTypeMax `,` $flags `>`";
}

// TODO(#8): UnrankedTensor is not part of the StableHLO spec.
// At the moment, it is used to represent unranked dynamism, and we will likely
Expand Down
Loading

0 comments on commit 5405149

Please sign in to comment.