From 87cfca40360fa61e479f325028cf8968e1dbf739 Mon Sep 17 00:00:00 2001 From: Abhinav Date: Tue, 20 Feb 2024 20:15:55 +0000 Subject: [PATCH] added per-axis, per-tensor in test names --- stablehlo/tests/ops_stablehlo_quantized.mlir | 52 ++++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/tests/ops_stablehlo_quantized.mlir index 82eab643a32..3150f57e860 100644 --- a/stablehlo/tests/ops_stablehlo_quantized.mlir +++ b/stablehlo/tests/ops_stablehlo_quantized.mlir @@ -4,8 +4,8 @@ // ----- // Tests for StableHLO OPs supporting per-axis quantization. These OPs also support per-tensor quantization. -// CHECK-LABEL: @per_axis_quantized_ops -func.func @per_axis_quantized_ops( +// CHECK-LABEL: @ops_per_axis_quantization +func.func @ops_per_axis_quantization( %arg0: tensor<1x2x2x!quant.uniform:f32:2, {0.1:-30, 0.5:-20}>>, %arg1: tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>, %token0: !stablehlo.token) { @@ -25,7 +25,7 @@ func.func @per_axis_quantized_ops( } // %arg1 can be a per-axis Quantized -func.func @dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { +func.func @dot_general_per_axis_quantization(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0], @@ -40,7 +40,7 @@ func.func @dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>, %arg2: tensor>, @@ -92,7 +92,7 @@ func.func @per_tensor_quantized_ops( func.return } -func.func @batch_norm_grad_quantization(%input: tensor<2x2x2x2x!quant.uniform>, %scale: tensor<2x!quant.uniform>, %mean: tensor<2x!quant.uniform>, %variance: tensor<2x!quant.uniform>, %grad_output: tensor<2x2x2x2x!quant.uniform>) -> tensor<2x2x2x2x!quant.uniform> { +func.func @batch_norm_grad_per_tensor_quantization(%input: tensor<2x2x2x2x!quant.uniform>, %scale: tensor<2x!quant.uniform>, %mean: tensor<2x!quant.uniform>, %variance: tensor<2x!quant.uniform>, %grad_output: tensor<2x2x2x2x!quant.uniform>) -> tensor<2x2x2x2x!quant.uniform> { %0:3 = "stablehlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x2x2x2x!quant.uniform>) -> (tensor<2x2x2x2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) @@ -100,7 +100,7 @@ func.func @batch_norm_grad_quantization(%input: tensor<2x2x2x2x!quant.uniform>, %scale: tensor<256x!quant.uniform>, %offset: tensor<256x!quant.uniform>, %mean: tensor<256x!quant.uniform>, %variance: tensor<256x!quant.uniform>) -> (tensor<4x256x!quant.uniform>) { +func.func @batch_norm_inference_per_tensor_quantization(%input: tensor<4x256x!quant.uniform>, %scale: tensor<256x!quant.uniform>, %offset: tensor<256x!quant.uniform>, %mean: tensor<256x!quant.uniform>, %variance: tensor<256x!quant.uniform>) -> (tensor<4x256x!quant.uniform>) { %0 = "stablehlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) { epsilon = 1.001000e-05 : f32, feature_index = 1 : i64 @@ -108,7 +108,7 @@ func.func @batch_norm_inference_quantization(%input: tensor<4x256x!quant.uniform func.return %0 : tensor<4x256x!quant.uniform> } -func.func @batch_norm_training_quantization(%input: tensor<2x2x2x2x!quant.uniform>, %scale: tensor<2x!quant.uniform>, %offset: tensor<2x!quant.uniform>) -> tensor<2x2x2x2x!quant.uniform> { +func.func @batch_norm_training_per_tensor_quantization(%input: tensor<2x2x2x2x!quant.uniform>, %scale: tensor<2x!quant.uniform>, %offset: tensor<2x!quant.uniform>) -> tensor<2x2x2x2x!quant.uniform> { %0:3 = "stablehlo.batch_norm_training" (%input, %scale, %offset) { epsilon = 0.001 : f32, feature_index = 1 : i64 @@ -117,7 +117,7 @@ func.func @batch_norm_training_quantization(%input: tensor<2x2x2x2x!quant.unifor func.return %0#0 : tensor<2x2x2x2x!quant.uniform> } -func.func @dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { +func.func @dot_general_per_tensor_quantization(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0], @@ -129,17 +129,17 @@ func.func @dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform> } -func.func @dynamic_slice_quantization(%arg0: tensor<3x4x!quant.uniform>, %arg1: tensor, %arg2: tensor) -> tensor<1x4x!quant.uniform> { +func.func @dynamic_slice_per_tensor_quantization(%arg0: tensor<3x4x!quant.uniform>, %arg1: tensor, %arg2: tensor) -> tensor<1x4x!quant.uniform> { %0 = "stablehlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = array} : (tensor<3x4x!quant.uniform>, tensor, tensor) -> tensor<1x4x!quant.uniform> func.return %0 : tensor<1x4x!quant.uniform> } -func.func @dynamic_update_slice_pertensor_quantization(%operand: tensor<3x4x!quant.uniform>, %update: tensor<1x4x!quant.uniform>, %start_indices0: tensor, %start_indices1: tensor) -> tensor<3x4x!quant.uniform> { +func.func @dynamic_update_slice_per_tensor_quantization(%operand: tensor<3x4x!quant.uniform>, %update: tensor<1x4x!quant.uniform>, %start_indices0: tensor, %start_indices1: tensor) -> tensor<3x4x!quant.uniform> { %0 = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1) : (tensor<3x4x!quant.uniform>, tensor<1x4x!quant.uniform>, tensor, tensor) -> tensor<3x4x!quant.uniform> func.return %0 : tensor<3x4x!quant.uniform> } -func.func @gather_quantization(%operand : tensor<*x!quant.uniform>, %start_indices : tensor<1x5x2xi32>) -> tensor<8x?x7x1x6x1x?x!quant.uniform> { +func.func @gather_per_tensor_quantization(%operand : tensor<*x!quant.uniform>, %start_indices : tensor<1x5x2xi32>) -> tensor<8x?x7x1x6x1x?x!quant.uniform> { %res = "stablehlo.gather"(%operand, %start_indices) { dimension_numbers = #stablehlo.gather< offset_dims = [0, 2, 3, 4, 5], @@ -153,7 +153,7 @@ func.func @gather_quantization(%operand : tensor<*x!quant.uniform> } -func.func @map_quantization(%arg0: tensor<4x!quant.uniform>, %arg1: tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> { +func.func @map_per_tensor_quantization(%arg0: tensor<4x!quant.uniform>, %arg1: tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> { %0 = "stablehlo.map"(%arg0, %arg1) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): "stablehlo.return"(%arg2) : (tensor>) -> () @@ -161,7 +161,7 @@ func.func @map_quantization(%arg0: tensor<4x!quant.uniform>, %ar func.return %0 : tensor<4x!quant.uniform> } -func.func @pad_quantization(%arg0: tensor<1x2x3x!quant.uniform>, %arg1: tensor>) -> tensor<2x4x7x!quant.uniform> { +func.func @pad_per_tensor_quantization(%arg0: tensor<1x2x3x!quant.uniform>, %arg1: tensor>) -> tensor<2x4x7x!quant.uniform> { %0 = "stablehlo.pad"(%arg0, %arg1) { edge_padding_low = array, edge_padding_high = array, @@ -170,7 +170,7 @@ func.func @pad_quantization(%arg0: tensor<1x2x3x!quant.uniform>, func.return %0 : tensor<2x4x7x!quant.uniform> } -func.func @reduce_quantization(%arg0: tensor<16x!quant.uniform>, %arg1: tensor>) -> tensor> { +func.func @reduce_per_tensor_quantization(%arg0: tensor<16x!quant.uniform>, %arg1: tensor>) -> tensor> { %0 = "stablehlo.reduce"(%arg0, %arg1) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): %1 = "stablehlo.add"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> @@ -181,7 +181,7 @@ func.func @reduce_quantization(%arg0: tensor<16x!quant.uniform>, func.return %0 : tensor> } -func.func @reduce_precision_quantization(%arg0: tensor<6x!quant.uniform>) -> tensor<6x!quant.uniform> { +func.func @reduce_per_tensor_precision_quantization(%arg0: tensor<6x!quant.uniform>) -> tensor<6x!quant.uniform> { %output = "stablehlo.reduce_precision"(%arg0) { exponent_bits = 5 : i32, mantissa_bits = 10 : i32 @@ -190,7 +190,7 @@ func.func @reduce_precision_quantization(%arg0: tensor<6x!quant.uniform>) -> tensor<4x4x!quant.uniform> { +func.func @reduce_scatter_per_tensor_quantization(%data: tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform> { %0 = "stablehlo.reduce_scatter"(%data) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): %1 = stablehlo.add %arg2, %arg3 : tensor> @@ -203,7 +203,7 @@ func.func @reduce_scatter_quantization(%data: tensor<4x16x!quant.uniform>, %arg1: tensor>) -> tensor<2x9x16x7x!quant.uniform> { +func.func @op_reduce_window_per_tensor_quantization(%arg0: tensor<2x17x31x7x!quant.uniform>, %arg1: tensor>) -> tensor<2x9x16x7x!quant.uniform> { %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> @@ -218,25 +218,25 @@ func.func @op_reduce_window_quantization(%arg0: tensor<2x17x31x7x!quant.uniform< func.return %0 : tensor<2x9x16x7x!quant.uniform> } -func.func @reverse_quantization(%operand: tensor<3x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> { +func.func @reverse_per_tensor_quantization(%operand: tensor<3x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> { %result = "stablehlo.reverse"(%operand) { dimensions = array } : (tensor<3x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> func.return %result : tensor<3x2x!quant.uniform> } -func.func @round_afz_quantization(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { +func.func @round_afz_per_tensor_quantization(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { %0 = "stablehlo.round_nearest_afz"(%arg0) {} : (tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> func.return %0 : tensor<2x!quant.uniform> } -func.func @round_even_quantization(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { +func.func @round_even_per_tensor_quantization(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { %0 = "stablehlo.round_nearest_even"(%arg0) {} : (tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> func.return %0 : tensor<2x!quant.uniform> } -func.func @scatter_quantization(%arg0: tensor<200x100x300x!quant.uniform>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300x!quant.uniform>) -> tensor<200x100x300x!quant.uniform> { +func.func @scatter_per_tensor_quantization(%arg0: tensor<200x100x300x!quant.uniform>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300x!quant.uniform>) -> tensor<200x100x300x!quant.uniform> { %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor>, %arg4: tensor>): %1 = "stablehlo.add"(%arg3, %arg4) : (tensor>, tensor>) -> tensor> @@ -252,12 +252,12 @@ func.func @scatter_quantization(%arg0: tensor<200x100x300x!quant.uniform> } -func.func @select_quantization(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3x!quant.uniform>, %arg2: tensor<2x3x!quant.uniform>) -> tensor<2x3x!quant.uniform> { +func.func @select_per_tensor_quantization(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3x!quant.uniform>, %arg2: tensor<2x3x!quant.uniform>) -> tensor<2x3x!quant.uniform> { %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3x!quant.uniform>, tensor<2x3x!quant.uniform>) -> tensor<2x3x!quant.uniform> func.return %0 : tensor<2x3x!quant.uniform> } -func.func @select_and_scatter_quantization(%arg0: tensor<10x24x24x64x!quant.uniform>, %arg1: tensor<10x23x23x64x!quant.uniform>, %arg2: tensor>) -> tensor<10x24x24x64x!quant.uniform> { +func.func @select_and_scatter_per_tensor_quantization(%arg0: tensor<10x24x24x64x!quant.uniform>, %arg1: tensor<10x23x23x64x!quant.uniform>, %arg2: tensor>) -> tensor<10x24x24x64x!quant.uniform> { %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor>, %arg4: tensor>): %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor>, tensor>) -> tensor @@ -272,12 +272,12 @@ func.func @select_and_scatter_quantization(%arg0: tensor<10x24x24x64x!quant.unif func.return %0 : tensor<10x24x24x64x!quant.uniform> } -func.func @slice_qunatization(%arg0: tensor<3x4x!quant.uniform>) -> tensor<1x2x!quant.uniform> { +func.func @slice_per_tensor_qunatization(%arg0: tensor<3x4x!quant.uniform>) -> tensor<1x2x!quant.uniform> { %0 = "stablehlo.slice"(%arg0) {start_indices = array, limit_indices = array, strides = array} : (tensor<3x4x!quant.uniform>) -> tensor<1x2x!quant.uniform> func.return %0 : tensor<1x2x!quant.uniform> } -func.func @sort_quantization(%input0: tensor<16x16x!quant.uniform>, %input1: tensor<16x16x!quant.uniform>) { +func.func @sort_per_tensor_quantization(%input0: tensor<16x16x!quant.uniform>, %input1: tensor<16x16x!quant.uniform>) { %0:2 = "stablehlo.sort"(%input0, %input1) ({ ^bb0(%arg0: tensor>, %arg1: tensor>, %arg2: tensor>, %arg3: tensor>): %7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo} : (tensor>, tensor>) -> tensor @@ -286,7 +286,7 @@ func.func @sort_quantization(%input0: tensor<16x16x!quant.uniform>) -> tensor<*x!quant.uniform> { +func.func @while_per_tensor_quantization(%arg0: tensor<4x!quant.uniform>) -> tensor<*x!quant.uniform> { %while = "stablehlo.while"(%arg0) ({ ^bb0(%arg1: tensor>): %1 = stablehlo.constant dense : tensor