Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantization Verifiers based on T2x set of Traits #2041

Merged
merged 14 commits into from
Mar 5, 2024
35 changes: 20 additions & 15 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ limitations under the License.
namespace mlir {
namespace hlo {

namespace {
Type getExpressedTypeOrSelf(Type type) {
auto quantType = type.dyn_cast<quant::QuantizedType>();
return quantType ? quantType.getExpressedType() : type;
}
} // namespace

LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2) {
if (failed(verifyCompatibleShape(type1, type2))) return failure();

Expand Down Expand Up @@ -72,20 +65,32 @@ bool isCompatibleElementTypeForHloTypeInference(Type tp1, Type tp2) {
tp1 = getElementTypeOrSelf(tp1);
tp2 = getElementTypeOrSelf(tp2);

// Quantization: In the most general case, we allow any combination of
// quantized/non-quantized across any combination of operands/results,
// and some differences in quantization parameters across operands/results.
// Individual ops may introduce additional constraints.
// For quantized types:
// a. both `tp1` and `tp2` should be quantized types
// b. with similar quantization granularity (i.e. both per-tensor or both
// per-axis)
// c. with equal storage_type, storage_type_min, storage_type_max, and
// expressed_type
auto qtp1 = tp1.dyn_cast<quant::QuantizedType>();
auto qtp2 = tp2.dyn_cast<quant::QuantizedType>();
if (qtp1 && qtp2) {
if (qtp1.getStorageType() != qtp2.getStorageType() ||
qtp1.getStorageTypeMin() != qtp2.getStorageTypeMin() ||
qtp1.getStorageTypeMax() != qtp2.getStorageTypeMax())
qtp1.getStorageTypeMax() != qtp2.getStorageTypeMax() ||
qtp1.getExpressedType() != qtp2.getExpressedType()) {
abhigunj marked this conversation as resolved.
Show resolved Hide resolved
return false;
}

auto qpatp1 = qtp1.dyn_cast<quant::UniformQuantizedPerAxisType>();
auto qpatp2 = qtp2.dyn_cast<quant::UniformQuantizedPerAxisType>();
bool quantizationGranularityMatches =
(qpatp1 && qpatp2) || (!qpatp1 && !qpatp2);
abhigunj marked this conversation as resolved.
Show resolved Hide resolved

return quantizationGranularityMatches;
}
auto etp1 = getExpressedTypeOrSelf(tp1);
auto etp2 = getExpressedTypeOrSelf(tp2);

// return false if only one is of quantized type
if (qtp1 || qtp2) return false;

// Sparsity: In the most general case, we allow any combination of
// sparsity/denseness across any combination of operands/results, as well as
Expand All @@ -96,7 +101,7 @@ bool isCompatibleElementTypeForHloTypeInference(Type tp1, Type tp2) {

// Default case: Unless dynamism, quantization and/or sparsity are involved,
// the types are required to be exactly equal.
return etp1 == etp2;
return tp1 == tp2;
}

bool isCompatibleForHloTypeInference(Type tp1, Type tp2) {
Expand Down
23 changes: 23 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,29 @@ class CompatibleOperandsAndResultElementType
}
};

template <typename ConcreteType>
class CompatibleOperandsElementType
: public mlir::OpTrait::TraitBase<ConcreteType,
CompatibleOperandsElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
abhigunj marked this conversation as resolved.
Show resolved Hide resolved
if (failed(mlir::OpTrait::impl::verifyAtLeastNOperands(op, 1)))
return failure();

Type expected = op->getOperand(0).getType();
auto typeMatch = [&](Type actual) {
return isCompatibleElementTypeForHloTypeInference(actual, expected);
};
auto allMatch = llvm::all_of(op->getOperandTypes(), typeMatch);
if (!allMatch) {
return op->emitOpError(
"requires compatible element types for all operands");
}

return success();
}
};

template <typename ConcreteType>
class CompatibleOperandsAndResultType
: public mlir::OpTrait::TraitBase<ConcreteType,
Expand Down
3 changes: 3 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ def HLO_CompatibleOperandsAndResultType : TraitList<
def HLO_CompatibleOperandsAndResultElementType :
HLO_NativeOpTrait<"CompatibleOperandsAndResultElementType">;

def HLO_CompatibleOperandsElementType :
HLO_NativeOpTrait<"CompatibleOperandsElementType">;

