Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm unittest refactor of Add and Conv2D test cases #7541

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion backends/arm/test/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -9,12 +9,16 @@

import tempfile
from datetime import datetime

from pathlib import Path
from typing import Any

import pytest
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder

from executorch.backends.arm.test.conftest import is_option_enabled
from executorch.exir.backend.compile_spec_schema import CompileSpec
from runner_utils import corstone300_installed, corstone320_installed


def get_time_formatted_path(path: str, log_prefix: str) -> str:
Expand Down Expand Up @@ -185,3 +189,41 @@ def get_target_board(compile_spec: list[CompileSpec]) -> str | None:
elif "u85" in flags:
return "corstone-320"
return None


u55_fvp_mark = pytest.mark.skipif(
not corstone300_installed(), reason="Did not find Corstone-300 FVP on path"
)
""" Marks a test as running on Ethos-U55 FVP, e.g. Corstone 300. Skips the test if this is not installed."""

u85_fvp_mark = pytest.mark.skipif(
not corstone320_installed(), reason="Did not find Corstone-320 FVP on path"
)
""" Marks a test as running on Ethos-U85 FVP, e.g. Corstone 320. Skips the test if this is not installed."""


def parametrize(
arg_name: str, test_data: dict[str, Any], xfails: dict[str, str] = None
):
"""
Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality
- test_data is expected as a dict of (id, test_data) pairs
- alllows to specifiy a dict of (id, failure_reason) pairs to mark specific tests as xfail
"""
if xfails is None:
xfails = {}

def decorator_func(func):
pytest_testsuite = []
for id, test_parameters in test_data.items():
if id in xfails:
pytest_param = pytest.param(
test_parameters, id=id, marks=pytest.mark.xfail(reason=xfails[id])
)
else:
pytest_param = pytest.param(test_parameters, id=id)
pytest_testsuite.append(pytest_param)

return pytest.mark.parametrize(arg_name, pytest_testsuite)(func)

return decorator_func
298 changes: 132 additions & 166 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
@@ -1,178 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from typing import Tuple

import pytest
import torch
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.exir import EdgeCompileConfig
from executorch.exir.backend.compile_spec_schema import CompileSpec
from parameterized import parameterized


class TestSimpleAdd(unittest.TestCase):
"""Tests a single add op, x+x and x+y."""

class Add(torch.nn.Module):
test_parameters = [
(torch.FloatTensor([1, 2, 3, 5, 7]),),
(3 * torch.ones(8),),
(10 * torch.randn(8),),
(torch.ones(1, 1, 4, 4),),
(torch.ones(1, 3, 4, 2),),
]

def forward(self, x):
return x + x

class Add2(torch.nn.Module):
test_parameters = [
(
torch.FloatTensor([1, 2, 3, 5, 7]),
(torch.FloatTensor([2, 1, 2, 1, 10])),
),
(torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)),
(torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
(torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)),
(10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
]

def __init__(self):
super().__init__()

def forward(self, x, y):
return x + y

_edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)

aten_op = "torch.ops.aten.add.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"

test_data = {
"5d_float": (torch.FloatTensor([1, 2, 3, 5, 7]),),
"1d_ones": ((3 * torch.ones(8),)),
"1d_randn": (10 * torch.randn(8),),
"4d_ones_1": (torch.ones(1, 1, 4, 4),),
"4d_ones_2": (torch.ones(1, 3, 4, 2),),
}
T1 = Tuple[torch.Tensor]

test_data2 = {
"5d_float": (
torch.FloatTensor([1, 2, 3, 5, 7]),
(torch.FloatTensor([2, 1, 2, 1, 10])),
),
"4d_ones": (torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)),
"4d_randn_1": (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
"4d_randn_2": (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)),
"4d_randn_big": (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
}
T2 = Tuple[torch.Tensor, torch.Tensor]


class Add(torch.nn.Module):
def forward(self, x):
return x + x


