Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions backends/arm/_passes/fuse_constant_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ def _is_tosa_dialect_op(target) -> bool:
or "<EdgeOpOverload: tosa." in target_str
)

@staticmethod
def _has_real_tosa_dialect_impl(target) -> bool:
schema = getattr(target, "_schema", None)
op_name = getattr(schema, "name", None)
if op_name is None:
return False

try:
return torch._C._dispatch_has_kernel_for_dispatch_key(
op_name, "CompositeExplicitAutograd"
)
except RuntimeError:
return False

@staticmethod
def _arg_contains_symbolic_shape(arg) -> bool:
if isinstance(arg, torch.fx.Node):
Expand Down Expand Up @@ -197,19 +211,31 @@ def resolve_arg(arg, arg_index=None):

return True

def maybe_delete(self, input_nodes_to_maybe_delete):
for input_node in input_nodes_to_maybe_delete:
if input_node.meta.get("is_input", False):
# Never delete submodule inputs, they need to match the parameters from the outer module.
continue
if len(input_node.users) == 0:
self._delete_constant_placeholder(input_node)

def call(self, graph_module):
modified = False
input_nodes_to_maybe_delete = set()
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
# Don't fuse TOSA dialect ops as they do not have eager forward functions.
# Also don't fuse ops whose explicit args/kwargs include symbolic shape values.
if (
self._is_tosa_dialect_op(node.target)
or self._arg_contains_symbolic_shape(node.args)
or self._arg_contains_symbolic_shape(node.kwargs)
):
if node.target == exir_ops.backend.tosa.RESCALE.default:
# Leave fusing of RESCALES to the compiler.
continue
if self._is_tosa_dialect_op(
node.target
) and not self._has_real_tosa_dialect_impl(node.target):
continue
# Don't fuse ops whose explicit args/kwargs include symbolic shape values.
if self._arg_contains_symbolic_shape(
node.args
) or self._arg_contains_symbolic_shape(node.kwargs):
continue

input_nodes = node.all_input_nodes
Expand Down Expand Up @@ -241,10 +267,7 @@ def call(self, graph_module):

if modified:
graph_module.graph.eliminate_dead_code()
for input_node in input_nodes_to_maybe_delete:
if len(input_node.users) == 0:
self._delete_constant_placeholder(input_node)

self.maybe_delete(input_nodes_to_maybe_delete)
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
Expand Down
13 changes: 13 additions & 0 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TosaSpecMapping,
)


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -246,3 +247,15 @@ def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
node_visitors[target] = visitor(*args)

return node_visitors


def get_node_visitor(target: str, tosa_spec: TosaSpecification):
# Ensure all operator modules are imported so visitors are registered.
import executorch.backends.arm.operators # noqa: F401

node_visitor_tuples = _node_visitor_tuples.get(tosa_spec)
for target_name, node_visitor_cls in node_visitor_tuples:
if target_name == target:
return node_visitor_cls(tosa_spec)

raise ValueError(f"No {target} NodeVisitor registered for {tosa_spec}")
44 changes: 44 additions & 0 deletions backends/arm/test/misc/tosa_dialect/test_tosa_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.arm.tosa.dialect # noqa: F401

import torch
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch._subclasses.fake_tensor import FakeTensorMode


def test_gather_tosa_FP_fake() -> None:
values = torch.randn((1, 4, 3), dtype=torch.float32)
indices = torch.tensor([[0, 2]], dtype=torch.int32)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.0+FP")
), FakeTensorMode() as mode:
output = exir_ops.backend.tosa.GATHER.default(
mode.from_tensor(values),
mode.from_tensor(indices),
)

assert output.dtype == values.dtype
assert tuple(output.shape) == (1, 2, 3)


def test_gather_tosa_FP_real() -> None:
values = torch.tensor(
[[[1.0, 10.0], [2.0, 20.0], [3.0, 30.0], [4.0, 40.0]]],
dtype=torch.float32,
)
indices = torch.tensor([[3, 1]], dtype=torch.int32)

with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")):
output = exir_ops.backend.tosa.GATHER.default(values, indices)

expected = values[:, indices[0], :]
torch.testing.assert_close(output, expected)
11 changes: 10 additions & 1 deletion backends/arm/test/misc/tosa_dialect/test_tosa_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch._subclasses.fake_tensor import FakeTensorMode


def test_identity_tosa_FP() -> None:
def test_identity_tosa_FP_fake() -> None:
sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32)

with TosaLoweringContext(
Expand All @@ -23,3 +23,12 @@ def test_identity_tosa_FP() -> None:

assert output.dtype == sample_input.dtype
assert tuple(output.shape) == tuple(sample_input.shape)


def test_identity_tosa_FP_real() -> None:
sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32)

with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")):
output = exir_ops.backend.tosa.IDENTITY.default(sample_input)

torch.testing.assert_close(output, sample_input)
60 changes: 60 additions & 0 deletions backends/arm/test/misc/tosa_dialect/test_tosa_rescale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.arm.tosa.dialect # noqa: F401
import pytest
import torch

from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode


@pytest.mark.parametrize(
"kwargs",
[
{},
{"input_unsigned": False, "output_unsigned": False},
],
)
def test_rescale_real_impl_with_and_without_kwargs(kwargs):
input_tensor = torch.tensor(
[[1, -2, 3], [4, 0, -5]],
dtype=torch.int32,
)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.0+INT")
), FakeTensorMode() as mode:
fake_output = exir_ops.backend.tosa.RESCALE.default(
mode.from_tensor(input_tensor),
torch.int32,
[1.0],
0,
0,
**kwargs,
)

assert isinstance(fake_output, FakeTensor)
assert fake_output.dtype == torch.int32
assert tuple(fake_output.shape) == tuple(input_tensor.shape)

with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")):
output = exir_ops.backend.tosa.RESCALE.default(
input_tensor,
torch.int32,
[1.0],
0,
0,
**kwargs,
)

assert not isinstance(output, FakeTensor)
assert output.dtype == torch.int32
assert tuple(output.shape) == tuple(input_tensor.shape)
assert torch.equal(output, input_tensor)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_ensure_unique_output_nodes_no_target_inserts_identity_per_repeated_outp
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default": 2,
},
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()

graph_module = (
Expand All @@ -62,5 +61,4 @@ def test_ensure_unique_output_nodes_no_target_keeps_unique_outputs_unchanged() -
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default",
],
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()
Loading
Loading