Skip to content

Commit

Permalink
Update on "Bump ExecuTorch's PyTorch nightly pin to dev20241121"
Browse files Browse the repository at this point in the history
Require at least 11/18 to unblock #7040 .

Differential Revision: [D66398425](https://our.internmc.facebook.com/intern/diff/D66398425/)

[ghstack-poisoned]
  • Loading branch information
swolchok committed Dec 2, 2024
2 parents 2ed5b98 + ec7367c commit 0520d25
Show file tree
Hide file tree
Showing 60 changed files with 1,798 additions and 479 deletions.
11 changes: 11 additions & 0 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,14 @@ python_library(
"//executorch/backends/arm/operators:node_visitor",
],
)

python_library(
name = "arm_model_evaluator",
src = [
"util/arm_model_evaluator.py",
],
typing = True,
deps = [
"//caffe2:torch",
]
)
6 changes: 3 additions & 3 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
DecomposeSoftmaxesPass,
)
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
InsertSqueezeAfterSumPass,
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
Expand Down Expand Up @@ -71,7 +71,7 @@ def transform_to_backend_pipeline(
self.add_pass(DecomposeMeanDimPass())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(InsertSqueezeAfterSumPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
Expand Down
58 changes: 58 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-unsafe

from inspect import isclass
from typing import Optional

import torch
Expand Down Expand Up @@ -133,3 +134,60 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
fake_tensor, FakeTensor
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
return fake_tensor


def get_node_arg(args: list | dict, key: int | str | type, default_value=None):
"""
Help-function for getting a value from node.args/ kwargs, three cases:
1. By position in node.args - Returns arg at given position or default_value if index is one out of bounds
2. By key in node.kwargs - Returns kwarg with given key or default_value if it deos not exist
3. By type in node.args - Returns first arg of args of given type. Useful for cases where arg postions may differ but types are unique.
"""
if isinstance(key, int):
if 0 <= key < len(args):
return args[key]
elif key == len(args):
if default_value is not None:
return default_value
else:
raise RuntimeError(f"No defult value given for index {key}")
else:
raise RuntimeError(
f"Out of bounds index {key} for getting value in args (of size {len(args)})"
)
elif isinstance(key, str):
return args.get(key, default_value)
elif isclass(key):
for arg in args:
if isinstance(arg, key):
return arg
if default_value is not None:
return default_value
else:
raise RuntimeError(f"No arg of type {key}")
else:
raise RuntimeError("Invalid type")


def set_node_arg(node: torch.fx.Node, i: int | str, value):
"""
Help-function for setting a value in node.args/ kwargs. If the index is one larger than the list size, the value is instead appended to the list.
"""
if isinstance(i, int):
if 0 <= i < len(node.args):
args = list(node.args)
args[i] = value
node.args = tuple(args)
return
elif i == len(node.args):
node.args = node.args + (value,)
else:
raise RuntimeError(
f"Out of bounds index {i} for setting value in {node} args (of size {len(node.args)})"
)
elif isinstance(i, str):
kwargs = dict(node.kwargs)
kwargs[i] = value
node.kwargs = kwargs
else:
raise RuntimeError("Invalid type")
13 changes: 7 additions & 6 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-unsafe

import torch
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -42,16 +43,16 @@ def call_operator(self, op, args, kwargs, meta):
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
return super().call_operator(op, args, kwargs, meta)

x = args[0]
dim = args[1]
keepdim = args[2] if len(args) > 2 else False
if not keepdim:
return super().call_operator(op, args, kwargs, meta)
# if keepdim == True and dim == [-1, -2], mean.dim can be
x = get_node_arg(args, 0)
dim = get_node_arg(args, 1)
keepdim = get_node_arg(args, 2, False)

# if dim == [-1, -2], mean.dim can be
# decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
if dim == [-1, -2]:
# Simply return the mean.dim operator for future decomposition.
return super().call_operator(op, args, kwargs, meta)

shape = meta["val"].size()
dtype = meta["val"].dtype
input_shape = x.data.size()
Expand Down
27 changes: 16 additions & 11 deletions backends/arm/_passes/decompose_var_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


import torch
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -53,26 +54,30 @@ def call_operator(self, op, args, kwargs, meta):
torch.ops.aten.var.dim,
):
return super().call_operator(op, args, kwargs, meta)
shape = meta["val"].size()

x = args[0]
input_shape = x.data.size()
shape = list(meta["val"].size())
if shape == []:
shape = [1 for _ in input_shape]

dtype = meta["val"].dtype
dim = args[1] if len(args) > 1 else list(range(len(shape)))
# Get dim from args based on argument type
dim = get_node_arg(args, key=list, default_value=list(range(len(shape))))

if op == torch.ops.aten.var.dim:
correction = args[-2]
keepdim = args[-1]
keepdim = get_node_arg(args, bool, False)
correction = get_node_arg(args, int, 1)
else:
correction = kwargs["correction"]
keepdim = kwargs.get("keepdim", False)
if not keepdim:
return super().call_operator(op, args, kwargs, meta)
correction = get_node_arg(kwargs, "correction", 1)
keepdim = get_node_arg(kwargs, "keepdim", False)

x = args[0]
input_shape = x.data.size()
N = 1
for d in dim:
N *= input_shape[d]

mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta)
mean = super().call_operator(mean_op, (x, dim, True), {}, meta)
diff = super().call_operator(diff_op, (x, mean), {}, meta)
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@

