diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 7115b3d6f0e..ed67461a30e 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -92,6 +92,20 @@ def _is_tosa_dialect_op(target) -> bool: or " bool: + schema = getattr(target, "_schema", None) + op_name = getattr(schema, "name", None) + if op_name is None: + return False + + try: + return torch._C._dispatch_has_kernel_for_dispatch_key( + op_name, "CompositeExplicitAutograd" + ) + except RuntimeError: + return False + @staticmethod def _arg_contains_symbolic_shape(arg) -> bool: if isinstance(arg, torch.fx.Node): @@ -197,19 +211,31 @@ def resolve_arg(arg, arg_index=None): return True + def maybe_delete(self, input_nodes_to_maybe_delete): + for input_node in input_nodes_to_maybe_delete: + if input_node.meta.get("is_input", False): + # Never delete submodule inputs, they need to match the parameters from the outer module. + continue + if len(input_node.users) == 0: + self._delete_constant_placeholder(input_node) + def call(self, graph_module): modified = False input_nodes_to_maybe_delete = set() for node in graph_module.graph.nodes: if node.op != "call_function": continue - # Don't fuse TOSA dialect ops as they do not have eager forward functions. - # Also don't fuse ops whose explicit args/kwargs include symbolic shape values. - if ( - self._is_tosa_dialect_op(node.target) - or self._arg_contains_symbolic_shape(node.args) - or self._arg_contains_symbolic_shape(node.kwargs) - ): + if node.target == exir_ops.backend.tosa.RESCALE.default: + # Leave fusing of RESCALES to the compiler. + continue + if self._is_tosa_dialect_op( + node.target + ) and not self._has_real_tosa_dialect_impl(node.target): + continue + # Don't fuse ops whose explicit args/kwargs include symbolic shape values. + if self._arg_contains_symbolic_shape( + node.args + ) or self._arg_contains_symbolic_shape(node.kwargs): continue input_nodes = node.all_input_nodes @@ -241,10 +267,7 @@ def call(self, graph_module): if modified: graph_module.graph.eliminate_dead_code() - for input_node in input_nodes_to_maybe_delete: - if len(input_node.users) == 0: - self._delete_constant_placeholder(input_node) - + self.maybe_delete(input_nodes_to_maybe_delete) graph_module = super().call(graph_module).graph_module return PassResult(graph_module, modified) diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index c51b56a8f36..48433e8f396 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -31,6 +31,7 @@ TosaSpecMapping, ) + logger = logging.getLogger(__name__) @@ -246,3 +247,15 @@ def get_node_visitors(*args) -> Dict[str, NodeVisitor]: node_visitors[target] = visitor(*args) return node_visitors + + +def get_node_visitor(target: str, tosa_spec: TosaSpecification): + # Ensure all operator modules are imported so visitors are registered. + import executorch.backends.arm.operators # noqa: F401 + + node_visitor_tuples = _node_visitor_tuples.get(tosa_spec) + for target_name, node_visitor_cls in node_visitor_tuples: + if target_name == target: + return node_visitor_cls(tosa_spec) + + raise ValueError(f"No {target} NodeVisitor registered for {tosa_spec}") diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_gather.py b/backends/arm/test/misc/tosa_dialect/test_tosa_gather.py new file mode 100644 index 00000000000..8a65de56a57 --- /dev/null +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_gather.py @@ -0,0 +1,44 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 + +import torch +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_gather_tosa_FP_fake() -> None: + values = torch.randn((1, 4, 3), dtype=torch.float32) + indices = torch.tensor([[0, 2]], dtype=torch.int32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ), FakeTensorMode() as mode: + output = exir_ops.backend.tosa.GATHER.default( + mode.from_tensor(values), + mode.from_tensor(indices), + ) + + assert output.dtype == values.dtype + assert tuple(output.shape) == (1, 2, 3) + + +def test_gather_tosa_FP_real() -> None: + values = torch.tensor( + [[[1.0, 10.0], [2.0, 20.0], [3.0, 30.0], [4.0, 40.0]]], + dtype=torch.float32, + ) + indices = torch.tensor([[3, 1]], dtype=torch.int32) + + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")): + output = exir_ops.backend.tosa.GATHER.default(values, indices) + + expected = values[:, indices[0], :] + torch.testing.assert_close(output, expected) diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_identity.py b/backends/arm/test/misc/tosa_dialect/test_tosa_identity.py index 19461cb676c..f03f855e5e3 100644 --- a/backends/arm/test/misc/tosa_dialect/test_tosa_identity.py +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_identity.py @@ -13,7 +13,7 @@ from torch._subclasses.fake_tensor import FakeTensorMode -def test_identity_tosa_FP() -> None: +def test_identity_tosa_FP_fake() -> None: sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32) with TosaLoweringContext( @@ -23,3 +23,12 @@ def test_identity_tosa_FP() -> None: assert output.dtype == sample_input.dtype assert tuple(output.shape) == tuple(sample_input.shape) + + +def test_identity_tosa_FP_real() -> None: + sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32) + + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")): + output = exir_ops.backend.tosa.IDENTITY.default(sample_input) + + torch.testing.assert_close(output, sample_input) diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_rescale.py b/backends/arm/test/misc/tosa_dialect/test_tosa_rescale.py new file mode 100644 index 00000000000..cc52ac87970 --- /dev/null +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_rescale.py @@ -0,0 +1,60 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 +import pytest +import torch + +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"input_unsigned": False, "output_unsigned": False}, + ], +) +def test_rescale_real_impl_with_and_without_kwargs(kwargs): + input_tensor = torch.tensor( + [[1, -2, 3], [4, 0, -5]], + dtype=torch.int32, + ) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + fake_output = exir_ops.backend.tosa.RESCALE.default( + mode.from_tensor(input_tensor), + torch.int32, + [1.0], + 0, + 0, + **kwargs, + ) + + assert isinstance(fake_output, FakeTensor) + assert fake_output.dtype == torch.int32 + assert tuple(fake_output.shape) == tuple(input_tensor.shape) + + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + output = exir_ops.backend.tosa.RESCALE.default( + input_tensor, + torch.int32, + [1.0], + 0, + 0, + **kwargs, + ) + + assert not isinstance(output, FakeTensor) + assert output.dtype == torch.int32 + assert tuple(output.shape) == tuple(input_tensor.shape) + assert torch.equal(output, input_tensor) diff --git a/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py b/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py index 4dd03c1ca6e..077c7aae72c 100644 --- a/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py +++ b/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py @@ -35,7 +35,6 @@ def test_ensure_unique_output_nodes_no_target_inserts_identity_per_repeated_outp "executorch_exir_dialects_backend__ops_tosa_IDENTITY_default": 2, }, ) - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() graph_module = ( @@ -62,5 +61,4 @@ def test_ensure_unique_output_nodes_no_target_keeps_unique_outputs_unchanged() - "executorch_exir_dialects_backend__ops_tosa_IDENTITY_default", ], ) - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 9a07cd6e820..6b5884cd63e 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -16,6 +16,7 @@ from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.tosa.backend import TOSABackend from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, @@ -24,6 +25,7 @@ from executorch.backends.test.harness.stages import StageType from executorch.backends.test.program_builder import ProgramBuilder from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_cond_while_submodules from torch.export.graph_signature import InputKind input_t = Tuple[torch.Tensor] # Input x @@ -341,6 +343,52 @@ def test_fuse_constant_args_fuses_chains_without_recompile() -> None: torch.testing.assert_close(actual, expected) +def test_fuse_constant_args_preserves_unused_control_flow_inputs() -> None: + class CondWithCapturedBuffer(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("buf", torch.ones(1, 1, 1, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(buf: torch.Tensor, arg: torch.Tensor) -> torch.Tensor: + return arg + buf.view(1, 1, 1, 1) + + def false_branch(buf: torch.Tensor, arg: torch.Tensor) -> torch.Tensor: + return arg * buf + + return torch.cond(x.sum() > 0, true_branch, false_branch, [self.buf, x]) + + compile_spec = common.get_tosa_compile_spec( + TosaSpecification.create_from_string("TOSA-1.0+FP+cf") + ) + tester = ArmTester( + CondWithCapturedBuffer(), + example_inputs=(torch.randn(1, 1, 2, 2),), + compile_spec=compile_spec, + ) + tester.export().to_edge() + exported_program = tester.get_artifact(StageType.TO_EDGE).exported_program() + + submodule_0 = exported_program.graph_module.get_submodule("submodule_0") + _, _, cond_node = get_cond_while_submodules(exported_program.graph_module)[0] + TOSABackend._regularize_submodule(submodule_0, cond_node) + + pass_result = FuseConstantArgsPass(exported_program).call(submodule_0) + + placeholders = [ + node + for node in pass_result.graph_module.graph.nodes + if node.op == "placeholder" + ] + assert [node.name for node in placeholders] == [ + "aten_view_copy_default_fused_const", + "b_buf", + "x", + ] + assert placeholders[1].meta["is_input"] is True + assert len(placeholders[1].users) == 0 + + def test_fuse_constant_args_identifies_tosa_dialect_targets() -> None: class FakeTosaTarget: def __str__(self) -> str: @@ -351,6 +399,15 @@ def __str__(self) -> str: exir_ops.backend.tosa.GATHER.default ) assert not FuseConstantArgsPass._is_tosa_dialect_op(torch.ops.aten.add.Tensor) + assert FuseConstantArgsPass._has_real_tosa_dialect_impl( + exir_ops.backend.tosa.GATHER.default + ) + assert FuseConstantArgsPass._has_real_tosa_dialect_impl( + exir_ops.backend.tosa.RESCALE.default + ) + assert not FuseConstantArgsPass._has_real_tosa_dialect_impl( + exir_ops.backend.tosa.TABLE.default + ) def test_fuse_constant_args_identifies_symbolic_shape_args() -> None: @@ -364,24 +421,63 @@ def test_fuse_constant_args_identifies_symbolic_shape_args() -> None: ) -def test_fuse_constant_args_skips_backend_tosa_gather(caplog) -> None: - with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")): +def test_fuse_constant_args_skips_tosa_ops_without_real_impl(caplog) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): builder = ProgramBuilder() - values = builder.placeholder( - "values", - torch.randn(1, 4, 3), + input_tensor = builder.placeholder( + "input_tensor", + torch.tensor([-3, 2], dtype=torch.int8), input_kind=InputKind.CONSTANT_TENSOR, ) - indices = builder.placeholder( - "indices", - torch.tensor([[0, 2]], dtype=torch.int32), + table = builder.placeholder( + "table", + torch.arange(256, dtype=torch.int16).to(torch.int8), + input_kind=InputKind.CONSTANT_TENSOR, + ) + table_lookup = builder.call_operator( + exir_ops.backend.tosa.TABLE.default, + (input_tensor, table), + ) + builder.output([table_lookup]) + + exported_program = builder.get_program() + graph_module = exported_program.graph_module + + with caplog.at_level("WARNING"): + FuseConstantArgsPass(exported_program)(graph_module) + + warning_messages = [ + record.getMessage() + for record in caplog.records + if record.name == "executorch.backends.arm._passes.fuse_constant_ops_pass" + ] + assert not any( + "Failed to fuse constant op" in message and "TABLE" in message + for message in warning_messages + ) + assert ( + sum( + node.op == "call_function" + and node.target == exir_ops.backend.tosa.TABLE.default + for node in graph_module.graph.nodes + ) + == 1 + ) + + +def test_fuse_constant_args_skips_tosa_rescale(caplog) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + builder = ProgramBuilder() + input_tensor = builder.placeholder( + "input_tensor", + torch.tensor([1, 2], dtype=torch.int32), input_kind=InputKind.CONSTANT_TENSOR, ) - gather = builder.call_operator( - exir_ops.backend.tosa.GATHER.default, - (values, indices), + rescale = builder.call_operator( + exir_ops.backend.tosa.RESCALE.default, + (input_tensor, torch.int8, [1.0], 0, 0), ) - builder.output([gather]) + builder.output([rescale]) exported_program = builder.get_program() graph_module = exported_program.graph_module @@ -395,13 +491,13 @@ def test_fuse_constant_args_skips_backend_tosa_gather(caplog) -> None: if record.name == "executorch.backends.arm._passes.fuse_constant_ops_pass" ] assert not any( - "Failed to fuse constant op" in message and "GATHER" in message + "Failed to fuse constant op" in message and "RESCALE" in message for message in warning_messages ) assert ( sum( node.op == "call_function" - and node.target == exir_ops.backend.tosa.GATHER.default + and node.target == exir_ops.backend.tosa.RESCALE.default for node in graph_module.graph.nodes ) == 1 diff --git a/backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py b/backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py index 42214ba59b3..98abfe8fdfe 100644 --- a/backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py +++ b/backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py @@ -6,7 +6,7 @@ from typing import cast, Dict, Protocol, Tuple import torch -from executorch.backends.arm._passes.rewrite_avg_pool2d_pass import RewriteAvgPool2dPass +from executorch.backends.arm._passes import RewriteAvgPool2dPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline from executorch.backends.test.harness.stages import StageType @@ -29,7 +29,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AvgPool2dWithoutStride(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 9, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=3) @@ -37,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AvgPool2dListKernel(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 8, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=[2, 3]) @@ -45,7 +45,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AvgPool2dScalarPadding(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 9, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=2, padding=1) @@ -53,7 +53,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AvgPool2dWithEmptyStride(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 8, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=[2, 3], stride=[]) @@ -81,9 +81,6 @@ def test_rewrite_avg_pool2d_tosa(module: ModuleWithInputs) -> None: }, pass_list=[RewriteAvgPool2dPass], ) - pipeline.pop_stage( - "run_method_and_compare_outputs" - ) # Cannot run aten graph with tosa dialect ops pipeline.run() @@ -119,7 +116,6 @@ def test_rewrite_avg_pool2d_tosa_empty_stride_uses_kernel_size() -> None: }, pass_list=[RewriteAvgPool2dPass], ) - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() tosa_node = _get_tosa_avg_pool2d_node(pipeline) diff --git a/backends/arm/test/passes/test_rewrite_conv_pass.py b/backends/arm/test/passes/test_rewrite_conv_pass.py index 736aa685b86..de8135bc049 100644 --- a/backends/arm/test/passes/test_rewrite_conv_pass.py +++ b/backends/arm/test/passes/test_rewrite_conv_pass.py @@ -213,8 +213,6 @@ def test_rewrite_conv_tosa_FP(): pipeline = PassPipeline( module, module.get_inputs(), passes_with_exported_program=[RewriteConvPass] ) - # We cannot run TOSA backend dialect operators in eager mode. - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() diff --git a/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py b/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py index 52efb0929f2..9ba31f1e6e0 100644 --- a/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py +++ b/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py @@ -37,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MaxPool2dWithoutStride(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 9, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.max_pool2d(x, kernel_size=3) @@ -45,7 +45,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MaxPool2dListKernel(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 8, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.max_pool2d(x, kernel_size=[2, 3]) @@ -56,7 +56,7 @@ def get_inputs(self) -> input_t: return (torch.rand(1, 3, 8, 8),) def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.max_pool2d(x, kernel_size=[2, 3], stride=[]) + return torch.nn.functional.max_pool2d(x, kernel_size=[2, 2], stride=[]) class MaxPool2dDynamic(torch.nn.Module): @@ -94,9 +94,6 @@ def test_rewrite_max_pool2d_tosa(module: ModuleWithInputs) -> None: }, pass_list=[RemoveGetItemPass, RewriteMaxPool2dPass], ) - pipeline.pop_stage( - "run_method_and_compare_outputs" - ) # Cannnot run aten graph with tosa dialect ops pipeline.run() @@ -131,11 +128,10 @@ def test_rewrite_max_pool2d_tosa_empty_stride_uses_kernel_size() -> None: }, pass_list=[RemoveGetItemPass, RewriteMaxPool2dPass], ) - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() tosa_node = _get_tosa_max_pool2d_node(pipeline) - assert tosa_node.args[2] == [2, 3] + assert tosa_node.args[2] == [2, 2] def test_rewrite_max_pool2d_tosa_dynamic_shape() -> None: diff --git a/backends/arm/tosa/dialect/BUCK b/backends/arm/tosa/dialect/BUCK index 5081f5d6945..68961e24f04 100644 --- a/backends/arm/tosa/dialect/BUCK +++ b/backends/arm/tosa/dialect/BUCK @@ -6,6 +6,7 @@ fbcode_target(_kind = runtime.python_library, srcs = [ "lib.py", "ops_registration.py", + "real_impl.py", "shape.py", ], deps = [ diff --git a/backends/arm/tosa/dialect/lib.py b/backends/arm/tosa/dialect/lib.py index ed26a21a297..7db9f1d755b 100644 --- a/backends/arm/tosa/dialect/lib.py +++ b/backends/arm/tosa/dialect/lib.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -14,16 +14,22 @@ tosa_lib = Library("tosa", "DEF") -def register_tosa_dialect_op(op_schema, func) -> Callable: +def register_tosa_dialect_op( + op_schema, + fake_func, + real_func: Callable | None = None, + real_dispatch_key: str = "CompositeExplicitAutograd", +) -> Callable: """Register a TOSA dialect operator with the backend op library. Args: op_schema (str): Operator schema without namespace or overload name. - func (Callable): Fake implementation used for registration. + fake_func (Callable): Fake implementation used for registration. + real_func (Optional[Callable]): Optional eager implementation. + real_dispatch_key (str): Dispatch key used for the eager implementation. Returns: - Callable: Backend dialect operator handle exposed via ``exir_ops`` and - marked ``not_callable`` for runtime use. + Callable: Backend dialect operator handle exposed via ``exir_ops``. """ if tosa_lib.ns not in _BACKEND_OP_LIB: @@ -46,19 +52,11 @@ def register_tosa_dialect_op(op_schema, func) -> Callable: overload_name = "default" op_qualified_name = f"{tosa_lib.ns}::{opname}" - register_fake(op_qualified_name, func, lib=tosa_lib) - + register_fake(op_qualified_name, fake_func, lib=tosa_lib) + if real_func is not None: + tosa_lib.impl(opname, real_func, real_dispatch_key) op = getattr(getattr(getattr(exir_ops.backend, tosa_lib.ns), opname), overload_name) - # For now, since the TOSA operators are only used for lowering and serialization in the backend - # the op doesn't need to be callable. This can be changed in the future if needed to support - # execution of TOSA ops directly. - def not_callable(): - """Raise when the dialect op handle is invoked at runtime.""" - raise RuntimeError("TOSA dialect op is not callable") - - op.__equvalent_callable__ = not_callable - return op diff --git a/backends/arm/tosa/dialect/ops/activation.py b/backends/arm/tosa/dialect/ops/activation.py index 3c3fbffe176..a17110a7e85 100644 --- a/backends/arm/tosa/dialect/ops/activation.py +++ b/backends/arm/tosa/dialect/ops/activation.py @@ -8,7 +8,7 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError from executorch.backends.arm.tosa.dialect.ops._common import validate_nan_mode -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, @@ -84,7 +84,7 @@ def _validate_integer_clamp_bounds( ) -@register_fake_tosa_op( +@register_tosa_op( 'CLAMP(Tensor input, Scalar min_val, Scalar max_val, *, str nan_mode="PROPAGATE") -> Tensor', TosaSpecification.all_versions_and_profiles(), ) @@ -111,7 +111,7 @@ def CLAMP( return torch.empty_like(input, dtype=input.dtype) -@register_fake_tosa_op( +@register_tosa_op( "ERF(Tensor input) -> Tensor", FP_SPECS, ) @@ -120,7 +120,7 @@ def ERF(input: torch.Tensor) -> torch.Tensor: return torch.empty_like(input, dtype=input.dtype) -@register_fake_tosa_op( +@register_tosa_op( "SIGMOID(Tensor input) -> Tensor", FP_SPECS, ) @@ -129,7 +129,7 @@ def SIGMOID(input: torch.Tensor) -> torch.Tensor: return torch.empty_like(input, dtype=input.dtype) -@register_fake_tosa_op( +@register_tosa_op( "TANH(Tensor input) -> Tensor", FP_SPECS, ) diff --git a/backends/arm/tosa/dialect/ops/avg_pool2d.py b/backends/arm/tosa/dialect/ops/avg_pool2d.py index 968b335fc7b..0acc39c29f7 100644 --- a/backends/arm/tosa/dialect/ops/avg_pool2d.py +++ b/backends/arm/tosa/dialect/ops/avg_pool2d.py @@ -8,7 +8,7 @@ import sympy # type: ignore[import-untyped] import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_shape_env, get_context_spec, @@ -105,7 +105,7 @@ def validate_avg_pool2d_args( ) -@register_fake_tosa_op( +@register_tosa_op( "AVG_POOL2D(Tensor input, Tensor input_zp, Tensor output_zp, int[2] kernel, int[2] stride, SymInt[4] pad, ScalarType acc_type) -> Tensor", TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops/avg_pool2d_adaptive.py b/backends/arm/tosa/dialect/ops/avg_pool2d_adaptive.py index 7d71b85eca7..6d2195abd6c 100644 --- a/backends/arm/tosa/dialect/ops/avg_pool2d_adaptive.py +++ b/backends/arm/tosa/dialect/ops/avg_pool2d_adaptive.py @@ -11,7 +11,7 @@ compute_avg_pool2d_output_shape, validate_avg_pool2d_dtype, ) -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_shape_env, get_context_spec, @@ -36,7 +36,7 @@ def _is_directly_representable( return remainder in (0, 1) -@register_fake_tosa_op( +@register_tosa_op( "AVG_POOL2D_ADAPTIVE(Tensor input, Tensor input_zp, Tensor output_zp, SymInt[2] kernel, SymInt[2] stride, SymInt[4] pad, ScalarType acc_type) -> Tensor", TosaSpecification.all_profiles_for_version("1.1"), ) diff --git a/backends/arm/tosa/dialect/ops/cast_to_block_scaled.py b/backends/arm/tosa/dialect/ops/cast_to_block_scaled.py index 8dbff7c11c5..01369696d19 100644 --- a/backends/arm/tosa/dialect/ops/cast_to_block_scaled.py +++ b/backends/arm/tosa/dialect/ops/cast_to_block_scaled.py @@ -11,7 +11,7 @@ from executorch.backends.arm.ao_ext.mxfp import mxfp_str_to_dtype from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, @@ -19,7 +19,7 @@ from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2 -@register_fake_tosa_op( +@register_tosa_op( "CAST_TO_BLOCK_SCALED(Tensor input, SymInt block_size, str output_dtype) -> (Tensor, Tensor)", [TosaSpecification.create_from_string("TOSA-1.1+FP")], ) diff --git a/backends/arm/tosa/dialect/ops/conv2d.py b/backends/arm/tosa/dialect/ops/conv2d.py index d0db2d60fcd..0320efd03b5 100644 --- a/backends/arm/tosa/dialect/ops/conv2d.py +++ b/backends/arm/tosa/dialect/ops/conv2d.py @@ -7,7 +7,7 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, @@ -105,7 +105,7 @@ def conv_output_dim( return (input_dim + total_pad - receptive_field) // stride + 1 -@register_fake_tosa_op( +@register_tosa_op( "CONV2D(Tensor input, " "Tensor weight, " "Tensor bias, " diff --git a/backends/arm/tosa/dialect/ops/conv3d.py b/backends/arm/tosa/dialect/ops/conv3d.py index a81ae0dae53..39f54f8355f 100644 --- a/backends/arm/tosa/dialect/ops/conv3d.py +++ b/backends/arm/tosa/dialect/ops/conv3d.py @@ -11,7 +11,7 @@ conv_output_dim, validate_conv2d_args_dtypes, ) -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, @@ -32,7 +32,7 @@ def validate_conv3d_args_dtypes( return validate_conv2d_args_dtypes(tosa_spec, x, weight, bias, op="CONV3D") -@register_fake_tosa_op( +@register_tosa_op( "CONV3D(Tensor input, " "Tensor weight, " "Tensor bias, " diff --git a/backends/arm/tosa/dialect/ops/custom.py b/backends/arm/tosa/dialect/ops/custom.py index 6376124d6f2..0313456ad9e 100644 --- a/backends/arm/tosa/dialect/ops/custom.py +++ b/backends/arm/tosa/dialect/ops/custom.py @@ -31,7 +31,7 @@ from collections.abc import Callable import torch -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -132,7 +132,7 @@ def run_registered_fake_tosa_impl( return outputs -@register_fake_tosa_op( +@register_tosa_op( "CUSTOM(Tensor[] inputs, str operator_name, str domain_name, int[] implementation_attrs) -> Tensor[]", TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py index 83ef3ff72fb..91609cd6b62 100644 --- a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py +++ b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py @@ -8,7 +8,7 @@ conv_output_dim, validate_conv2d_args_dtypes, ) -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -16,7 +16,7 @@ ) -@register_fake_tosa_op( +@register_tosa_op( "DEPTHWISE_CONV2D(Tensor input, " "Tensor weight, " "Tensor bias, " diff --git a/backends/arm/tosa/dialect/ops/gather.py b/backends/arm/tosa/dialect/ops/gather.py index 49374142cd6..4c634445713 100644 --- a/backends/arm/tosa/dialect/ops/gather.py +++ b/backends/arm/tosa/dialect/ops/gather.py @@ -7,14 +7,14 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, ) -@register_fake_tosa_op( +@register_tosa_op( "GATHER(Tensor values, Tensor indices) -> Tensor", TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops/identity.py b/backends/arm/tosa/dialect/ops/identity.py index 6e26d8e8b22..ff8a366e72d 100644 --- a/backends/arm/tosa/dialect/ops/identity.py +++ b/backends/arm/tosa/dialect/ops/identity.py @@ -4,11 +4,11 @@ # LICENSE file in the root directory of this source tree. import torch -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import TosaSpecification -@register_fake_tosa_op( +@register_tosa_op( "IDENTITY(Tensor input) -> Tensor", TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops/matmul.py b/backends/arm/tosa/dialect/ops/matmul.py index 8023df88072..6d035c0e5a6 100644 --- a/backends/arm/tosa/dialect/ops/matmul.py +++ b/backends/arm/tosa/dialect/ops/matmul.py @@ -5,7 +5,7 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -14,7 +14,7 @@ from executorch.exir.dialects._ops import ops as exir_ops -@register_fake_tosa_op( +@register_tosa_op( "MATMUL(Tensor input1, Tensor input2) -> Tensor", # schema TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py b/backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py index fcea104320f..49b1c07129e 100644 --- a/backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py +++ b/backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py @@ -13,7 +13,7 @@ SUPPORTED_MXFP_DTYPES, ) from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, @@ -128,7 +128,7 @@ def _validate_shapes( return N, H, W -@register_fake_tosa_op( +@register_tosa_op( "MATMUL_T_BLOCK_SCALED(Tensor A_data, Tensor A_scale, Tensor B_data, " "Tensor B_scale, SymInt block_size, str payload_dtype='') -> Tensor", [TosaSpecification.create_from_string("TOSA-1.1+FP")], diff --git a/backends/arm/tosa/dialect/ops/max_pool2d.py b/backends/arm/tosa/dialect/ops/max_pool2d.py index 1b1a399a757..9b93d79a71b 100644 --- a/backends/arm/tosa/dialect/ops/max_pool2d.py +++ b/backends/arm/tosa/dialect/ops/max_pool2d.py @@ -8,7 +8,7 @@ import sympy # type: ignore[import-untyped] import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_shape_env, get_context_spec, @@ -68,7 +68,7 @@ def validate_max_pool2d_dtype( raise TosaValueError(f"Unsupported input dtype {x.dtype} pools", op=op) -@register_fake_tosa_op( +@register_tosa_op( "MAX_POOL2D(Tensor input, int[2] kernel, int[2] stride, SymInt[4] pad) -> Tensor", TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops/max_pool2d_adaptive.py b/backends/arm/tosa/dialect/ops/max_pool2d_adaptive.py index 605d94d2af1..193d083f012 100644 --- a/backends/arm/tosa/dialect/ops/max_pool2d_adaptive.py +++ b/backends/arm/tosa/dialect/ops/max_pool2d_adaptive.py @@ -10,7 +10,7 @@ compute_max_pool2d_output_shape, validate_max_pool2d_dtype, ) -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_shape_env, get_context_spec, @@ -35,7 +35,7 @@ def _is_directly_representable( return remainder in (0, 1) -@register_fake_tosa_op( +@register_tosa_op( "MAX_POOL2D_ADAPTIVE(Tensor input, SymInt[2] kernel, SymInt[2] stride, SymInt[4] pad) -> Tensor", TosaSpecification.all_profiles_for_version("1.1"), ) diff --git a/backends/arm/tosa/dialect/ops/rescale.py b/backends/arm/tosa/dialect/ops/rescale.py index c782ab4ae81..0ea5e4b25aa 100644 --- a/backends/arm/tosa/dialect/ops/rescale.py +++ b/backends/arm/tosa/dialect/ops/rescale.py @@ -7,7 +7,7 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -15,7 +15,7 @@ ) -@register_fake_tosa_op( +@register_tosa_op( "RESCALE(Tensor input1, ScalarType dtype, float[] scale, int in_zp, int out_zp, *, bool input_unsigned=False, bool output_unsigned=False) -> Tensor", # schema TosaSpecification.all_versions_for_profile("INT"), # target TOSA specifications ) diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py index 0d06253ccd8..18ebe6c6210 100644 --- a/backends/arm/tosa/dialect/ops/resize.py +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -7,7 +7,7 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.resize_utils import ( calculate_tosa_resize_output_hw, get_tosa_resize_output_hw_validation_error, @@ -68,7 +68,7 @@ def _validate_resize_parameters(input_hw, output_hw, scale, offset, border, tosa raise TosaValueError(validation_error, op="RESIZE") -@register_fake_tosa_op( +@register_tosa_op( "RESIZE(Tensor input, SymInt[4] scale_factors, SymInt[2] offset, SymInt[2] border, *, str resize_mode) -> Tensor", # schema TosaSpecification.all_versions_and_profiles(), # target TOSA specifications ) diff --git a/backends/arm/tosa/dialect/ops/scatter.py b/backends/arm/tosa/dialect/ops/scatter.py index 6f13bd6154a..b42b2a5ae30 100644 --- a/backends/arm/tosa/dialect/ops/scatter.py +++ b/backends/arm/tosa/dialect/ops/scatter.py @@ -5,12 +5,12 @@ import torch -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import TosaSpecification -@register_fake_tosa_op( +@register_tosa_op( "SCATTER(Tensor values_in, Tensor indices, Tensor input) -> Tensor", # schema TosaSpecification.all_versions_and_profiles(), # target TOSA specifications ) diff --git a/backends/arm/tosa/dialect/ops_registration.py b/backends/arm/tosa/dialect/ops_registration.py index 6581673bfec..3cfa2a6281b 100644 --- a/backends/arm/tosa/dialect/ops_registration.py +++ b/backends/arm/tosa/dialect/ops_registration.py @@ -7,6 +7,9 @@ from typing import Callable, Iterable, List, ParamSpec, TypeVar from executorch.backends.arm.tosa.dialect.lib import register_tosa_dialect_op +from executorch.backends.arm.tosa.dialect.real_impl import ( + make_tosa_reference_model_impl, +) from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -41,10 +44,32 @@ def register_fake_tosa_op( """ + return register_tosa_op(op_schema, tosa_specs, include_real_impl=False) + + +def register_tosa_op( + op_schema: str, + tosa_specs: Iterable[TosaSpecification], + include_real_impl: bool = True, +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """Register a TOSA op with fake/meta and optional real eager execution.""" + def decorator(func: Callable[P, R]) -> Callable[P, R]: # Only call register_tosa_dialect_op if the function hasn't been registered yet. if func not in _registered_tosa_ops_by_func: - op_callable = register_tosa_dialect_op(op_schema, func) + real_impl = None + if include_real_impl: + real_impl = make_tosa_reference_model_impl( + fake_func=func, + op_schema=op_schema, + ) + + op_callable = register_tosa_dialect_op( + op_schema, + fake_func=func, + real_func=real_impl, + real_dispatch_key="CompositeExplicitAutograd", + ) _registered_tosa_ops_by_func[func] = op_callable else: op_callable = _registered_tosa_ops_by_func[func] diff --git a/backends/arm/tosa/dialect/real_impl.py b/backends/arm/tosa/dialect/real_impl.py new file mode 100644 index 00000000000..65ba4441f12 --- /dev/null +++ b/backends/arm/tosa/dialect/real_impl.py @@ -0,0 +1,171 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import logging +from typing import Any, Callable + +import numpy as np +import torch +import tosa_reference_model as reference_model # type: ignore[import-not-found, import-untyped] +import tosa_serializer as ts +from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import get_context_spec +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + +logger = logging.getLogger(__name__) + + +def _torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: + tensor = tensor.detach().cpu() + if tensor.dtype == torch.bfloat16: + tensor = tensor.view(torch.uint16) + elif tensor.dtype == torch.float8_e4m3fn: + tensor = tensor.view(torch.uint8) + elif tensor.dtype == torch.float8_e5m2: + tensor = tensor.view(torch.uint8) + return tensor.numpy() + + +def _numpy_to_torch_tensor( + array: np.ndarray, expected_dtype: torch.dtype, expected_shape: torch.Size +) -> torch.Tensor: + if array.dtype.type is np.void: + return torch.frombuffer(array, dtype=expected_dtype).reshape(expected_shape) + + tensor = torch.from_numpy(array) + if expected_dtype == torch.bfloat16: + tensor = tensor.view(torch.bfloat16) + elif expected_dtype == torch.float8_e4m3fn: + tensor = tensor.view(torch.float8_e4m3fn) + elif expected_dtype == torch.float8_e5m2: + tensor = tensor.view(torch.float8_e5m2) + return tensor + + +def make_tosa_reference_model_impl( + fake_func: Callable, + op_schema: str, +) -> Callable: + """Create a real eager implementation from a fake TOSA dialect op.""" + + signature = inspect.signature(fake_func) + op_name = op_schema.split("(")[0] + + def real_impl(*args, **kwargs) -> torch.Tensor: + + bound = signature.bind(*args, **kwargs) + bound.apply_defaults() + normalized_args = list(bound.arguments.values()) + + fake_output = fake_func(*args, **kwargs) + if not isinstance(fake_output, torch.Tensor): + raise TypeError( + f"Only single-tensor outputs are supported for real TOSA op execution, got {type(fake_output).__name__}" + ) + + tensor_args = [arg for arg in normalized_args if isinstance(arg, torch.Tensor)] + if not tensor_args: + raise ValueError( + f"Real TOSA op execution requires at least one tensor input: {op_name}" + ) + + graph = torch.fx.Graph() + node_args: list[Any] = [] + node_kwargs: dict[str, Any] = {} + placeholder_nodes: list[torch.fx.Node] = [] + op_handle = getattr(exir_ops.backend.tosa, op_name).default + with FakeTensorMode(allow_non_fake_inputs=True) as mode: + for parameter, arg in zip(signature.parameters.values(), normalized_args): + if isinstance(arg, torch.Tensor): + placeholder = graph.placeholder(parameter.name) + placeholder.meta["val"] = mode.from_tensor(arg.detach().cpu()) + placeholder_nodes.append(placeholder) + fx_arg = placeholder + else: + fx_arg = arg + if parameter.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + node_args.append(fx_arg) + elif parameter.kind == inspect.Parameter.KEYWORD_ONLY: + node_kwargs[parameter.name] = fx_arg + else: + raise NotImplementedError( + f"Unsupported parameter kind for tosa.{op_name}: {parameter.kind}" + ) + op_node = graph.call_function(op_handle, tuple(node_args), node_kwargs) + op_node.meta["val"] = mode.from_tensor(fake_output.detach().cpu()) + + tosa_spec = get_context_spec() + version = tosa_spec.version + tosa_graph = ts.TosaSerializer( + "", + targetMajor=version.major, + targetMinor=version.minor, + targetPatch=version.micro, + targetDraft=False, + ) + + for node in placeholder_nodes: + arg = TosaArg(node, tosa_spec) + tosa_graph.addInputTensor( + ts.TosaSerializerTensor(arg.name, list(arg.shape), arg.dtype, data=None) + ) + + output_arg = TosaArg(op_node, tosa_spec) + tosa_graph.currRegion.currBasicBlock.addTensor( + output_arg.name, + list(output_arg.shape), + output_arg.dtype, + ) + from executorch.backends.arm.operators.node_visitor import get_node_visitor + + visitor = get_node_visitor(f"tosa.{op_name}.default", tosa_spec) + visitor.define_node( + op_node, + tosa_graph, + [TosaArg(arg, tosa_spec) for arg in op_node.args], + output_arg, + ) + tosa_graph.addOutputTensor( + tosa_graph.currRegion.currBasicBlock.tensors[output_arg.name] + ) + + outputs_np, status = reference_model.run( + tosa_graph.serialize(), + [_torch_tensor_to_numpy(arg) for arg in tensor_args], + verbosity=_tosa_refmodel_loglevel(logger.getEffectiveLevel()), + initialize_variable_tensor_from_numpy=True, + debug_mode="ALL" if logger.isEnabledFor(logging.DEBUG) else None, + ) + if status != reference_model.GraphStatus.TOSA_VALID: + raise RuntimeError( + f"TOSA reference model rejected tosa.{op_name} graph: {status}" + ) + + return _numpy_to_torch_tensor( + outputs_np[0], fake_output.dtype, fake_output.shape + ).to(device=tensor_args[0].device) + + return real_impl + + +def _tosa_refmodel_loglevel(loglevel: int) -> str: + """Converts a logging loglevel to tosa_reference_model logginglevel, + returned as string. + """ + loglevel_map = { + logging.INFO: "INFO", + logging.CRITICAL: "LOW", + logging.ERROR: "LOW", + logging.WARNING: "MED", + logging.DEBUG: "HIGH", + logging.NOTSET: "MED", + } + clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0) + return loglevel_map[clamped_logging_level] diff --git a/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py b/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py index f8d0269630b..80f62ddd5f6 100644 --- a/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py +++ b/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py @@ -10,7 +10,7 @@ # Stub modules that are transitively imported by arm_quantizer but never # exercised by these tests. -for _mod in ("tosa_serializer", "tosa", "tosa.TosaGraph"): +for _mod in ("tosa_serializer", "tosa", "tosa.TosaGraph, tosa_reference_model"): if _mod not in sys.modules: try: __import__(_mod)