diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 1e1e48d66..07a694131 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -37,17 +37,6 @@ class DispatchTuner(DispatchParser): - # TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove this in favor of configuring using transform dialect. - @abstractmethod - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> MLIRTransformation: - """Apply parameter transformations to the operation.""" - pass - @abstractmethod def get_td_spec( self, @@ -59,25 +48,13 @@ def get_td_spec( class DispatchTunerRegistry: - def __init__(self, check_translation_info=True): - self.check_translation_info = check_translation_info + def __init__(self): self.registry = set() def register(self, dispatch_tuners: list[DispatchTuner]) -> None: for dispatch_tuner in dispatch_tuners: self.registry.add(dispatch_tuner) - # TODO(Max191): Remove translation info validation. - def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool: - if not self.check_translation_info: - return True - for attr in attrs: - if (attr.name == "translation_info") and ( - "LLVMGPUVectorDistribute" in str(attr.attr) - ): - return True - assert False, "Translation info not supported" - def find_handler(self, op_name: str) -> DispatchTuner: for dispatch_tuner in self.registry: if dispatch_tuner.supports(op_name): @@ -86,14 +63,6 @@ def find_handler(self, op_name: str) -> DispatchTuner: class ContractionOpInterfaceTuner(DispatchTuner, ContractionOpInterfaceParser): - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> MLIRTransformation: - raise NotImplementedError - def get_td_spec( self, ir_module: ir.Module, @@ -114,14 +83,6 @@ def get_td_spec( class ConvolutionOpInterfaceTuner(DispatchTuner, ConvolutionOpInterfaceParser): - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - compilation_info: iree_codegen.CompilationInfoAttr, - ) -> MLIRTransformation: - raise NotImplementedError - def get_td_spec( self, ir_module: ir.Module, @@ -158,8 +119,6 @@ def walk_callback_get_fn( walk_result: OpWalkResult, dispatch_tuner_registry: DispatchTunerRegistry, ) -> ir.WalkResult: - if op.name == "func.func": - dispatch_tuner_registry.validate_translation([a for a in op.opview.attributes]) if op.name == "util.func": func_name = str(op.opview.sym_name) walk_result.was_interrupted = True @@ -198,7 +157,7 @@ def generate_configs_and_td_specs( pipeline_options_search_space: PipelineOptionsSearchSpace = PipelineOptionsSearchSpace(), codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute, ) -> list[ir.Module]: - dispatch_tuner_registry = DispatchTunerRegistry(check_translation_info=False) + dispatch_tuner_registry = DispatchTunerRegistry() dispatch_tuner_registry.register( [ ContractionOpInterfaceTuner(),