Skip to content

Commit

Permalink
added per-axis, per-tensor in test names
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigunj committed Feb 20, 2024
1 parent 4f5a839 commit 87cfca4
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions stablehlo/tests/ops_stablehlo_quantized.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
// -----
// Tests for StableHLO OPs supporting per-axis quantization. These OPs also support per-tensor quantization.

// CHECK-LABEL: @per_axis_quantized_ops
func.func @per_axis_quantized_ops(
// CHECK-LABEL: @ops_per_axis_quantization
func.func @ops_per_axis_quantization(
%arg0: tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>,
%arg1: tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:0, {0.1:-30}>>,
%token0: !stablehlo.token) {
Expand All @@ -25,7 +25,7 @@ func.func @per_axis_quantized_ops(
}

// %arg1 can be a per-axis Quantized
func.func @dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<2x3x5x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<2x4x5x!quant.uniform<i8:f32:0, {0.1:-30}>> {
func.func @dot_general_per_axis_quantization(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<2x3x5x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<2x4x5x!quant.uniform<i8:f32:0, {0.1:-30}>> {
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
Expand All @@ -40,7 +40,7 @@ func.func @dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 1
// -----
// Tests for StableHLO OPs supporting per-tensor quantization. These OPs may or may not support per-axis quantization

func.func @per_tensor_quantized_ops(
func.func @ops_per_tensor_quantization(
%arg0: tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>,
%arg1: tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>,
%arg2: tensor<!quant.uniform<i8:f32, 1.0:17>>,
Expand Down Expand Up @@ -92,23 +92,23 @@ func.func @per_tensor_quantized_ops(
func.return
}

func.func @batch_norm_grad_quantization(%input: tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>>, %scale: tensor<2x!quant.uniform<i8:f32, 1.0:17>>, %mean: tensor<2x!quant.uniform<i8:f32, 1.0:17>>, %variance: tensor<2x!quant.uniform<i8:f32, 1.0:17>>, %grad_output: tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>> {
func.func @batch_norm_grad_per_tensor_quantization(%input: tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>>, %scale: tensor<2x!quant.uniform<i8:f32, 1.0:17>>, %mean: tensor<2x!quant.uniform<i8:f32, 1.0:17>>, %variance: tensor<2x!quant.uniform<i8:f32, 1.0:17>>, %grad_output: tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>> {
%0:3 = "stablehlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output)
{epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<2x!quant.uniform<i8:f32, 1.0:17>>, tensor<2x!quant.uniform<i8:f32, 1.0:17>>, tensor<2x!quant.uniform<i8:f32, 1.0:17>>, tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>>)
-> (tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<2x!quant.uniform<i8:f32, 1.0:17>>, tensor<2x!quant.uniform<i8:f32, 1.0:17>>)
func.return %0#0 : tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>>
}


func.func @batch_norm_inference_quantization(%input: tensor<4x256x!quant.uniform<i8:f32, 1.0:17>>, %scale: tensor<256x!quant.uniform<i8:f32, 1.0:17>>, %offset: tensor<256x!quant.uniform<i8:f32, 1.0:17>>, %mean: tensor<256x!quant.uniform<i8:f32, 1.0:17>>, %variance: tensor<256x!quant.uniform<i8:f32, 1.0:17>>) -> (tensor<4x256x!quant.uniform<i8:f32, 1.0:17>>) {
func.func @batch_norm_inference_per_tensor_quantization(%input: tensor<4x256x!quant.uniform<i8:f32, 1.0:17>>, %scale: tensor<256x!quant.uniform<i8:f32, 1.0:17>>, %offset: tensor<256x!quant.uniform<i8:f32, 1.0:17>>, %mean: tensor<256x!quant.uniform<i8:f32, 1.0:17>>, %variance: tensor<256x!quant.uniform<i8:f32, 1.0:17>>) -> (tensor<4x256x!quant.uniform<i8:f32, 1.0:17>>) {
%0 = "stablehlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) {
epsilon = 1.001000e-05 : f32,
feature_index = 1 : i64
} : (tensor<4x256x!quant.uniform<i8:f32, 1.0:17>>, tensor<256x!quant.uniform<i8:f32, 1.0:17>>, tensor<256x!quant.uniform<i8:f32, 1.0:17>>, tensor<256x!quant.uniform<i8:f32, 1.0:17>>, tensor<256x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<4x256x!quant.uniform<i8:f32, 1.0:17>>
func.return %0 : tensor<4x256x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @batch_norm_training_quantization(%input: tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>>, %scale: tensor<2x!quant.uniform<i8:f32, 1.0:17>>, %offset: tensor<2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>> {
func.func @batch_norm_training_per_tensor_quantization(%input: tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>>, %scale: tensor<2x!quant.uniform<i8:f32, 1.0:17>>, %offset: tensor<2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>> {
%0:3 = "stablehlo.batch_norm_training" (%input, %scale, %offset) {
epsilon = 0.001 : f32,
feature_index = 1 : i64
Expand All @@ -117,7 +117,7 @@ func.func @batch_norm_training_quantization(%input: tensor<2x2x2x2x!quant.unifor
func.return %0#0 : tensor<2x2x2x2x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<2x3x5x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x4x5x!quant.uniform<i8:f32, 1.0:17>> {
func.func @dot_general_per_tensor_quantization(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<2x3x5x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x4x5x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
Expand All @@ -129,17 +129,17 @@ func.func @dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 1
func.return %0 : tensor<2x4x5x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @dynamic_slice_quantization(%arg0: tensor<3x4x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4x!quant.uniform<i8:f32, 1.0:17>> {
func.func @dynamic_slice_per_tensor_quantization(%arg0: tensor<3x4x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = array<i64: 1, 4>} : (tensor<3x4x!quant.uniform<i8:f32, 1.0:17>>, tensor<i64>, tensor<i64>) -> tensor<1x4x!quant.uniform<i8:f32, 1.0:17>>
func.return %0 : tensor<1x4x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @dynamic_update_slice_pertensor_quantization(%operand: tensor<3x4x!quant.uniform<i8:f32, 1.0:17>>, %update: tensor<1x4x!quant.uniform<i8:f32, 1.0:17>>, %start_indices0: tensor<i64>, %start_indices1: tensor<i64>) -> tensor<3x4x!quant.uniform<i8:f32, 1.0:17>> {
func.func @dynamic_update_slice_per_tensor_quantization(%operand: tensor<3x4x!quant.uniform<i8:f32, 1.0:17>>, %update: tensor<1x4x!quant.uniform<i8:f32, 1.0:17>>, %start_indices0: tensor<i64>, %start_indices1: tensor<i64>) -> tensor<3x4x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1) : (tensor<3x4x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x4x!quant.uniform<i8:f32, 1.0:17>>, tensor<i64>, tensor<i64>) -> tensor<3x4x!quant.uniform<i8:f32, 1.0:17>>
func.return %0 : tensor<3x4x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @gather_quantization(%operand : tensor<*x!quant.uniform<i8:f32, 1.0:17>>, %start_indices : tensor<1x5x2xi32>) -> tensor<8x?x7x1x6x1x?x!quant.uniform<i8:f32, 1.0:17>> {
func.func @gather_per_tensor_quantization(%operand : tensor<*x!quant.uniform<i8:f32, 1.0:17>>, %start_indices : tensor<1x5x2xi32>) -> tensor<8x?x7x1x6x1x?x!quant.uniform<i8:f32, 1.0:17>> {
%res = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [0, 2, 3, 4, 5],
Expand All @@ -153,15 +153,15 @@ func.func @gather_quantization(%operand : tensor<*x!quant.uniform<i8:f32, 1.0:17
func.return %res : tensor<8x?x7x1x6x1x?x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @map_quantization(%arg0: tensor<4x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<4x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<4x!quant.uniform<i8:f32, 1.0:17>> {
func.func @map_per_tensor_quantization(%arg0: tensor<4x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<4x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<4x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.map"(%arg0, %arg1) ({
^bb0(%arg2: tensor<!quant.uniform<i8:f32, 1.0:17>>, %arg3: tensor<!quant.uniform<i8:f32, 1.0:17>>):
"stablehlo.return"(%arg2) : (tensor<!quant.uniform<i8:f32, 1.0:17>>) -> ()
}) {dimensions = array<i64: 0>} : (tensor<4x!quant.uniform<i8:f32, 1.0:17>>, tensor<4x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<4x!quant.uniform<i8:f32, 1.0:17>>
func.return %0 : tensor<4x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @pad_quantization(%arg0: tensor<1x2x3x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x4x7x!quant.uniform<i8:f32, 1.0:17>> {
func.func @pad_per_tensor_quantization(%arg0: tensor<1x2x3x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x4x7x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.pad"(%arg0, %arg1) {
edge_padding_low = array<i64: 0, 1, 2>,
edge_padding_high = array<i64: 1, 1, 0>,
Expand All @@ -170,7 +170,7 @@ func.func @pad_quantization(%arg0: tensor<1x2x3x!quant.uniform<i8:f32, 1.0:17>>,
func.return %0 : tensor<2x4x7x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @reduce_quantization(%arg0: tensor<16x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<!quant.uniform<i8:f32, 1.0:17>>) -> tensor<!quant.uniform<i8:f32, 1.0:17>> {
func.func @reduce_per_tensor_quantization(%arg0: tensor<16x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<!quant.uniform<i8:f32, 1.0:17>>) -> tensor<!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.reduce"(%arg0, %arg1) ({
^bb0(%arg2: tensor<!quant.uniform<i8:f32, 1.0:17>>, %arg3: tensor<!quant.uniform<i8:f32, 1.0:17>>):
%1 = "stablehlo.add"(%arg2, %arg3) : (tensor<!quant.uniform<i8:f32, 1.0:17>>, tensor<!quant.uniform<i8:f32, 1.0:17>>) -> tensor<!quant.uniform<i8:f32, 1.0:17>>
Expand All @@ -181,7 +181,7 @@ func.func @reduce_quantization(%arg0: tensor<16x!quant.uniform<i8:f32, 1.0:17>>,
func.return %0 : tensor<!quant.uniform<i8:f32, 1.0:17>>
}

func.func @reduce_precision_quantization(%arg0: tensor<6x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<6x!quant.uniform<i8:f32, 1.0:17>> {
func.func @reduce_per_tensor_precision_quantization(%arg0: tensor<6x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<6x!quant.uniform<i8:f32, 1.0:17>> {
%output = "stablehlo.reduce_precision"(%arg0) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
Expand All @@ -190,7 +190,7 @@ func.func @reduce_precision_quantization(%arg0: tensor<6x!quant.uniform<i8:f32,
}


func.func @reduce_scatter_quantization(%data: tensor<4x16x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<4x4x!quant.uniform<i8:f32, 1.0:17>> {
func.func @reduce_scatter_per_tensor_quantization(%data: tensor<4x16x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<4x4x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.reduce_scatter"(%data) ({
^bb0(%arg2: tensor<!quant.uniform<i8:f32, 1.0:17>>, %arg3: tensor<!quant.uniform<i8:f32, 1.0:17>>):
%1 = stablehlo.add %arg2, %arg3 : tensor<!quant.uniform<i8:f32, 1.0:17>>
Expand All @@ -203,7 +203,7 @@ func.func @reduce_scatter_quantization(%data: tensor<4x16x!quant.uniform<i8:f32,
}


func.func @op_reduce_window_quantization(%arg0: tensor<2x17x31x7x!quant.uniform<i8:f32, 0.1:-30>>, %arg1: tensor<!quant.uniform<i8:f32, 0.1:-30>>) -> tensor<2x9x16x7x!quant.uniform<i8:f32, 0.1:-30>> {
func.func @op_reduce_window_per_tensor_quantization(%arg0: tensor<2x17x31x7x!quant.uniform<i8:f32, 0.1:-30>>, %arg1: tensor<!quant.uniform<i8:f32, 0.1:-30>>) -> tensor<2x9x16x7x!quant.uniform<i8:f32, 0.1:-30>> {
%0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
^bb0(%arg2: tensor<!quant.uniform<i8:f32, 0.1:-30>>, %arg3: tensor<!quant.uniform<i8:f32, 0.1:-30>>):
%1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor<!quant.uniform<i8:f32, 0.1:-30>>, tensor<!quant.uniform<i8:f32, 0.1:-30>>) -> tensor<!quant.uniform<i8:f32, 0.1:-30>>
Expand All @@ -218,25 +218,25 @@ func.func @op_reduce_window_quantization(%arg0: tensor<2x17x31x7x!quant.uniform<
func.return %0 : tensor<2x9x16x7x!quant.uniform<i8:f32, 0.1:-30>>
}

func.func @reverse_quantization(%operand: tensor<3x2x!quant.uniform<i8:f32, 0.1:-30>>) -> tensor<3x2x!quant.uniform<i8:f32, 0.1:-30>> {
func.func @reverse_per_tensor_quantization(%operand: tensor<3x2x!quant.uniform<i8:f32, 0.1:-30>>) -> tensor<3x2x!quant.uniform<i8:f32, 0.1:-30>> {
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2x!quant.uniform<i8:f32, 0.1:-30>>) -> tensor<3x2x!quant.uniform<i8:f32, 0.1:-30>>
func.return %result : tensor<3x2x!quant.uniform<i8:f32, 0.1:-30>>
}

func.func @round_afz_quantization(%arg0: tensor<2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x!quant.uniform<i8:f32, 1.0:17>> {
func.func @round_afz_per_tensor_quantization(%arg0: tensor<2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.round_nearest_afz"(%arg0) {} : (tensor<2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x!quant.uniform<i8:f32, 1.0:17>>
func.return %0 : tensor<2x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @round_even_quantization(%arg0: tensor<2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x!quant.uniform<i8:f32, 1.0:17>> {
func.func @round_even_per_tensor_quantization(%arg0: tensor<2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.round_nearest_even"(%arg0) {} : (tensor<2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x!quant.uniform<i8:f32, 1.0:17>>
func.return %0 : tensor<2x!quant.uniform<i8:f32, 1.0:17>>
}


func.func @scatter_quantization(%arg0: tensor<200x100x300x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<200x100x300x!quant.uniform<i8:f32, 1.0:17>> {
func.func @scatter_per_tensor_quantization(%arg0: tensor<200x100x300x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<200x100x300x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({
^bb0(%arg3: tensor<!quant.uniform<i8:f32, 1.0:17>>, %arg4: tensor<!quant.uniform<i8:f32, 1.0:17>>):
%1 = "stablehlo.add"(%arg3, %arg4) : (tensor<!quant.uniform<i8:f32, 1.0:17>>, tensor<!quant.uniform<i8:f32, 1.0:17>>) -> tensor<!quant.uniform<i8:f32, 1.0:17>>
Expand All @@ -252,12 +252,12 @@ func.func @scatter_quantization(%arg0: tensor<200x100x300x!quant.uniform<i8:f32,
func.return %0 : tensor<200x100x300x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @select_quantization(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3x!quant.uniform<i8:f32, 1.0:17>>, %arg2: tensor<2x3x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x3x!quant.uniform<i8:f32, 1.0:17>> {
func.func @select_per_tensor_quantization(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3x!quant.uniform<i8:f32, 1.0:17>>, %arg2: tensor<2x3x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x3x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3x!quant.uniform<i8:f32, 1.0:17>>, tensor<2x3x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<2x3x!quant.uniform<i8:f32, 1.0:17>>
func.return %0 : tensor<2x3x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @select_and_scatter_quantization(%arg0: tensor<10x24x24x64x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<10x23x23x64x!quant.uniform<i8:f32, 1.0:17>>, %arg2: tensor<!quant.uniform<i8:f32, 1.0:17>>) -> tensor<10x24x24x64x!quant.uniform<i8:f32, 1.0:17>> {
func.func @select_and_scatter_per_tensor_quantization(%arg0: tensor<10x24x24x64x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<10x23x23x64x!quant.uniform<i8:f32, 1.0:17>>, %arg2: tensor<!quant.uniform<i8:f32, 1.0:17>>) -> tensor<10x24x24x64x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({
^bb0(%arg3: tensor<!quant.uniform<i8:f32, 1.0:17>>, %arg4: tensor<!quant.uniform<i8:f32, 1.0:17>>):
%1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo<comparison_type TOTALORDER>, comparison_direction = #stablehlo<comparison_direction GE>} : (tensor<!quant.uniform<i8:f32, 1.0:17>>, tensor<!quant.uniform<i8:f32, 1.0:17>>) -> tensor<i1>
Expand All @@ -272,12 +272,12 @@ func.func @select_and_scatter_quantization(%arg0: tensor<10x24x24x64x!quant.unif
func.return %0 : tensor<10x24x24x64x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @slice_qunatization(%arg0: tensor<3x4x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x!quant.uniform<i8:f32, 1.0:17>> {
func.func @slice_per_tensor_qunatization(%arg0: tensor<3x4x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x!quant.uniform<i8:f32, 1.0:17>> {
%0 = "stablehlo.slice"(%arg0) {start_indices = array<i64: 1, 0>, limit_indices = array<i64: 2, 4>, strides = array<i64: 1, 2>} : (tensor<3x4x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x!quant.uniform<i8:f32, 1.0:17>>
func.return %0 : tensor<1x2x!quant.uniform<i8:f32, 1.0:17>>
}

func.func @sort_quantization(%input0: tensor<16x16x!quant.uniform<i8:f32, 1.0:17>>, %input1: tensor<16x16x!quant.uniform<i8:f32, 1.0:17>>) {
func.func @sort_per_tensor_quantization(%input0: tensor<16x16x!quant.uniform<i8:f32, 1.0:17>>, %input1: tensor<16x16x!quant.uniform<i8:f32, 1.0:17>>) {
%0:2 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<!quant.uniform<i8:f32, 1.0:17>>, %arg2: tensor<!quant.uniform<i8:f32, 1.0:17>>, %arg3: tensor<!quant.uniform<i8:f32, 1.0:17>>):
%7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction GT>} : (tensor<!quant.uniform<i8:f32, 1.0:17>>, tensor<!quant.uniform<i8:f32, 1.0:17>>) -> tensor<i1>
Expand All @@ -286,7 +286,7 @@ func.func @sort_quantization(%input0: tensor<16x16x!quant.uniform<i8:f32, 1.0:17
func.return
}

func.func @while_quantization(%arg0: tensor<4x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<*x!quant.uniform<i8:f32, 1.0:17>> {
func.func @while_per_tensor_quantization(%arg0: tensor<4x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<*x!quant.uniform<i8:f32, 1.0:17>> {
%while = "stablehlo.while"(%arg0) ({
^bb0(%arg1: tensor<?x!quant.uniform<i8:f32, 1.0:17>>):
%1 = stablehlo.constant dense<true> : tensor<i1>
Expand Down

0 comments on commit 87cfca4

Please sign in to comment.