@common.parametrize("test_data", test_data)
def test_add_tosa_MI(test_data):
pipeline = TosaPipelineMI[T1](Add(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", test_data)
def test_add_tosa_BI(test_data):
pipeline = TosaPipelineBI[T1](Add(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", test_data)
def test_add_u55_BI(test_data):
pipeline = EthosU55PipelineBI[T1](
Add(), test_data, aten_op, exir_op, run_on_fvp=False
)
pipeline.run()

def _test_add_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
)
.export()
.check_count({"torch.ops.aten.add.Tensor": 1})
.check_not(["torch.ops.quantized_decomposed"])
.to_edge(config=self._edge_compile_config)
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_add_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
)
.quantize()
.export()
.check_count({"torch.ops.aten.add.Tensor": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge(config=self._edge_compile_config)
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
)

def _test_add_ethos_BI_pipeline(
self,
module: torch.nn.Module,
compile_spec: CompileSpec,
test_data: Tuple[torch.Tensor],
):
tester = (
ArmTester(
module,
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize()
.export()
.check_count({"torch.ops.aten.add.Tensor": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.serialize()
)
if conftest.is_option_enabled("corstone_fvp"):
tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)

return tester

@parameterized.expand(Add.test_parameters)
def test_add_tosa_MI(self, test_data: torch.Tensor):
test_data = (test_data,)
self._test_add_tosa_MI_pipeline(self.Add(), test_data)

@parameterized.expand(Add.test_parameters)
def test_add_tosa_BI(self, test_data: torch.Tensor):
test_data = (test_data,)
self._test_add_tosa_BI_pipeline(self.Add(), test_data)

@parameterized.expand(Add.test_parameters)
@pytest.mark.corstone_fvp
def test_add_u55_BI(self, test_data: torch.Tensor):
test_data = (test_data,)
self._test_add_ethos_BI_pipeline(
self.Add(),
common.get_u55_compile_spec(permute_memory_to_nhwc=True),
test_data,
)

@parameterized.expand(Add.test_parameters)
@pytest.mark.corstone_fvp
def test_add_u85_BI(self, test_data: torch.Tensor):
test_data = (test_data,)
self._test_add_ethos_BI_pipeline(
self.Add(),
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
test_data,
)

@parameterized.expand(Add2.test_parameters)
def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_add_tosa_MI_pipeline(self.Add2(), test_data)

@parameterized.expand(Add2.test_parameters)
def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_add_tosa_BI_pipeline(self.Add2(), test_data)

@parameterized.expand(Add2.test_parameters)
@pytest.mark.corstone_fvp
def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_add_ethos_BI_pipeline(
self.Add2(), common.get_u55_compile_spec(), test_data
)

@parameterized.expand(Add2.test_parameters)
@pytest.mark.corstone_fvp
def test_add2_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_add_ethos_BI_pipeline(
self.Add2(), common.get_u85_compile_spec(), test_data
)

@common.parametrize("test_data", test_data)
def test_add_u85_BI(test_data):
pipeline = EthosU85PipelineBI[T1](
Add(), test_data, aten_op, exir_op, run_on_fvp=False
)
pipeline.run()


@common.parametrize("test_data", test_data)
@common.u55_fvp_mark
def test_add_u55_BI_on_fvp(test_data):
pipeline = EthosU55PipelineBI[T1](
Add(), test_data, aten_op, exir_op, run_on_fvp=True
)
pipeline.run()


@common.parametrize("test_data", test_data)
@common.u85_fvp_mark
def test_add_u85_BI_on_fvp(test_data):
pipeline = EthosU85PipelineBI[T1](
Add(), test_data, aten_op, exir_op, run_on_fvp=True
)
pipeline.run()


class Add2(torch.nn.Module):
def forward(self, x, y):
return x + y


@common.parametrize("test_data", test_data2)
def test_add2_tosa_MI(test_data):
pipeline = TosaPipelineMI[T2](Add2(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", test_data2)
def test_add2_tosa_BI(test_data):
pipeline = TosaPipelineBI[T2](Add2(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", test_data2)
def test_add2_u55_BI(test_data):
pipeline = EthosU55PipelineBI[T2](
Add2(), test_data, aten_op, exir_op, run_on_fvp=False
)
pipeline.run()


@common.parametrize("test_data", test_data2)
@common.u55_fvp_mark
def test_add2_u55_BI_on_fvp(test_data):
pipeline = EthosU55PipelineBI[T2](
Add2(), test_data, aten_op, exir_op, run_on_fvp=True
)
pipeline.run()


@common.parametrize("test_data", test_data2)
def test_add2_u85_BI(test_data):
pipeline = EthosU85PipelineBI[T2](
Add2(), test_data, aten_op, exir_op, run_on_fvp=False
)
pipeline.run()


@common.parametrize("test_data", test_data2)
@common.u85_fvp_mark
def test_add2_u85_BI_on_fvp(test_data):
pipeline = EthosU85PipelineBI[T2](
Add2(), test_data, aten_op, exir_op, run_on_fvp=True
)
pipeline.run()
Loading
Loading