import torch
import torch.fx
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_node_arg,
set_node_arg,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class InsertSqueezeAfterSumPass(ExportPass):
class KeepDimsFalseToSqueezePass(ExportPass):
"""
In Pytorch, the default behaviour of Tensor.sum is to squeeze
In Pytorch, the default behaviour of for example Tensor.sum is to squeeze
the dimension that is summed (keep_dim = False).
However, in TOSA, REDUCE_SUM always preserves the
rank of the input (keep_dim = True).
Expand All @@ -31,28 +35,52 @@ class InsertSqueezeAfterSumPass(ExportPass):
squeeze(dim = dims)
"""

# CURRENTLY NOT HANDLED OPS
# exir_ops.edge.aten.amax,
# exir_ops.edge.aten.amin,
# exir_ops.edge.aten.any.dim,
# exir_ops.edge.aten.any.dims,
# exir_ops.edge.aten.argmax,
# exir_ops.edge.aten.argmin,
# exir_ops.edge.aten.max.dim,
# exir_ops.edge.aten.min.dim,
# exir_ops.edge.aten.prod.dim_int,

# HANDLED OPS
# exir_ops.edge.aten.sum.dim_IntList
# exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass)
# exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass)
# exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass)

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
keep_dim_index = None

if node.op != "call_function":
continue
if node.target != exir_ops.edge.aten.sum.dim_IntList:
if node.target == exir_ops.edge.aten.sum.dim_IntList:
keep_dim_index = 2
else:
continue

sum_node = cast(torch.fx.Node, node)
keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False)
keep_dim = get_node_arg(sum_node.args, keep_dim_index, False)

if keep_dim:
continue

dim_list = cast(list[int], sum_node.args[1])
dim_list = get_node_arg(sum_node.args, 1, [0])

# Add keep_dim = True arg to sum node.
sum_node.args = sum_node.args[0:2] + (True,)
set_node_arg(sum_node, 2, True)

with graph_module.graph.inserting_after(sum_node):
squeeze_node = create_node(
graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, ()
)
sum_node.replace_all_uses_with(squeeze_node)
squeeze_node.args = (sum_node, dim_list)

graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
Expand Down
7 changes: 1 addition & 6 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,4 @@

# pyre-unsafe

from . import ( # noqa
mean_dim_support,
right_shift_support,
tosa_supported_operators,
var_correction_support,
)
from . import right_shift_support, to_copy_support, tosa_supported_operators # noqa
33 changes: 0 additions & 33 deletions backends/arm/operator_support/mean_dim_support.py

This file was deleted.

Loading

0 comments on commit 0520d25

Please sign in to comment.