Skip to content

Commit

Permalink
[tuner] clean up candidate gen (#797)
Browse files Browse the repository at this point in the history
- removed unused function `apply_params`
- removed unused function `validate translation`

Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu authored Jan 9, 2025
1 parent f715322 commit 7849f8e
Showing 1 changed file with 2 additions and 43 deletions.
45 changes: 2 additions & 43 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,6 @@


class DispatchTuner(DispatchParser):
# TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove this in favor of configuring using transform dialect.
@abstractmethod
def apply_params(
self,
problem_size: ProblemSize,
template: list[str],
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
"""Apply parameter transformations to the operation."""
pass

@abstractmethod
def get_td_spec(
self,
Expand All @@ -59,25 +48,13 @@ def get_td_spec(


class DispatchTunerRegistry:
def __init__(self, check_translation_info=True):
self.check_translation_info = check_translation_info
def __init__(self):
self.registry = set()

def register(self, dispatch_tuners: list[DispatchTuner]) -> None:
for dispatch_tuner in dispatch_tuners:
self.registry.add(dispatch_tuner)

# TODO(Max191): Remove translation info validation.
def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool:
if not self.check_translation_info:
return True
for attr in attrs:
if (attr.name == "translation_info") and (
"LLVMGPUVectorDistribute" in str(attr.attr)
):
return True
assert False, "Translation info not supported"

def find_handler(self, op_name: str) -> DispatchTuner:
for dispatch_tuner in self.registry:
if dispatch_tuner.supports(op_name):
Expand All @@ -86,14 +63,6 @@ def find_handler(self, op_name: str) -> DispatchTuner:


class ContractionOpInterfaceTuner(DispatchTuner, ContractionOpInterfaceParser):
def apply_params(
self,
problem_size: ProblemSize,
template: list[str],
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
raise NotImplementedError

def get_td_spec(
self,
ir_module: ir.Module,
Expand All @@ -114,14 +83,6 @@ def get_td_spec(


class ConvolutionOpInterfaceTuner(DispatchTuner, ConvolutionOpInterfaceParser):
def apply_params(
self,
problem_size: ProblemSize,
template: list[str],
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
raise NotImplementedError

def get_td_spec(
self,
ir_module: ir.Module,
Expand Down Expand Up @@ -158,8 +119,6 @@ def walk_callback_get_fn(
walk_result: OpWalkResult,
dispatch_tuner_registry: DispatchTunerRegistry,
) -> ir.WalkResult:
if op.name == "func.func":
dispatch_tuner_registry.validate_translation([a for a in op.opview.attributes])
if op.name == "util.func":
func_name = str(op.opview.sym_name)
walk_result.was_interrupted = True
Expand Down Expand Up @@ -198,7 +157,7 @@ def generate_configs_and_td_specs(
pipeline_options_search_space: PipelineOptionsSearchSpace = PipelineOptionsSearchSpace(),
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute,
) -> list[ir.Module]:
dispatch_tuner_registry = DispatchTunerRegistry(check_translation_info=False)
dispatch_tuner_registry = DispatchTunerRegistry()
dispatch_tuner_registry.register(
[
ContractionOpInterfaceTuner(),
Expand Down

0 comments on commit 7849f8e

Please sign in to comment.