diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 8838cb72d6..0f8986bcc5 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -9,12 +9,16 @@ import tempfile from datetime import datetime + from pathlib import Path +from typing import Any +import pytest from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.test.conftest import is_option_enabled from executorch.exir.backend.compile_spec_schema import CompileSpec +from runner_utils import corstone300_installed, corstone320_installed def get_time_formatted_path(path: str, log_prefix: str) -> str: @@ -185,3 +189,41 @@ def get_target_board(compile_spec: list[CompileSpec]) -> str | None: elif "u85" in flags: return "corstone-320" return None + + +u55_fvp_mark = pytest.mark.skipif( + not corstone300_installed(), reason="Did not find Corstone-300 FVP on path" +) +""" Marks a test as running on Ethos-U55 FVP, e.g. Corstone 300. Skips the test if this is not installed.""" + +u85_fvp_mark = pytest.mark.skipif( + not corstone320_installed(), reason="Did not find Corstone-320 FVP on path" +) +""" Marks a test as running on Ethos-U85 FVP, e.g. Corstone 320. Skips the test if this is not installed.""" + + +def parametrize( + arg_name: str, test_data: dict[str, Any], xfails: dict[str, str] = None +): + """ + Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality + - test_data is expected as a dict of (id, test_data) pairs + - alllows to specifiy a dict of (id, failure_reason) pairs to mark specific tests as xfail + """ + if xfails is None: + xfails = {} + + def decorator_func(func): + pytest_testsuite = [] + for id, test_parameters in test_data.items(): + if id in xfails: + pytest_param = pytest.param( + test_parameters, id=id, marks=pytest.mark.xfail(reason=xfails[id]) + ) + else: + pytest_param = pytest.param(test_parameters, id=id) + pytest_testsuite.append(pytest_param) + + return pytest.mark.parametrize(arg_name, pytest_testsuite)(func) + + return decorator_func diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 24faace007..8bd13886a2 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -1,178 +1,144 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest from typing import Tuple -import pytest import torch -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.exir import EdgeCompileConfig -from executorch.exir.backend.compile_spec_schema import CompileSpec -from parameterized import parameterized - - -class TestSimpleAdd(unittest.TestCase): - """Tests a single add op, x+x and x+y.""" - - class Add(torch.nn.Module): - test_parameters = [ - (torch.FloatTensor([1, 2, 3, 5, 7]),), - (3 * torch.ones(8),), - (10 * torch.randn(8),), - (torch.ones(1, 1, 4, 4),), - (torch.ones(1, 3, 4, 2),), - ] - - def forward(self, x): - return x + x - - class Add2(torch.nn.Module): - test_parameters = [ - ( - torch.FloatTensor([1, 2, 3, 5, 7]), - (torch.FloatTensor([2, 1, 2, 1, 10])), - ), - (torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)), - (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), - (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), - (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), - ] - - def __init__(self): - super().__init__() - - def forward(self, x, y): - return x + y - - _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( - _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.add.Tensor" +exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + +test_data = { + "5d_float": (torch.FloatTensor([1, 2, 3, 5, 7]),), + "1d_ones": ((3 * torch.ones(8),)), + "1d_randn": (10 * torch.randn(8),), + "4d_ones_1": (torch.ones(1, 1, 4, 4),), + "4d_ones_2": (torch.ones(1, 3, 4, 2),), +} +T1 = Tuple[torch.Tensor] + +test_data2 = { + "5d_float": ( + torch.FloatTensor([1, 2, 3, 5, 7]), + (torch.FloatTensor([2, 1, 2, 1, 10])), + ), + "4d_ones": (torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)), + "4d_randn_1": (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), + "4d_randn_2": (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), + "4d_randn_big": (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), +} +T2 = Tuple[torch.Tensor, torch.Tensor] + + +class Add(torch.nn.Module): + def forward(self, x): + return x + x + + +@common.parametrize("test_data", test_data) +def test_add_tosa_MI(test_data): + pipeline = TosaPipelineMI[T1](Add(), test_data, aten_op, exir_op) + pipeline.run() + + +@common.parametrize("test_data", test_data) +def test_add_tosa_BI(test_data): + pipeline = TosaPipelineBI[T1](Add(), test_data, aten_op, exir_op) + pipeline.run() + + +@common.parametrize("test_data", test_data) +def test_add_u55_BI(test_data): + pipeline = EthosU55PipelineBI[T1]( + Add(), test_data, aten_op, exir_op, run_on_fvp=False ) + pipeline.run() - def _test_add_tosa_MI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .check_count({"torch.ops.aten.add.Tensor": 1}) - .check_not(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - def _test_add_tosa_BI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .check_count({"torch.ops.aten.add.Tensor": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=1) - ) - - def _test_add_ethos_BI_pipeline( - self, - module: torch.nn.Module, - compile_spec: CompileSpec, - test_data: Tuple[torch.Tensor], - ): - tester = ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=compile_spec, - ) - .quantize() - .export() - .check_count({"torch.ops.aten.add.Tensor": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .serialize() - ) - if conftest.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) - - return tester - - @parameterized.expand(Add.test_parameters) - def test_add_tosa_MI(self, test_data: torch.Tensor): - test_data = (test_data,) - self._test_add_tosa_MI_pipeline(self.Add(), test_data) - - @parameterized.expand(Add.test_parameters) - def test_add_tosa_BI(self, test_data: torch.Tensor): - test_data = (test_data,) - self._test_add_tosa_BI_pipeline(self.Add(), test_data) - - @parameterized.expand(Add.test_parameters) - @pytest.mark.corstone_fvp - def test_add_u55_BI(self, test_data: torch.Tensor): - test_data = (test_data,) - self._test_add_ethos_BI_pipeline( - self.Add(), - common.get_u55_compile_spec(permute_memory_to_nhwc=True), - test_data, - ) - - @parameterized.expand(Add.test_parameters) - @pytest.mark.corstone_fvp - def test_add_u85_BI(self, test_data: torch.Tensor): - test_data = (test_data,) - self._test_add_ethos_BI_pipeline( - self.Add(), - common.get_u85_compile_spec(permute_memory_to_nhwc=True), - test_data, - ) - - @parameterized.expand(Add2.test_parameters) - def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_add_tosa_MI_pipeline(self.Add2(), test_data) - - @parameterized.expand(Add2.test_parameters) - def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_add_tosa_BI_pipeline(self.Add2(), test_data) - - @parameterized.expand(Add2.test_parameters) - @pytest.mark.corstone_fvp - def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_add_ethos_BI_pipeline( - self.Add2(), common.get_u55_compile_spec(), test_data - ) - - @parameterized.expand(Add2.test_parameters) - @pytest.mark.corstone_fvp - def test_add2_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_add_ethos_BI_pipeline( - self.Add2(), common.get_u85_compile_spec(), test_data - ) + +@common.parametrize("test_data", test_data) +def test_add_u85_BI(test_data): + pipeline = EthosU85PipelineBI[T1]( + Add(), test_data, aten_op, exir_op, run_on_fvp=False + ) + pipeline.run() + + +@common.parametrize("test_data", test_data) +@common.u55_fvp_mark +def test_add_u55_BI_on_fvp(test_data): + pipeline = EthosU55PipelineBI[T1]( + Add(), test_data, aten_op, exir_op, run_on_fvp=True + ) + pipeline.run() + + +@common.parametrize("test_data", test_data) +@common.u85_fvp_mark +def test_add_u85_BI_on_fvp(test_data): + pipeline = EthosU85PipelineBI[T1]( + Add(), test_data, aten_op, exir_op, run_on_fvp=True + ) + pipeline.run() + + +class Add2(torch.nn.Module): + def forward(self, x, y): + return x + y + + +@common.parametrize("test_data", test_data2) +def test_add2_tosa_MI(test_data): + pipeline = TosaPipelineMI[T2](Add2(), test_data, aten_op, exir_op) + pipeline.run() + + +@common.parametrize("test_data", test_data2) +def test_add2_tosa_BI(test_data): + pipeline = TosaPipelineBI[T2](Add2(), test_data, aten_op, exir_op) + pipeline.run() + + +@common.parametrize("test_data", test_data2) +def test_add2_u55_BI(test_data): + pipeline = EthosU55PipelineBI[T2]( + Add2(), test_data, aten_op, exir_op, run_on_fvp=False + ) + pipeline.run() + + +@common.parametrize("test_data", test_data2) +@common.u55_fvp_mark +def test_add2_u55_BI_on_fvp(test_data): + pipeline = EthosU55PipelineBI[T2]( + Add2(), test_data, aten_op, exir_op, run_on_fvp=True + ) + pipeline.run() + + +@common.parametrize("test_data", test_data2) +def test_add2_u85_BI(test_data): + pipeline = EthosU85PipelineBI[T2]( + Add2(), test_data, aten_op, exir_op, run_on_fvp=False + ) + pipeline.run() + + +@common.parametrize("test_data", test_data2) +@common.u85_fvp_mark +def test_add2_u85_BI_on_fvp(test_data): + pipeline = EthosU85PipelineBI[T2]( + Add2(), test_data, aten_op, exir_op, run_on_fvp=True + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 9ccac53940..ace4c513d6 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -1,20 +1,23 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest from typing import List, Optional, Tuple, Union -import pytest - import torch -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.exir.backend.compile_spec_schema import CompileSpec -from parameterized import parameterized +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.conv2d.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default" class Conv2d(torch.nn.Module): @@ -228,118 +231,72 @@ def forward(self, x): batches=1, ) -# Shenanigan to get a nicer output when test fails. With unittest it looks like: -# FAIL: test_conv2d_tosa_BI_2_3x3_1x3x12x12_st2_pd1 -testsuite = [ - ("2x2_3x2x40x40_nobias", conv2d_2x2_3x2x40x40_nobias), - ("3x3_1x3x256x256_st1", conv2d_3x3_1x3x256x256_st1), - ("3x3_1x3x12x12_st2_pd1", conv2d_3x3_1x3x12x12_st2_pd1), - ("1x1_1x2x128x128_st1", conv2d_1x1_1x2x128x128_st1), - ("2x2_1x1x14x13_st2_needs_adjust_pass", conv2d_2x2_1x1x14x13_st2), - ("conv2d_5x5_1x3x14x15_st3_pd1_needs_adjust_pass", conv2d_5x5_1x3x14x15_st3_pd1), - ("5x5_3x2x128x128_st1", conv2d_5x5_3x2x128x128_st1), - ("3x3_1x3x224x224_st2_pd1", conv2d_3x3_1x3x224x224_st2_pd1), - ("two_conv2d_nobias", two_conv2d_nobias), - ("two_conv2d", two_conv2d), -] - - -class TestConv2D(unittest.TestCase): - """Tests Conv2D, both single ops and multiple Convolutions in series.""" - - def _test_conv2d_tosa_MI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80+MI", permute_memory_to_nhwc=True - ), - ) - .export() - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - def _test_conv2d_tosa_BI_pipeline( - self, - module: torch.nn.Module, - test_data: Tuple[torch.Tensor], - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80+BI", permute_memory_to_nhwc=True - ), - ) - .quantize() - .export() - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=1) - ) - - def _test_conv2d_ethosu_BI_pipeline( - self, - compile_spec: CompileSpec, - module: torch.nn.Module, - test_data: Tuple[torch.Tensor], - ): - tester = ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=compile_spec, - ) - .quantize() - .export() - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) - .to_executorch() - .serialize() - ) - if conftest.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) - - @parameterized.expand(testsuite) - def test_conv2d_tosa_MI(self, test_name, model): - self._test_conv2d_tosa_MI_pipeline(model, model.get_inputs()) - - @parameterized.expand(testsuite) - def test_conv2d_tosa_BI(self, test_name, model): - self._test_conv2d_tosa_BI_pipeline(model, model.get_inputs()) - - # These cases have numerical issues on FVP, MLETORCH-520 - testsuite.remove(("2x2_3x2x40x40_nobias", conv2d_2x2_3x2x40x40_nobias)) - testsuite.remove(("5x5_3x2x128x128_st1", conv2d_5x5_3x2x128x128_st1)) - - @parameterized.expand(testsuite) - @pytest.mark.corstone_fvp - def test_conv2d_u55_BI(self, test_name, model): - self._test_conv2d_ethosu_BI_pipeline( - common.get_u55_compile_spec(permute_memory_to_nhwc=True), - model, - model.get_inputs(), - ) - - @parameterized.expand(testsuite) - @pytest.mark.corstone_fvp - def test_conv2d_u85_BI(self, test_name, model): - self._test_conv2d_ethosu_BI_pipeline( - common.get_u85_compile_spec(permute_memory_to_nhwc=True), - model, - model.get_inputs(), - ) +test_modules = { + "2x2_3x2x40x40_nobias": conv2d_2x2_3x2x40x40_nobias, + "3x3_1x3x256x256_st1": conv2d_3x3_1x3x256x256_st1, + "3x3_1x3x12x12_st2_pd1": conv2d_3x3_1x3x12x12_st2_pd1, + "1x1_1x2x128x128_st1": conv2d_1x1_1x2x128x128_st1, + "2x2_1x1x14x13_st2_needs_adjust_pass": conv2d_2x2_1x1x14x13_st2, + "conv2d_5x5_1x3x14x15_st3_pd1_needs_adjust_pass": conv2d_5x5_1x3x14x15_st3_pd1, + "5x5_3x2x128x128_st1": conv2d_5x5_3x2x128x128_st1, + "3x3_1x3x224x224_st2_pd1": conv2d_3x3_1x3x224x224_st2_pd1, + "two_conv2d_nobias": two_conv2d_nobias, + "two_conv2d": two_conv2d, +} +T1 = Tuple[torch.Tensor] + + +@common.parametrize("test_module", test_modules) +def test_conv2d_tosa_MI(test_module): + pipeline = TosaPipelineMI[T1]( + test_module, test_module.get_inputs(), aten_op, exir_op + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +def test_conv2d_tosa_BI(test_module): + pipeline = TosaPipelineBI[T1]( + test_module, test_module.get_inputs(), aten_op, exir_op + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +def test_conv2d_u55_BI(test_module): + pipeline = EthosU55PipelineBI[T1]( + test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=False + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +def test_conv2d_u85_BI(test_module): + pipeline = EthosU85PipelineBI[T1]( + test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=False + ) + pipeline.run() + + +xfails = { + "2x2_3x2x40x40_nobias": "MLETORCH-520: Numerical issues on FVP.", + "5x5_3x2x128x128_st1": "MLETORCH-520: Numerical issues on FVP.", +} + + +@common.parametrize("test_module", test_modules, xfails) +@common.u55_fvp_mark +def test_conv2d_u55_BI_on_fvp(test_module): + pipeline = EthosU55PipelineBI[T1]( + test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules, xfails) +@common.u85_fvp_mark +def test_conv2d_u85_BI_on_fvp(test_module): + pipeline = EthosU85PipelineBI[T1]( + test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True + ) + pipeline.run() diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 9ae1a27cf7..1dce924e3d 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -664,3 +664,21 @@ def _tosa_refmodel_loglevel(loglevel: int) -> str: } clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0) return loglevel_map[clamped_logging_level] + + +def corstone300_installed() -> bool: + cmd = ["FVP_Corstone_SSE-300_Ethos-U55", "--version"] + try: + _run_cmd(cmd, check=True) + except: + return False + return True + + +def corstone320_installed() -> bool: + cmd = ["FVP_Corstone_SSE-320", "--version"] + try: + _run_cmd(cmd, check=True) + except: + return False + return True diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py new file mode 100644 index 0000000000..3b2dce8952 --- /dev/null +++ b/backends/arm/test/tester/test_pipeline.py @@ -0,0 +1,361 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Any, Callable, Generic, List, TypeVar + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +logger = logging.getLogger(__name__) +T = TypeVar("T") +""" Generic type used for test data in the pipeline. Depends on which type the operator expects.""" + + +class BasePipeline(Generic[T]): + """ + This pipeline defines a list of stages to be run on a given module with input of data type T. This list can be modified in any way before running the pipeline to support various usecases. + """ + + class PipelineStage: + """ + Helper class to store a pipeline stage as a function call + args for calling later on. + """ + + def __init__(self, func, *args, **kwargs): + self.id: str = func.__name__ + self.func: Callable = func + self.args = args + self.kwargs = kwargs + self.is_called = False + + def __call__(self): + if not self.is_called: + self.func(*self.args, **self.kwargs) + else: + raise RuntimeError(f"{self.id} called twice.") + self.is_called = True + + def update(self, *args, **kwargs): + if not self.is_called: + self.args = args + self.kwargs = kwargs + else: + raise RuntimeError(f"{self.id} args updated after being called.") + + def __init__( + self, + module: torch.nn.Module, + test_data: T, + aten_ops: str | List[str], + exir_ops: str | List[str], + compile_spec: List[CompileSpec], + use_to_edge_transform_and_lower: bool = False, + ): + + self.tester = ArmTester( + module, example_inputs=test_data, compile_spec=compile_spec + ) + + self.aten_ops = aten_ops if isinstance(aten_ops, list) else [aten_ops] + self.exir_ops = exir_ops if isinstance(exir_ops, list) else [exir_ops] + self.test_data = test_data + self._stages = [] + + self.add_stage(-1, self.tester.export) + self.add_stage(-1, self.tester.check, self.aten_ops) + if use_to_edge_transform_and_lower: + self.add_stage(-1, self.tester.to_edge_transform_and_lower) + + else: + self.add_stage(-1, self.tester.to_edge) + self.add_stage(-1, self.tester.check, self.exir_ops) + self.add_stage(-1, self.tester.partition) + self.add_stage(-1, self.tester.check_not, self.exir_ops) + self.add_stage( + -1, + self.tester.check_count, + {"torch.ops.higher_order.executorch_call_delegate": 1}, + ) + self.add_stage(-1, self.tester.to_executorch) + + def add_stage(self, pos: int, func: Callable, *args, **kwargs): + pipeline_stage = self.PipelineStage(func, *args, **kwargs) + if pos == -1: + self._stages.append(pipeline_stage) + else: + self._stages.insert(pos, pipeline_stage) + + return self + + def pop_stage(self, pos: int): + return self._stages.pop(pos) + + def find_pos(self, stage_id: str): + for i, stage in enumerate(self._stages): + if stage.id == stage_id: + return i + + raise Exception(f"Stage id {stage_id} not found in pipeline") + + def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs): + pos = self.find_pos(stage_id) + self.add_stage(pos + 1, func, *args, **kwargs) + return self + + def dump_artifact(self, stage_id: str): + self.add_stage_after(stage_id, self.tester.dump_artifact) + return self + + def plot(self, stage_id: str): + self.add_stage_after(stage_id, self.tester.plot) + return self + + def dump_operator_distribution(self, stage_id: str): + self.add_stage_after(stage_id, self.tester.dump_operator_distribution) + return self + + def change_args(self, stage_id: str, *args, **kwargs): + pos = self.find_pos(stage_id) + pipeline_stage = self._stages[pos] + pipeline_stage.update(*args, **kwargs) + return self + + def run(self): + stage_list = [stage.id for stage in self._stages] + logger.info(f"Running pipeline with stages {stage_list}.") + + for stage in self._stages: + try: + stage() + except Exception as e: + logger.error(f"\nFailure in stage <{stage.id}>: \n {str(e)}") + raise e + + +class TosaPipelineBI(BasePipeline, Generic[T]): + def __init__( + self, + module: torch.nn.Module, + test_data: Any, + aten_op: str, + exir_op: str, + tosa_version: str = "TOSA-0.80+BI", + use_to_edge_transform_and_lower: bool = False, + ): + compile_spec = common.get_tosa_compile_spec( + tosa_version, permute_memory_to_nhwc=True + ) + super().__init__( + module, + test_data, + aten_op, + exir_op, + compile_spec, + use_to_edge_transform_and_lower, + ) + self.add_stage(0, self.tester.quantize) + self.add_stage_after( + "quantize", + self.tester.check, + ["torch.ops.quantized_decomposed.dequantize_per_tensor.default"], + ) + self.add_stage_after( + "quantize", + self.tester.check, + ["torch.ops.quantized_decomposed.quantize_per_tensor.default"], + ) + + +class TosaPipelineMI(BasePipeline, Generic[T]): + def __init__( + self, + module: torch.nn.Module, + test_data: Any, + aten_op: str, + exir_op: str, + tosa_version: str = "TOSA-0.80+MI", + use_to_edge_transform_and_lower: bool = False, + ): + compile_spec = common.get_tosa_compile_spec( + tosa_version, permute_memory_to_nhwc=True + ) + super().__init__( + module, + test_data, + aten_op, + exir_op, + compile_spec, + use_to_edge_transform_and_lower, + ) + self.add_stage_after( + "export", + self.tester.check_not, + ["torch.ops.quantized_decomposed.dequantize_per_tensor.default"], + ) + self.add_stage_after( + "export", + self.tester.check_not, + ["torch.ops.quantized_decomposed.quantize_per_tensor.default"], + ) + + self.add_stage( + -1, self.tester.run_method_and_compare_outputs, inputs=self.test_data + ) + + +class EthosU55PipelineBI(BasePipeline, Generic[T]): + def __init__( + self, + module: torch.nn.Module, + test_data: T, + aten_ops: str | List[str], + exir_ops: str | List[str], + run_on_fvp: bool = False, + use_to_edge_transform_and_lower: bool = False, + ): + compile_spec = common.get_u55_compile_spec(permute_memory_to_nhwc=True) + super().__init__( + module, + test_data, + aten_ops, + exir_ops, + compile_spec, + use_to_edge_transform_and_lower, + ) + self.add_stage(0, self.tester.quantize) + self.add_stage_after( + "quantize", + self.tester.check, + ["torch.ops.quantized_decomposed.dequantize_per_tensor.default"], + ) + self.add_stage_after( + "quantize", + self.tester.check, + ["torch.ops.quantized_decomposed.quantize_per_tensor.default"], + ) + if run_on_fvp: + self.add_stage(-1, self.tester.serialize) + self.add_stage( + -1, + self.tester.run_method_and_compare_outputs, + qtol=1, + inputs=self.test_data, + ) + + +class EthosU55PipelineMI(BasePipeline, Generic[T]): + def __init__( + self, + module: torch.nn.Module, + test_data: T, + aten_ops: str | List[str], + exir_ops: str | List[str], + run_on_fvp: bool = False, + use_to_edge_transform_and_lower: bool = False, + ): + compile_spec = common.get_u55_compile_spec(permute_memory_to_nhwc=True) + super().__init__( + module, + test_data, + aten_ops, + exir_ops, + compile_spec, + use_to_edge_transform_and_lower, + ) + self.add_stage_after( + "export", + self.tester.check_not, + ["torch.ops.quantized_decomposed.dequantize_per_tensor.default"], + ) + self.add_stage_after( + "export", + self.tester.check_not, + ["torch.ops.quantized_decomposed.quantize_per_tensor.default"], + ) + if run_on_fvp: + self.add_stage(-1, self.tester.serialize) + self.add_stage( + -1, self.tester.run_method_and_compare_outputs, inputs=self.test_data + ) + + +class EthosU85PipelineBI(BasePipeline, Generic[T]): + def __init__( + self, + module: torch.nn.Module, + test_data: T, + aten_ops: str | List[str], + exir_ops: str | List[str], + run_on_fvp: bool = False, + use_to_edge_transform_and_lower: bool = False, + ): + compile_spec = common.get_u85_compile_spec(permute_memory_to_nhwc=True) + super().__init__( + module, + test_data, + aten_ops, + exir_ops, + compile_spec, + use_to_edge_transform_and_lower, + ) + self.add_stage(0, self.tester.quantize) + self.add_stage_after( + "quantize", + self.tester.check, + ["torch.ops.quantized_decomposed.dequantize_per_tensor.default"], + ) + self.add_stage_after( + "quantize", + self.tester.check, + ["torch.ops.quantized_decomposed.quantize_per_tensor.default"], + ) + if run_on_fvp: + self.add_stage(-1, self.tester.serialize) + self.add_stage( + -1, + self.tester.run_method_and_compare_outputs, + qtol=1, + inputs=self.test_data, + ) + + +class EthosU85PipelineMI(BasePipeline, Generic[T]): + def __init__( + self, + module: torch.nn.Module, + test_data: T, + aten_ops: str | List[str], + exir_ops: str | List[str], + run_on_fvp: bool = False, + use_to_edge_transform_and_lower: bool = False, + ): + compile_spec = common.get_u85_compile_spec(permute_memory_to_nhwc=True) + super().__init__( + module, + test_data, + aten_ops, + exir_ops, + compile_spec, + use_to_edge_transform_and_lower, + ) + self.add_stage_after( + "export", + self.tester.check_not, + ["torch.ops.quantized_decomposed.dequantize_per_tensor.default"], + ) + self.add_stage_after( + "export", + self.tester.check_not, + ["torch.ops.quantized_decomposed.quantize_per_tensor.default"], + ) + if run_on_fvp: + self.add_stage(-1, self.tester.serialize) + self.add_stage( + -1, self.tester.run_method_and_compare_outputs, inputs=self.test_data + )