diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index bc01bb709..a3252130e 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -51,6 +51,12 @@ def apply_configuration( ) = lowering_config.subgroup_count_mn workgroup_sizes = lowering_config.workgroup_tile_sizes reduction_sizes = lowering_config.reduction_tile_sizes + gpu_pipeline_options = configuration.translation_info.configuration[ + GPU_PIPELINE_OPTIONS_KEY + ] + waves_per_eu = configuration.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][ + WAVES_PER_EU_KEY + ] tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" @@ -63,11 +69,11 @@ def apply_configuration( expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") repl0 = f"" - repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' + repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.translation_info.workgroup_size))}] subgroup_size = {configuration.translation_info.subgroup_size},' repl2 = f"workgroup = {workgroup_sizes}" repl3 = f"reduction = {reduction_sizes}" - repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" - repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' + repl4 = f"gpu_pipeline_options = {gpu_pipeline_options}" + repl5 = f'"amdgpu-waves-per-eu" = {waves_per_eu}' new_mlir = "" for line in template: @@ -128,15 +134,6 @@ class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: - lowering_config = configuration.lowering_config - intrinsic = lowering_config.mma_kind - ( - subgroup_m_count, - subgroup_n_count, - ) = lowering_config.subgroup_count_mn - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) return f""" transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op @@ -145,13 +142,8 @@ def get_transform_function_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}>, - translation_info = #iree_codegen.translation_info - {extra_config}}}> + lowering_config = {configuration.lowering_config}, + translation_info = {configuration.translation_info} > -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param }} @@ -197,16 +189,6 @@ def get_transform_function_conv( filter = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{dynamic_batch_output_ty}>" - lowering_config = configuration.lowering_config - intrinsic = lowering_config.mma_kind - ( - subgroup_m_count, - subgroup_n_count, - ) = lowering_config.subgroup_count_mn - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - return f""" transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ @@ -217,13 +199,8 @@ def get_transform_function_conv( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}>, - translation_info = #iree_codegen.translation_info - {extra_config}}}> + lowering_config = {configuration.lowering_config}, + translation_info = {configuration.translation_info} > -> !transform.any_param transform.yield %conv, %config : !transform.any_op, !transform.any_param }} @@ -262,16 +239,6 @@ def get_transform_function_broadcast_rhs_mmt( functionName: str, configuration: Configuration, ) -> str: - lowering_config = configuration.lowering_config - intrinsic = lowering_config.mma_kind - ( - subgroup_m_count, - subgroup_n_count, - ) = lowering_config.subgroup_count_mn - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - lhs_dynamic_batch = problem_size.lhs_type lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() lhs_dynamic_batch.shape[0] = -1 @@ -284,13 +251,8 @@ def get_transform_function_broadcast_rhs_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}>, - translation_info = #iree_codegen.translation_info - {extra_config}}}> + lowering_config = {configuration.lowering_config}, + translation_info = {configuration.translation_info} > -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param }} @@ -351,16 +313,6 @@ def get_transform_function_batch_mmt( functionName: str, configuration: Configuration, ) -> str: - lowering_config = configuration.lowering_config - intrinsic = lowering_config.mma_kind - ( - subgroup_m_count, - subgroup_n_count, - ) = lowering_config.subgroup_count_mn - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - return f""" transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ %mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op @@ -369,13 +321,8 @@ def get_transform_function_batch_mmt( transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}>, - translation_info = #iree_codegen.translation_info - {extra_config}}}> + lowering_config = {configuration.lowering_config}, + translation_info = {configuration.translation_info} > -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param }} @@ -421,16 +368,6 @@ def get_transform_function_batch_matmul( input1 = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{problem_size.res_type}>" - lowering_config = configuration.lowering_config - intrinsic = lowering_config.mma_kind - ( - subgroup_m_count, - subgroup_n_count, - ) = lowering_config.subgroup_count_mn - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - return f""" transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ @@ -441,13 +378,8 @@ def get_transform_function_batch_matmul( outs(%out : {output}) -> {output} }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = {configuration.lowering_config}>, - translation_info = #iree_codegen.translation_info - {extra_config}}}> + lowering_config = {configuration.lowering_config}, + translation_info = {configuration.translation_info} > -> !transform.any_param transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param }} diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 45da323c5..7f104bcd9 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -14,6 +14,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore from . import candidate_gen from . import common @@ -56,14 +57,17 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=16, subgroup_n_count=16, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get(prefetch_shared_memory=True) + config_dict = common.get_translation_info_config(pipeline_options, 8) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 16, config_dict + ) config = common.Configuration( - subgroup_size=16, - workgroup_size=[16, 16, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( - prefetch_shared_memory=True - ), - waves_per_eu=8, ) problem_size = common.ProblemSize( @@ -118,16 +122,21 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=4, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get( + reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get( + iree_gpu.ReorderWorkgroupsStrategy.Transpose + ) + ) + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [256, 1, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( - reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get( - iree_gpu.ReorderWorkgroupsStrategy.Transpose - ) - ), - waves_per_eu=2, ) problem_size = common.ProblemSize( @@ -191,12 +200,17 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=4, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [256, 1, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( @@ -246,12 +260,17 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=2, subgroup_n_count=2, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [128, 2, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( @@ -304,12 +323,17 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=2, subgroup_n_count=2, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [128, 2, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) tf_mlir = candidate_gen.BatchMmtTuner().apply_params( @@ -360,12 +384,17 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=2, subgroup_n_count=2, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 4) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [128, 2, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=4, ) tf_mlir = candidate_gen.BatchMmtTuner().apply_params( @@ -440,12 +469,17 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=2, subgroup_n_count=2, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 4) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [128, 2, 1], 64, config_dict + ) config = common.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=4, ) tf_mlir = candidate_gen.ContractionTuner( diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 0a2b03fd1..c683c5bdc 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -14,6 +14,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore class CommonTypes: @@ -112,11 +113,16 @@ def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: @dataclass class Configuration: - subgroup_size: int - workgroup_size: list[int] + translation_info: iree_codegen.TranslationInfoAttr lowering_config: iree_gpu.LoweringConfigAttr - gpu_pipeline_options: iree_gpu.PipelineOptionsAttr - waves_per_eu: int + + +# The key name for GPUPipelineOptionsAttr in the translation info config dictionary. +GPU_PIPELINE_OPTIONS_KEY = "gpu_pipeline_options" +# The key name for llvm_func_attrs attribute in the translation info config dictionary. +LLVM_FUNC_ATTRS_KEY = "llvm_func_attrs" +# The Key name for the 'amdgpu-waves-per-eu' within the llvm_func_attrs attribute. +WAVES_PER_EU_KEY = "amdgpu-waves-per-eu" def get_lowering_config( @@ -157,15 +163,34 @@ def get_lowering_config( return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) -def get_pipeline_config(configuration: Configuration) -> str: - extra_config = "" - pipeline_options = configuration.gpu_pipeline_options - if pipeline_options != iree_gpu.PipelineOptionsAttr.get(): - extra_config += f", gpu_pipeline_options = {pipeline_options}" - - if configuration.waves_per_eu != 2: - extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' - return extra_config +# Generate a config dictionary used in translation_info attribute. +def get_translation_info_config( + pipeline_options: iree_gpu.PipelineOptionsAttr, waves_per_eu: int +) -> ir.DictAttr: + """ + Example IR + translation_info = #iree_codegen.translation_info< + pipeline = LLVMGPUVectorDistribute workgroup_size = [512, 1, 1] subgroup_size = 64, + {gpu_pipeline_options = #iree_gpu.pipeline_options<...>, + llvm_func_attrs = {"amdgpu-waves-per-eu" = "3"} + } + > + """ + waves_per_eu_str = str(waves_per_eu) + + # Create the waves_per_eu dictionary attribute. + waves_per_eu_dict = ir.DictAttr.get( + {WAVES_PER_EU_KEY: ir.StringAttr.get(waves_per_eu_str)} + ) + + config_dict = ir.DictAttr.get( + { + GPU_PIPELINE_OPTIONS_KEY: pipeline_options, + LLVM_FUNC_ATTRS_KEY: waves_per_eu_dict, + } + ) + + return config_dict def read_input_mlir(filename: str) -> list[str]: diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 6d76c216f..af1e1bf9a 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -15,6 +15,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore @pytest.fixture @@ -84,27 +85,36 @@ def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=1, ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 32, config_dict + ) config = common.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) - config1_str: str = common.get_pipeline_config(config) - assert config1_str == "" - - config.waves_per_eu = 4 - config2_str: str = common.get_pipeline_config(config) - assert config2_str == ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + config1_str: str = str( + config.translation_info.configuration[common.LLVM_FUNC_ATTRS_KEY] + ) + assert config1_str == '{"amdgpu-waves-per-eu" = "2"}' - config.gpu_pipeline_options = iree_gpu.PipelineOptionsAttr.get( - prefetch_shared_memory=True + pipeline_options = iree_gpu.PipelineOptionsAttr.get(prefetch_shared_memory=True) + config_dict = common.get_translation_info_config(pipeline_options, 4) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 32, config_dict ) - config3_str = common.get_pipeline_config(config) + config = common.Configuration( + translation_info=translation_info, + lowering_config=lowering_config, + ) + config2_str: str = str(config.translation_info.configuration) assert ( - config3_str - == ', gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + config2_str + == '{gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}}' ) @@ -207,12 +217,17 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: == "#iree_gpu.lowering_config<{reduction = [0, 0, 16], subgroup_m_count = 1 : i64, subgroup_n_count = 1 : i64, workgroup = [4, 8, 0]}>" ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 32, config_dict + ) config = common.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) assert config.lowering_config.mma_kind is None diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index f86523389..8ba202310 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -244,13 +244,21 @@ def generate_solutions( subgroup_m_count=lookup(sg_m_cnt), subgroup_n_count=lookup(sg_n_cnt), ) - config = Configuration( - lookup(subgroup_size), + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = get_translation_info_config( + pipeline_options, lookup(waves_per_eu) + ) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, + None, [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - lowering_config, - iree_gpu.PipelineOptionsAttr.get(), - lookup(waves_per_eu), + lookup(subgroup_size), + config_dict, ) + config = Configuration(translation_info, lowering_config) solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) i += 1 yield config diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index db8c4a7da..a63576808 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -15,6 +15,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import func # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore from . import common from . import dispatch_parser @@ -50,12 +51,17 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=4, ) - config = dispatch_parser.Configuration( - subgroup_size=0, - workgroup_size=[], + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 0) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [], 0, config_dict + ) + config = common.Configuration( + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=0, ) lowering_config = config.lowering_config assert lowering_config.workgroup_tile_sizes == [128, 320, 0] @@ -73,12 +79,17 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=4, ) - config = dispatch_parser.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 1) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [256, 1, 1], 64, config_dict + ) + config = common.Configuration( + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=1, ) assert config.lowering_config.workgroup_tile_sizes == [1, 1, 464, 320, 1, 1, 0] assert config.lowering_config.reduction_tile_sizes == [0, 0, 0, 0, 0, 0, 16] @@ -95,12 +106,17 @@ def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: subgroup_m_count=1, subgroup_n_count=1, ) - config = dispatch_parser.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get() + config_dict = common.get_translation_info_config(pipeline_options, 2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 32, config_dict + ) + config = common.Configuration( + translation_info=translation_info, lowering_config=lowering_config, - gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), - waves_per_eu=2, ) assert dispatch_parser.get_contract_workgroup_sizes(config, "mnk") == [4, 8, 0] assert dispatch_parser.get_contract_reduction_sizes(config, "mnk") == [0, 0, 16]