def HLO_BoundedAttrInterface : AttrInterface<"BoundedAttrInterface"> {
let cppNamespace = "::mlir::hlo";

Expand Down
17 changes: 7 additions & 10 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ def StableHLO_TupleOp : StableHLO_Op<"tuple", [Pure,
}

def StableHLO_CompareOp: StableHLO_Op<"compare", [Pure, Elementwise,
SameOperandsElementType /*compare_c1*/,
HLO_CompatibleOperandsElementType /*compare_c1*/,
SameOperandsAndResultShape /*compare_c2*/,
InferTensorTypeWithReify /*compare_c1, compare_c2*/]> {
let summary = "Compare operation";
Expand Down Expand Up @@ -1685,8 +1685,7 @@ def StableHLO_DynamicUpdateSliceOp: StableHLO_Op<"dynamic_update_slice",
//===----------------------------------------------------------------------===//

def StableHLO_BatchNormGradOp : StableHLO_Op<"batch_norm_grad", [Pure,
AllElementTypesMatch<["operand", "scale", "mean", "variance", "grad_output",
"grad_operand", "grad_scale", "grad_offset"] /*batch_norm_grad_c2*/>,
HLO_CompatibleOperandsAndResultElementType /*batch_norm_grad_c2*/,
InferTensorType /*batch_norm_grad_c3, batch_norm_grad_c4*/]> {
let summary = "BatchNormGrad operation";
let description = [{
Expand Down Expand Up @@ -1725,8 +1724,7 @@ def StableHLO_BatchNormGradOp : StableHLO_Op<"batch_norm_grad", [Pure,
}

def StableHLO_BatchNormInferenceOp : StableHLO_Op<"batch_norm_inference",
[Pure, AllElementTypesMatch<["operand", "scale", "offset", "mean",
"variance", "result"]> /*batch_norm_inference_c2*/,
[Pure, HLO_CompatibleOperandsAndResultElementType /*batch_norm_inference_c2*/,
InferTensorType /*batch_norm_inference_c7*/]> {
let summary = "BatchNormInference operation";
let description = [{
Expand Down Expand Up @@ -1759,8 +1757,7 @@ def StableHLO_BatchNormInferenceOp : StableHLO_Op<"batch_norm_inference",
}

def StableHLO_BatchNormTrainingOp : StableHLO_Op<"batch_norm_training",
[Pure, AllElementTypesMatch<["operand", "scale", "offset", "output",
"batch_mean", "batch_var"]> /*batch_norm_training_c2*/,
[Pure, HLO_CompatibleOperandsAndResultElementType /*batch_norm_training_c2*/,
InferTensorType /*batch_norm_training_c5, batch_norm_training_c6, batch_norm_training_c7*/]> {
let summary = "BatchNormTraining operation";
let description = [{
Expand Down Expand Up @@ -1927,7 +1924,7 @@ def StableHLO_DynamicBroadcastInDimOp : StableHLO_ShapedInterfaceOp<
// directly.

def StableHLO_CholeskyOp : StableHLO_Op<"cholesky",
[Pure, SameOperandsAndResultElementType /*cholesky_c1*/,
[Pure, HLO_CompatibleOperandsAndResultElementType /*cholesky_c1*/,
InferTensorType /*cholesky_c1*/]> {
let summary = "Cholesky operation";
let description = [{
Expand All @@ -1954,7 +1951,7 @@ def StableHLO_CholeskyOp : StableHLO_Op<"cholesky",
}

def StableHLO_ClampOp : StableHLO_ShapedInterfaceOp<"clamp", [Pure,
SameOperandsAndResultElementType /* clamp_c3 */, HLO_BroadcastingElementwise,
HLO_CompatibleOperandsAndResultElementType /* clamp_c3 */, HLO_BroadcastingElementwise,
InferTensorType]> {
let summary = "Clamp operation";
let description = [{
Expand Down Expand Up @@ -2814,7 +2811,7 @@ def StableHLO_TransposeOp: StableHLO_ShapedInterfaceOp<"transpose",
}

def StableHLO_TriangularSolveOp: StableHLO_Op<"triangular_solve",
[Pure, SameOperandsAndResultElementType, InferTensorType]> {
[Pure, HLO_CompatibleOperandsAndResultElementType, InferTensorType]> {
let summary = "TriangularSolve operation";
let description = [{
Solves batches of systems of linear equations with lower or upper triangular
Expand Down
10 changes: 2 additions & 8 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5070,16 +5070,10 @@ func.func @is_compatible_dynamism_dim_mismatch(%arg0: tensor<1x?xf32>) {

// -----

// TODO(b/230263270): For stablehlo.add, the plan is to only allow fp+fp=fp, q+q=q and q+q=fp.
func.func @is_compatible_quant_mix_non_quant(%arg0: tensor<1xf32>, %arg1: tensor<1x!quant.uniform<i8:f32, 1.0:17>>) {
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "stablehlo.add"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%2 = "stablehlo.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%3 = "stablehlo.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%4 = "stablehlo.add"(%arg1, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1xf32>) -> tensor<1xf32>
%5 = "stablehlo.add"(%arg1, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1xf32>) -> tensor<1xf32>
%6 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%7 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%1 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%2 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
func.return
}

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/tests/ops_stablehlo_quantized.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func.func @ops_per_tensor_quantization(
%sqrt = "stablehlo.sqrt"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%subtract = "stablehlo.subtract"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%tanh = "stablehlo.tanh"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%transpose = "stablehlo.transpose"(%arg0) {permutation = array<i64: 0, 2, 1>}: (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:1, {0.1:-30, 0.5:-20}>>
%transpose = "stablehlo.transpose"(%arg0) {permutation = array<i64: 0, 2, 1>}: (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8<-128:127>:f32, 1.0:17>>
%tuple = "stablehlo.tuple"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tuple<tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>>
%uniform_dequantize = "stablehlo.uniform_dequantize" (%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2xf32>
%uniform_quantize = "stablehlo.uniform_quantize" (%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2226,10 +2226,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2233,10 +2233,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2233,10 +2233,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2233,10 +2233,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2233,10 +2233,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2241,10 +2241,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc
Binary file not shown.
8 changes: 4 additions & 4 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2252,10 +2252,10 @@ func.func @type_dynamism_ranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
}

// CHECK-LABEL: "type_quantization"
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.f32_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
func.func @type_quantization(%arg0: tensor<!quant.uniform<i8:f32, 34.0:16>>, %arg1: tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>> {
// CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>, !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>) -> !vhlo.tensor_v1<!vhlo.quant_v1<!vhlo.i8_v1:!vhlo.f32_v1, 3.400000e+01:16, -128:127, 1>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32, 34.0:16>>, tensor<!quant.uniform<i8:f32, 34.0:16>>) -> tensor<!quant.uniform<i8:f32, 34.0:16>>
func.return %0 : tensor<!quant.uniform<i8:f32, 34.0:16>>
}

// CHECK: function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.token_v1) -> !vhlo.token_v1>>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc
Binary file not shown.
Loading
Loading