From 6bae98f50c21c654a5b9c21414c2bf6f5b88a75b Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Wed, 29 Apr 2026 10:42:59 +0200 Subject: [PATCH 1/2] Arm backend: Add infra for real implementations of Tosa ops TosaRealOpConfig configures how the real implementation is done. Main point of configuration is which node visitor is used. Currently, this needs to be done with a factory function with lazy import, due to direct imports causing a circular import. Tested by applying the infra to the avg_pool2d Tosa dialect op. The rewrite_avg_pool2d tests can now be ran to verify that the produced Tosa is correct. To make it completely correct, two additional passes need to be added to the test. Signed-off-by: Erik Lundell Change-Id: I7b573583fe241b63927864684e5ccdd1e4aa2cff --- backends/arm/operators/node_visitor.py | 10 ++ .../misc/tosa_dialect/test_tosa_gather.py | 44 +++++ .../misc/tosa_dialect/test_tosa_identity.py | 11 +- .../passes/test_rewrite_avg_pool2d_pass.py | 14 +- backends/arm/tosa/dialect/BUCK | 1 + backends/arm/tosa/dialect/lib.py | 30 ++-- backends/arm/tosa/dialect/ops/avg_pool2d.py | 4 +- backends/arm/tosa/dialect/ops/gather.py | 4 +- backends/arm/tosa/dialect/ops/identity.py | 4 +- backends/arm/tosa/dialect/ops_registration.py | 27 +++- backends/arm/tosa/dialect/real_impl.py | 152 ++++++++++++++++++ 11 files changed, 268 insertions(+), 33 deletions(-) create mode 100644 backends/arm/test/misc/tosa_dialect/test_tosa_gather.py create mode 100644 backends/arm/tosa/dialect/real_impl.py diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index c51b56a8f36..078d855b1d2 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -31,6 +31,7 @@ TosaSpecMapping, ) + logger = logging.getLogger(__name__) @@ -246,3 +247,12 @@ def get_node_visitors(*args) -> Dict[str, NodeVisitor]: node_visitors[target] = visitor(*args) return node_visitors + + +def get_node_visitor(target: str, tosa_spec: TosaSpecification): + node_visitor_tuples = _node_visitor_tuples.get(tosa_spec) + for target_name, node_visitor_cls in node_visitor_tuples: + if target_name == target: + return node_visitor_cls(tosa_spec) + + raise ValueError(f"No {target} NodeVisitor registered for {tosa_spec}") diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_gather.py b/backends/arm/test/misc/tosa_dialect/test_tosa_gather.py new file mode 100644 index 00000000000..8a65de56a57 --- /dev/null +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_gather.py @@ -0,0 +1,44 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 + +import torch +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_gather_tosa_FP_fake() -> None: + values = torch.randn((1, 4, 3), dtype=torch.float32) + indices = torch.tensor([[0, 2]], dtype=torch.int32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ), FakeTensorMode() as mode: + output = exir_ops.backend.tosa.GATHER.default( + mode.from_tensor(values), + mode.from_tensor(indices), + ) + + assert output.dtype == values.dtype + assert tuple(output.shape) == (1, 2, 3) + + +def test_gather_tosa_FP_real() -> None: + values = torch.tensor( + [[[1.0, 10.0], [2.0, 20.0], [3.0, 30.0], [4.0, 40.0]]], + dtype=torch.float32, + ) + indices = torch.tensor([[3, 1]], dtype=torch.int32) + + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")): + output = exir_ops.backend.tosa.GATHER.default(values, indices) + + expected = values[:, indices[0], :] + torch.testing.assert_close(output, expected) diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_identity.py b/backends/arm/test/misc/tosa_dialect/test_tosa_identity.py index 19461cb676c..f03f855e5e3 100644 --- a/backends/arm/test/misc/tosa_dialect/test_tosa_identity.py +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_identity.py @@ -13,7 +13,7 @@ from torch._subclasses.fake_tensor import FakeTensorMode -def test_identity_tosa_FP() -> None: +def test_identity_tosa_FP_fake() -> None: sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32) with TosaLoweringContext( @@ -23,3 +23,12 @@ def test_identity_tosa_FP() -> None: assert output.dtype == sample_input.dtype assert tuple(output.shape) == tuple(sample_input.shape) + + +def test_identity_tosa_FP_real() -> None: + sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32) + + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")): + output = exir_ops.backend.tosa.IDENTITY.default(sample_input) + + torch.testing.assert_close(output, sample_input) diff --git a/backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py b/backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py index 42214ba59b3..98abfe8fdfe 100644 --- a/backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py +++ b/backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py @@ -6,7 +6,7 @@ from typing import cast, Dict, Protocol, Tuple import torch -from executorch.backends.arm._passes.rewrite_avg_pool2d_pass import RewriteAvgPool2dPass +from executorch.backends.arm._passes import RewriteAvgPool2dPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline from executorch.backends.test.harness.stages import StageType @@ -29,7 +29,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AvgPool2dWithoutStride(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 9, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=3) @@ -37,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AvgPool2dListKernel(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 8, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=[2, 3]) @@ -45,7 +45,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AvgPool2dScalarPadding(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 9, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=2, padding=1) @@ -53,7 +53,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AvgPool2dWithEmptyStride(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 8, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=[2, 3], stride=[]) @@ -81,9 +81,6 @@ def test_rewrite_avg_pool2d_tosa(module: ModuleWithInputs) -> None: }, pass_list=[RewriteAvgPool2dPass], ) - pipeline.pop_stage( - "run_method_and_compare_outputs" - ) # Cannot run aten graph with tosa dialect ops pipeline.run() @@ -119,7 +116,6 @@ def test_rewrite_avg_pool2d_tosa_empty_stride_uses_kernel_size() -> None: }, pass_list=[RewriteAvgPool2dPass], ) - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() tosa_node = _get_tosa_avg_pool2d_node(pipeline) diff --git a/backends/arm/tosa/dialect/BUCK b/backends/arm/tosa/dialect/BUCK index 5081f5d6945..68961e24f04 100644 --- a/backends/arm/tosa/dialect/BUCK +++ b/backends/arm/tosa/dialect/BUCK @@ -6,6 +6,7 @@ fbcode_target(_kind = runtime.python_library, srcs = [ "lib.py", "ops_registration.py", + "real_impl.py", "shape.py", ], deps = [ diff --git a/backends/arm/tosa/dialect/lib.py b/backends/arm/tosa/dialect/lib.py index ed26a21a297..7db9f1d755b 100644 --- a/backends/arm/tosa/dialect/lib.py +++ b/backends/arm/tosa/dialect/lib.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -14,16 +14,22 @@ tosa_lib = Library("tosa", "DEF") -def register_tosa_dialect_op(op_schema, func) -> Callable: +def register_tosa_dialect_op( + op_schema, + fake_func, + real_func: Callable | None = None, + real_dispatch_key: str = "CompositeExplicitAutograd", +) -> Callable: """Register a TOSA dialect operator with the backend op library. Args: op_schema (str): Operator schema without namespace or overload name. - func (Callable): Fake implementation used for registration. + fake_func (Callable): Fake implementation used for registration. + real_func (Optional[Callable]): Optional eager implementation. + real_dispatch_key (str): Dispatch key used for the eager implementation. Returns: - Callable: Backend dialect operator handle exposed via ``exir_ops`` and - marked ``not_callable`` for runtime use. + Callable: Backend dialect operator handle exposed via ``exir_ops``. """ if tosa_lib.ns not in _BACKEND_OP_LIB: @@ -46,19 +52,11 @@ def register_tosa_dialect_op(op_schema, func) -> Callable: overload_name = "default" op_qualified_name = f"{tosa_lib.ns}::{opname}" - register_fake(op_qualified_name, func, lib=tosa_lib) - + register_fake(op_qualified_name, fake_func, lib=tosa_lib) + if real_func is not None: + tosa_lib.impl(opname, real_func, real_dispatch_key) op = getattr(getattr(getattr(exir_ops.backend, tosa_lib.ns), opname), overload_name) - # For now, since the TOSA operators are only used for lowering and serialization in the backend - # the op doesn't need to be callable. This can be changed in the future if needed to support - # execution of TOSA ops directly. - def not_callable(): - """Raise when the dialect op handle is invoked at runtime.""" - raise RuntimeError("TOSA dialect op is not callable") - - op.__equvalent_callable__ = not_callable - return op diff --git a/backends/arm/tosa/dialect/ops/avg_pool2d.py b/backends/arm/tosa/dialect/ops/avg_pool2d.py index 968b335fc7b..0acc39c29f7 100644 --- a/backends/arm/tosa/dialect/ops/avg_pool2d.py +++ b/backends/arm/tosa/dialect/ops/avg_pool2d.py @@ -8,7 +8,7 @@ import sympy # type: ignore[import-untyped] import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_shape_env, get_context_spec, @@ -105,7 +105,7 @@ def validate_avg_pool2d_args( ) -@register_fake_tosa_op( +@register_tosa_op( "AVG_POOL2D(Tensor input, Tensor input_zp, Tensor output_zp, int[2] kernel, int[2] stride, SymInt[4] pad, ScalarType acc_type) -> Tensor", TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops/gather.py b/backends/arm/tosa/dialect/ops/gather.py index 49374142cd6..4c634445713 100644 --- a/backends/arm/tosa/dialect/ops/gather.py +++ b/backends/arm/tosa/dialect/ops/gather.py @@ -7,14 +7,14 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, ) -@register_fake_tosa_op( +@register_tosa_op( "GATHER(Tensor values, Tensor indices) -> Tensor", TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops/identity.py b/backends/arm/tosa/dialect/ops/identity.py index 6e26d8e8b22..ff8a366e72d 100644 --- a/backends/arm/tosa/dialect/ops/identity.py +++ b/backends/arm/tosa/dialect/ops/identity.py @@ -4,11 +4,11 @@ # LICENSE file in the root directory of this source tree. import torch -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import TosaSpecification -@register_fake_tosa_op( +@register_tosa_op( "IDENTITY(Tensor input) -> Tensor", TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops_registration.py b/backends/arm/tosa/dialect/ops_registration.py index 6581673bfec..3cfa2a6281b 100644 --- a/backends/arm/tosa/dialect/ops_registration.py +++ b/backends/arm/tosa/dialect/ops_registration.py @@ -7,6 +7,9 @@ from typing import Callable, Iterable, List, ParamSpec, TypeVar from executorch.backends.arm.tosa.dialect.lib import register_tosa_dialect_op +from executorch.backends.arm.tosa.dialect.real_impl import ( + make_tosa_reference_model_impl, +) from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -41,10 +44,32 @@ def register_fake_tosa_op( """ + return register_tosa_op(op_schema, tosa_specs, include_real_impl=False) + + +def register_tosa_op( + op_schema: str, + tosa_specs: Iterable[TosaSpecification], + include_real_impl: bool = True, +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """Register a TOSA op with fake/meta and optional real eager execution.""" + def decorator(func: Callable[P, R]) -> Callable[P, R]: # Only call register_tosa_dialect_op if the function hasn't been registered yet. if func not in _registered_tosa_ops_by_func: - op_callable = register_tosa_dialect_op(op_schema, func) + real_impl = None + if include_real_impl: + real_impl = make_tosa_reference_model_impl( + fake_func=func, + op_schema=op_schema, + ) + + op_callable = register_tosa_dialect_op( + op_schema, + fake_func=func, + real_func=real_impl, + real_dispatch_key="CompositeExplicitAutograd", + ) _registered_tosa_ops_by_func[func] = op_callable else: op_callable = _registered_tosa_ops_by_func[func] diff --git a/backends/arm/tosa/dialect/real_impl.py b/backends/arm/tosa/dialect/real_impl.py new file mode 100644 index 00000000000..a9bbd2df2ab --- /dev/null +++ b/backends/arm/tosa/dialect/real_impl.py @@ -0,0 +1,152 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import logging +from typing import Any, Callable + +import numpy as np +import torch +import tosa_reference_model as reference_model # type: ignore[import-not-found, import-untyped] +import tosa_serializer as ts +from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import get_context_spec +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + +logger = logging.getLogger(__name__) + + +def _torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: + tensor = tensor.detach().cpu() + if tensor.dtype == torch.bfloat16: + tensor = tensor.view(torch.uint16) + return tensor.numpy() + + +def _numpy_to_torch_tensor(array: np.ndarray, dtype: torch.dtype) -> torch.Tensor: + if array.dtype.type is np.void: + return torch.frombuffer(array, dtype=dtype) + + tensor = torch.from_numpy(array) + if dtype == torch.bfloat16: + return tensor.view(torch.bfloat16) + return tensor + + +def make_tosa_reference_model_impl( + fake_func: Callable, + op_schema: str, +) -> Callable: + """Create a real eager implementation from a fake TOSA dialect op.""" + + signature = inspect.signature(fake_func) + op_name = op_schema.split("(")[0] + + def real_impl(*args, **kwargs) -> torch.Tensor: + + bound = signature.bind(*args, **kwargs) + bound.apply_defaults() + normalized_args = list(bound.arguments.values()) + + fake_output = fake_func(*args, **kwargs) + if not isinstance(fake_output, torch.Tensor): + raise TypeError( + f"Only single-tensor outputs are supported for real TOSA op execution, got {type(fake_output).__name__}" + ) + + tensor_args = [arg for arg in normalized_args if isinstance(arg, torch.Tensor)] + if not tensor_args: + raise ValueError( + f"Real TOSA op execution requires at least one tensor input: {op_name}" + ) + + graph = torch.fx.Graph() + node_args: list[Any] = [] + placeholder_nodes: list[torch.fx.Node] = [] + op_handle = getattr(exir_ops.backend.tosa, op_name).default + + with FakeTensorMode(allow_non_fake_inputs=True) as mode: + for parameter, arg in zip(signature.parameters.values(), normalized_args): + if isinstance(arg, torch.Tensor): + placeholder = graph.placeholder(parameter.name) + placeholder.meta["val"] = mode.from_tensor(arg.detach().cpu()) + placeholder_nodes.append(placeholder) + node_args.append(placeholder) + else: + node_args.append(arg) + + op_node = graph.call_function(op_handle, tuple(node_args), {}) + op_node.meta["val"] = mode.from_tensor(fake_output.detach().cpu()) + graph.output((op_node,)) + + tosa_spec = get_context_spec() + version = tosa_spec.version + tosa_graph = ts.TosaSerializer( + "", + targetMajor=version.major, + targetMinor=version.minor, + targetPatch=version.micro, + targetDraft=False, + ) + + for node in placeholder_nodes: + arg = TosaArg(node, tosa_spec) + tosa_graph.addInputTensor( + ts.TosaSerializerTensor(arg.name, list(arg.shape), arg.dtype, data=None) + ) + + output_arg = TosaArg(op_node, tosa_spec) + tosa_graph.currRegion.currBasicBlock.addTensor( + output_arg.name, + list(output_arg.shape), + output_arg.dtype, + ) + from executorch.backends.arm.operators.node_visitor import get_node_visitor + + visitor = get_node_visitor(f"tosa.{op_name}.default", tosa_spec) + visitor.define_node( + op_node, + tosa_graph, + [TosaArg(arg, tosa_spec) for arg in op_node.args], + output_arg, + ) + tosa_graph.addOutputTensor( + tosa_graph.currRegion.currBasicBlock.tensors[output_arg.name] + ) + + outputs_np, status = reference_model.run( + tosa_graph.serialize(), + [_torch_tensor_to_numpy(arg) for arg in tensor_args], + verbosity=_tosa_refmodel_loglevel(logger.getEffectiveLevel()), + initialize_variable_tensor_from_numpy=True, + debug_mode="ALL" if logger.isEnabledFor(logging.DEBUG) else None, + ) + if status != reference_model.GraphStatus.TOSA_VALID: + raise RuntimeError( + f"TOSA reference model rejected tosa.{op_name} graph: {status}" + ) + + return _numpy_to_torch_tensor(outputs_np[0], fake_output.dtype).to( + device=tensor_args[0].device + ) + + return real_impl + + +def _tosa_refmodel_loglevel(loglevel: int) -> str: + """Converts a logging loglevel to tosa_reference_model logginglevel, + returned as string. + """ + loglevel_map = { + logging.INFO: "INFO", + logging.CRITICAL: "LOW", + logging.ERROR: "LOW", + logging.WARNING: "MED", + logging.DEBUG: "HIGH", + logging.NOTSET: "MED", + } + clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0) + return loglevel_map[clamped_logging_level] From 96c234b442ea77df9f18f840ab26ac1d32fb44e6 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Thu, 28 May 2026 11:47:25 +0200 Subject: [PATCH 2/2] Arm backend: Add real impls to TOSA dialect ops Additionally, - Start pre-computing TOSA ops with real impl in ComputeOpsAOT. - Start running the model in tests were this was previously impossible due to ops not having a real impl. - Some modifications are needed in real_impl to run operators with kwargs. Signed-off-by: Erik Lundell Change-Id: I94ed6aa08842d8cd57e9f0fb331edc5261b8d044 --- .../arm/_passes/fuse_constant_ops_pass.py | 45 +++++-- backends/arm/operators/node_visitor.py | 3 + .../misc/tosa_dialect/test_tosa_rescale.py | 60 +++++++++ .../test_ensure_unique_output_nodes_pass.py | 2 - .../passes/test_fuse_constant_ops_pass.py | 124 ++++++++++++++++-- .../arm/test/passes/test_rewrite_conv_pass.py | 2 - .../passes/test_rewrite_max_pool2d_pass.py | 12 +- backends/arm/tosa/dialect/ops/activation.py | 10 +- .../tosa/dialect/ops/avg_pool2d_adaptive.py | 4 +- .../tosa/dialect/ops/cast_to_block_scaled.py | 4 +- backends/arm/tosa/dialect/ops/conv2d.py | 4 +- backends/arm/tosa/dialect/ops/conv3d.py | 4 +- backends/arm/tosa/dialect/ops/custom.py | 4 +- .../arm/tosa/dialect/ops/depthwise_conv2d.py | 4 +- backends/arm/tosa/dialect/ops/matmul.py | 4 +- .../tosa/dialect/ops/matmul_t_block_scaled.py | 4 +- backends/arm/tosa/dialect/ops/max_pool2d.py | 4 +- .../tosa/dialect/ops/max_pool2d_adaptive.py | 4 +- backends/arm/tosa/dialect/ops/rescale.py | 4 +- backends/arm/tosa/dialect/ops/resize.py | 4 +- backends/arm/tosa/dialect/ops/scatter.py | 4 +- backends/arm/tosa/dialect/real_impl.py | 45 +++++-- .../test_quantize_fused_convbn_bias_pass.py | 2 +- 23 files changed, 275 insertions(+), 82 deletions(-) create mode 100644 backends/arm/test/misc/tosa_dialect/test_tosa_rescale.py diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 7115b3d6f0e..ed67461a30e 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -92,6 +92,20 @@ def _is_tosa_dialect_op(target) -> bool: or " bool: + schema = getattr(target, "_schema", None) + op_name = getattr(schema, "name", None) + if op_name is None: + return False + + try: + return torch._C._dispatch_has_kernel_for_dispatch_key( + op_name, "CompositeExplicitAutograd" + ) + except RuntimeError: + return False + @staticmethod def _arg_contains_symbolic_shape(arg) -> bool: if isinstance(arg, torch.fx.Node): @@ -197,19 +211,31 @@ def resolve_arg(arg, arg_index=None): return True + def maybe_delete(self, input_nodes_to_maybe_delete): + for input_node in input_nodes_to_maybe_delete: + if input_node.meta.get("is_input", False): + # Never delete submodule inputs, they need to match the parameters from the outer module. + continue + if len(input_node.users) == 0: + self._delete_constant_placeholder(input_node) + def call(self, graph_module): modified = False input_nodes_to_maybe_delete = set() for node in graph_module.graph.nodes: if node.op != "call_function": continue - # Don't fuse TOSA dialect ops as they do not have eager forward functions. - # Also don't fuse ops whose explicit args/kwargs include symbolic shape values. - if ( - self._is_tosa_dialect_op(node.target) - or self._arg_contains_symbolic_shape(node.args) - or self._arg_contains_symbolic_shape(node.kwargs) - ): + if node.target == exir_ops.backend.tosa.RESCALE.default: + # Leave fusing of RESCALES to the compiler. + continue + if self._is_tosa_dialect_op( + node.target + ) and not self._has_real_tosa_dialect_impl(node.target): + continue + # Don't fuse ops whose explicit args/kwargs include symbolic shape values. + if self._arg_contains_symbolic_shape( + node.args + ) or self._arg_contains_symbolic_shape(node.kwargs): continue input_nodes = node.all_input_nodes @@ -241,10 +267,7 @@ def call(self, graph_module): if modified: graph_module.graph.eliminate_dead_code() - for input_node in input_nodes_to_maybe_delete: - if len(input_node.users) == 0: - self._delete_constant_placeholder(input_node) - + self.maybe_delete(input_nodes_to_maybe_delete) graph_module = super().call(graph_module).graph_module return PassResult(graph_module, modified) diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 078d855b1d2..48433e8f396 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -250,6 +250,9 @@ def get_node_visitors(*args) -> Dict[str, NodeVisitor]: def get_node_visitor(target: str, tosa_spec: TosaSpecification): + # Ensure all operator modules are imported so visitors are registered. + import executorch.backends.arm.operators # noqa: F401 + node_visitor_tuples = _node_visitor_tuples.get(tosa_spec) for target_name, node_visitor_cls in node_visitor_tuples: if target_name == target: diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_rescale.py b/backends/arm/test/misc/tosa_dialect/test_tosa_rescale.py new file mode 100644 index 00000000000..cc52ac87970 --- /dev/null +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_rescale.py @@ -0,0 +1,60 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 +import pytest +import torch + +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"input_unsigned": False, "output_unsigned": False}, + ], +) +def test_rescale_real_impl_with_and_without_kwargs(kwargs): + input_tensor = torch.tensor( + [[1, -2, 3], [4, 0, -5]], + dtype=torch.int32, + ) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + fake_output = exir_ops.backend.tosa.RESCALE.default( + mode.from_tensor(input_tensor), + torch.int32, + [1.0], + 0, + 0, + **kwargs, + ) + + assert isinstance(fake_output, FakeTensor) + assert fake_output.dtype == torch.int32 + assert tuple(fake_output.shape) == tuple(input_tensor.shape) + + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + output = exir_ops.backend.tosa.RESCALE.default( + input_tensor, + torch.int32, + [1.0], + 0, + 0, + **kwargs, + ) + + assert not isinstance(output, FakeTensor) + assert output.dtype == torch.int32 + assert tuple(output.shape) == tuple(input_tensor.shape) + assert torch.equal(output, input_tensor) diff --git a/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py b/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py index 4dd03c1ca6e..077c7aae72c 100644 --- a/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py +++ b/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py @@ -35,7 +35,6 @@ def test_ensure_unique_output_nodes_no_target_inserts_identity_per_repeated_outp "executorch_exir_dialects_backend__ops_tosa_IDENTITY_default": 2, }, ) - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() graph_module = ( @@ -62,5 +61,4 @@ def test_ensure_unique_output_nodes_no_target_keeps_unique_outputs_unchanged() - "executorch_exir_dialects_backend__ops_tosa_IDENTITY_default", ], ) - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 9a07cd6e820..6b5884cd63e 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -16,6 +16,7 @@ from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.tosa.backend import TOSABackend from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, @@ -24,6 +25,7 @@ from executorch.backends.test.harness.stages import StageType from executorch.backends.test.program_builder import ProgramBuilder from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_cond_while_submodules from torch.export.graph_signature import InputKind input_t = Tuple[torch.Tensor] # Input x @@ -341,6 +343,52 @@ def test_fuse_constant_args_fuses_chains_without_recompile() -> None: torch.testing.assert_close(actual, expected) +def test_fuse_constant_args_preserves_unused_control_flow_inputs() -> None: + class CondWithCapturedBuffer(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("buf", torch.ones(1, 1, 1, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(buf: torch.Tensor, arg: torch.Tensor) -> torch.Tensor: + return arg + buf.view(1, 1, 1, 1) + + def false_branch(buf: torch.Tensor, arg: torch.Tensor) -> torch.Tensor: + return arg * buf + + return torch.cond(x.sum() > 0, true_branch, false_branch, [self.buf, x]) + + compile_spec = common.get_tosa_compile_spec( + TosaSpecification.create_from_string("TOSA-1.0+FP+cf") + ) + tester = ArmTester( + CondWithCapturedBuffer(), + example_inputs=(torch.randn(1, 1, 2, 2),), + compile_spec=compile_spec, + ) + tester.export().to_edge() + exported_program = tester.get_artifact(StageType.TO_EDGE).exported_program() + + submodule_0 = exported_program.graph_module.get_submodule("submodule_0") + _, _, cond_node = get_cond_while_submodules(exported_program.graph_module)[0] + TOSABackend._regularize_submodule(submodule_0, cond_node) + + pass_result = FuseConstantArgsPass(exported_program).call(submodule_0) + + placeholders = [ + node + for node in pass_result.graph_module.graph.nodes + if node.op == "placeholder" + ] + assert [node.name for node in placeholders] == [ + "aten_view_copy_default_fused_const", + "b_buf", + "x", + ] + assert placeholders[1].meta["is_input"] is True + assert len(placeholders[1].users) == 0 + + def test_fuse_constant_args_identifies_tosa_dialect_targets() -> None: class FakeTosaTarget: def __str__(self) -> str: @@ -351,6 +399,15 @@ def __str__(self) -> str: exir_ops.backend.tosa.GATHER.default ) assert not FuseConstantArgsPass._is_tosa_dialect_op(torch.ops.aten.add.Tensor) + assert FuseConstantArgsPass._has_real_tosa_dialect_impl( + exir_ops.backend.tosa.GATHER.default + ) + assert FuseConstantArgsPass._has_real_tosa_dialect_impl( + exir_ops.backend.tosa.RESCALE.default + ) + assert not FuseConstantArgsPass._has_real_tosa_dialect_impl( + exir_ops.backend.tosa.TABLE.default + ) def test_fuse_constant_args_identifies_symbolic_shape_args() -> None: @@ -364,24 +421,63 @@ def test_fuse_constant_args_identifies_symbolic_shape_args() -> None: ) -def test_fuse_constant_args_skips_backend_tosa_gather(caplog) -> None: - with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")): +def test_fuse_constant_args_skips_tosa_ops_without_real_impl(caplog) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): builder = ProgramBuilder() - values = builder.placeholder( - "values", - torch.randn(1, 4, 3), + input_tensor = builder.placeholder( + "input_tensor", + torch.tensor([-3, 2], dtype=torch.int8), input_kind=InputKind.CONSTANT_TENSOR, ) - indices = builder.placeholder( - "indices", - torch.tensor([[0, 2]], dtype=torch.int32), + table = builder.placeholder( + "table", + torch.arange(256, dtype=torch.int16).to(torch.int8), + input_kind=InputKind.CONSTANT_TENSOR, + ) + table_lookup = builder.call_operator( + exir_ops.backend.tosa.TABLE.default, + (input_tensor, table), + ) + builder.output([table_lookup]) + + exported_program = builder.get_program() + graph_module = exported_program.graph_module + + with caplog.at_level("WARNING"): + FuseConstantArgsPass(exported_program)(graph_module) + + warning_messages = [ + record.getMessage() + for record in caplog.records + if record.name == "executorch.backends.arm._passes.fuse_constant_ops_pass" + ] + assert not any( + "Failed to fuse constant op" in message and "TABLE" in message + for message in warning_messages + ) + assert ( + sum( + node.op == "call_function" + and node.target == exir_ops.backend.tosa.TABLE.default + for node in graph_module.graph.nodes + ) + == 1 + ) + + +def test_fuse_constant_args_skips_tosa_rescale(caplog) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + builder = ProgramBuilder() + input_tensor = builder.placeholder( + "input_tensor", + torch.tensor([1, 2], dtype=torch.int32), input_kind=InputKind.CONSTANT_TENSOR, ) - gather = builder.call_operator( - exir_ops.backend.tosa.GATHER.default, - (values, indices), + rescale = builder.call_operator( + exir_ops.backend.tosa.RESCALE.default, + (input_tensor, torch.int8, [1.0], 0, 0), ) - builder.output([gather]) + builder.output([rescale]) exported_program = builder.get_program() graph_module = exported_program.graph_module @@ -395,13 +491,13 @@ def test_fuse_constant_args_skips_backend_tosa_gather(caplog) -> None: if record.name == "executorch.backends.arm._passes.fuse_constant_ops_pass" ] assert not any( - "Failed to fuse constant op" in message and "GATHER" in message + "Failed to fuse constant op" in message and "RESCALE" in message for message in warning_messages ) assert ( sum( node.op == "call_function" - and node.target == exir_ops.backend.tosa.GATHER.default + and node.target == exir_ops.backend.tosa.RESCALE.default for node in graph_module.graph.nodes ) == 1 diff --git a/backends/arm/test/passes/test_rewrite_conv_pass.py b/backends/arm/test/passes/test_rewrite_conv_pass.py index 736aa685b86..de8135bc049 100644 --- a/backends/arm/test/passes/test_rewrite_conv_pass.py +++ b/backends/arm/test/passes/test_rewrite_conv_pass.py @@ -213,8 +213,6 @@ def test_rewrite_conv_tosa_FP(): pipeline = PassPipeline( module, module.get_inputs(), passes_with_exported_program=[RewriteConvPass] ) - # We cannot run TOSA backend dialect operators in eager mode. - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() diff --git a/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py b/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py index 52efb0929f2..9ba31f1e6e0 100644 --- a/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py +++ b/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py @@ -37,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MaxPool2dWithoutStride(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 9, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.max_pool2d(x, kernel_size=3) @@ -45,7 +45,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MaxPool2dListKernel(torch.nn.Module): def get_inputs(self) -> input_t: - return (torch.rand(1, 3, 8, 8),) + return (torch.rand(1, 3, 8, 9),) def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.max_pool2d(x, kernel_size=[2, 3]) @@ -56,7 +56,7 @@ def get_inputs(self) -> input_t: return (torch.rand(1, 3, 8, 8),) def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.max_pool2d(x, kernel_size=[2, 3], stride=[]) + return torch.nn.functional.max_pool2d(x, kernel_size=[2, 2], stride=[]) class MaxPool2dDynamic(torch.nn.Module): @@ -94,9 +94,6 @@ def test_rewrite_max_pool2d_tosa(module: ModuleWithInputs) -> None: }, pass_list=[RemoveGetItemPass, RewriteMaxPool2dPass], ) - pipeline.pop_stage( - "run_method_and_compare_outputs" - ) # Cannnot run aten graph with tosa dialect ops pipeline.run() @@ -131,11 +128,10 @@ def test_rewrite_max_pool2d_tosa_empty_stride_uses_kernel_size() -> None: }, pass_list=[RemoveGetItemPass, RewriteMaxPool2dPass], ) - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() tosa_node = _get_tosa_max_pool2d_node(pipeline) - assert tosa_node.args[2] == [2, 3] + assert tosa_node.args[2] == [2, 2] def test_rewrite_max_pool2d_tosa_dynamic_shape() -> None: diff --git a/backends/arm/tosa/dialect/ops/activation.py b/backends/arm/tosa/dialect/ops/activation.py index 3c3fbffe176..a17110a7e85 100644 --- a/backends/arm/tosa/dialect/ops/activation.py +++ b/backends/arm/tosa/dialect/ops/activation.py @@ -8,7 +8,7 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError from executorch.backends.arm.tosa.dialect.ops._common import validate_nan_mode -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, @@ -84,7 +84,7 @@ def _validate_integer_clamp_bounds( ) -@register_fake_tosa_op( +@register_tosa_op( 'CLAMP(Tensor input, Scalar min_val, Scalar max_val, *, str nan_mode="PROPAGATE") -> Tensor', TosaSpecification.all_versions_and_profiles(), ) @@ -111,7 +111,7 @@ def CLAMP( return torch.empty_like(input, dtype=input.dtype) -@register_fake_tosa_op( +@register_tosa_op( "ERF(Tensor input) -> Tensor", FP_SPECS, ) @@ -120,7 +120,7 @@ def ERF(input: torch.Tensor) -> torch.Tensor: return torch.empty_like(input, dtype=input.dtype) -@register_fake_tosa_op( +@register_tosa_op( "SIGMOID(Tensor input) -> Tensor", FP_SPECS, ) @@ -129,7 +129,7 @@ def SIGMOID(input: torch.Tensor) -> torch.Tensor: return torch.empty_like(input, dtype=input.dtype) -@register_fake_tosa_op( +@register_tosa_op( "TANH(Tensor input) -> Tensor", FP_SPECS, ) diff --git a/backends/arm/tosa/dialect/ops/avg_pool2d_adaptive.py b/backends/arm/tosa/dialect/ops/avg_pool2d_adaptive.py index 7d71b85eca7..6d2195abd6c 100644 --- a/backends/arm/tosa/dialect/ops/avg_pool2d_adaptive.py +++ b/backends/arm/tosa/dialect/ops/avg_pool2d_adaptive.py @@ -11,7 +11,7 @@ compute_avg_pool2d_output_shape, validate_avg_pool2d_dtype, ) -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_shape_env, get_context_spec, @@ -36,7 +36,7 @@ def _is_directly_representable( return remainder in (0, 1) -@register_fake_tosa_op( +@register_tosa_op( "AVG_POOL2D_ADAPTIVE(Tensor input, Tensor input_zp, Tensor output_zp, SymInt[2] kernel, SymInt[2] stride, SymInt[4] pad, ScalarType acc_type) -> Tensor", TosaSpecification.all_profiles_for_version("1.1"), ) diff --git a/backends/arm/tosa/dialect/ops/cast_to_block_scaled.py b/backends/arm/tosa/dialect/ops/cast_to_block_scaled.py index 8dbff7c11c5..01369696d19 100644 --- a/backends/arm/tosa/dialect/ops/cast_to_block_scaled.py +++ b/backends/arm/tosa/dialect/ops/cast_to_block_scaled.py @@ -11,7 +11,7 @@ from executorch.backends.arm.ao_ext.mxfp import mxfp_str_to_dtype from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, @@ -19,7 +19,7 @@ from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2 -@register_fake_tosa_op( +@register_tosa_op( "CAST_TO_BLOCK_SCALED(Tensor input, SymInt block_size, str output_dtype) -> (Tensor, Tensor)", [TosaSpecification.create_from_string("TOSA-1.1+FP")], ) diff --git a/backends/arm/tosa/dialect/ops/conv2d.py b/backends/arm/tosa/dialect/ops/conv2d.py index d0db2d60fcd..0320efd03b5 100644 --- a/backends/arm/tosa/dialect/ops/conv2d.py +++ b/backends/arm/tosa/dialect/ops/conv2d.py @@ -7,7 +7,7 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, @@ -105,7 +105,7 @@ def conv_output_dim( return (input_dim + total_pad - receptive_field) // stride + 1 -@register_fake_tosa_op( +@register_tosa_op( "CONV2D(Tensor input, " "Tensor weight, " "Tensor bias, " diff --git a/backends/arm/tosa/dialect/ops/conv3d.py b/backends/arm/tosa/dialect/ops/conv3d.py index a81ae0dae53..39f54f8355f 100644 --- a/backends/arm/tosa/dialect/ops/conv3d.py +++ b/backends/arm/tosa/dialect/ops/conv3d.py @@ -11,7 +11,7 @@ conv_output_dim, validate_conv2d_args_dtypes, ) -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, @@ -32,7 +32,7 @@ def validate_conv3d_args_dtypes( return validate_conv2d_args_dtypes(tosa_spec, x, weight, bias, op="CONV3D") -@register_fake_tosa_op( +@register_tosa_op( "CONV3D(Tensor input, " "Tensor weight, " "Tensor bias, " diff --git a/backends/arm/tosa/dialect/ops/custom.py b/backends/arm/tosa/dialect/ops/custom.py index 6376124d6f2..0313456ad9e 100644 --- a/backends/arm/tosa/dialect/ops/custom.py +++ b/backends/arm/tosa/dialect/ops/custom.py @@ -31,7 +31,7 @@ from collections.abc import Callable import torch -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -132,7 +132,7 @@ def run_registered_fake_tosa_impl( return outputs -@register_fake_tosa_op( +@register_tosa_op( "CUSTOM(Tensor[] inputs, str operator_name, str domain_name, int[] implementation_attrs) -> Tensor[]", TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py index 83ef3ff72fb..91609cd6b62 100644 --- a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py +++ b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py @@ -8,7 +8,7 @@ conv_output_dim, validate_conv2d_args_dtypes, ) -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -16,7 +16,7 @@ ) -@register_fake_tosa_op( +@register_tosa_op( "DEPTHWISE_CONV2D(Tensor input, " "Tensor weight, " "Tensor bias, " diff --git a/backends/arm/tosa/dialect/ops/matmul.py b/backends/arm/tosa/dialect/ops/matmul.py index 8023df88072..6d035c0e5a6 100644 --- a/backends/arm/tosa/dialect/ops/matmul.py +++ b/backends/arm/tosa/dialect/ops/matmul.py @@ -5,7 +5,7 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -14,7 +14,7 @@ from executorch.exir.dialects._ops import ops as exir_ops -@register_fake_tosa_op( +@register_tosa_op( "MATMUL(Tensor input1, Tensor input2) -> Tensor", # schema TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py b/backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py index fcea104320f..49b1c07129e 100644 --- a/backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py +++ b/backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py @@ -13,7 +13,7 @@ SUPPORTED_MXFP_DTYPES, ) from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, TosaSpecification, @@ -128,7 +128,7 @@ def _validate_shapes( return N, H, W -@register_fake_tosa_op( +@register_tosa_op( "MATMUL_T_BLOCK_SCALED(Tensor A_data, Tensor A_scale, Tensor B_data, " "Tensor B_scale, SymInt block_size, str payload_dtype='') -> Tensor", [TosaSpecification.create_from_string("TOSA-1.1+FP")], diff --git a/backends/arm/tosa/dialect/ops/max_pool2d.py b/backends/arm/tosa/dialect/ops/max_pool2d.py index 1b1a399a757..9b93d79a71b 100644 --- a/backends/arm/tosa/dialect/ops/max_pool2d.py +++ b/backends/arm/tosa/dialect/ops/max_pool2d.py @@ -8,7 +8,7 @@ import sympy # type: ignore[import-untyped] import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_shape_env, get_context_spec, @@ -68,7 +68,7 @@ def validate_max_pool2d_dtype( raise TosaValueError(f"Unsupported input dtype {x.dtype} pools", op=op) -@register_fake_tosa_op( +@register_tosa_op( "MAX_POOL2D(Tensor input, int[2] kernel, int[2] stride, SymInt[4] pad) -> Tensor", TosaSpecification.all_versions_and_profiles(), ) diff --git a/backends/arm/tosa/dialect/ops/max_pool2d_adaptive.py b/backends/arm/tosa/dialect/ops/max_pool2d_adaptive.py index 605d94d2af1..193d083f012 100644 --- a/backends/arm/tosa/dialect/ops/max_pool2d_adaptive.py +++ b/backends/arm/tosa/dialect/ops/max_pool2d_adaptive.py @@ -10,7 +10,7 @@ compute_max_pool2d_output_shape, validate_max_pool2d_dtype, ) -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_shape_env, get_context_spec, @@ -35,7 +35,7 @@ def _is_directly_representable( return remainder in (0, 1) -@register_fake_tosa_op( +@register_tosa_op( "MAX_POOL2D_ADAPTIVE(Tensor input, SymInt[2] kernel, SymInt[2] stride, SymInt[4] pad) -> Tensor", TosaSpecification.all_profiles_for_version("1.1"), ) diff --git a/backends/arm/tosa/dialect/ops/rescale.py b/backends/arm/tosa/dialect/ops/rescale.py index c782ab4ae81..0ea5e4b25aa 100644 --- a/backends/arm/tosa/dialect/ops/rescale.py +++ b/backends/arm/tosa/dialect/ops/rescale.py @@ -7,7 +7,7 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -15,7 +15,7 @@ ) -@register_fake_tosa_op( +@register_tosa_op( "RESCALE(Tensor input1, ScalarType dtype, float[] scale, int in_zp, int out_zp, *, bool input_unsigned=False, bool output_unsigned=False) -> Tensor", # schema TosaSpecification.all_versions_for_profile("INT"), # target TOSA specifications ) diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py index 0d06253ccd8..18ebe6c6210 100644 --- a/backends/arm/tosa/dialect/ops/resize.py +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -7,7 +7,7 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.resize_utils import ( calculate_tosa_resize_output_hw, get_tosa_resize_output_hw_validation_error, @@ -68,7 +68,7 @@ def _validate_resize_parameters(input_hw, output_hw, scale, offset, border, tosa raise TosaValueError(validation_error, op="RESIZE") -@register_fake_tosa_op( +@register_tosa_op( "RESIZE(Tensor input, SymInt[4] scale_factors, SymInt[2] offset, SymInt[2] border, *, str resize_mode) -> Tensor", # schema TosaSpecification.all_versions_and_profiles(), # target TOSA specifications ) diff --git a/backends/arm/tosa/dialect/ops/scatter.py b/backends/arm/tosa/dialect/ops/scatter.py index 6f13bd6154a..b42b2a5ae30 100644 --- a/backends/arm/tosa/dialect/ops/scatter.py +++ b/backends/arm/tosa/dialect/ops/scatter.py @@ -5,12 +5,12 @@ import torch -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op from executorch.backends.arm.tosa.specification import TosaSpecification -@register_fake_tosa_op( +@register_tosa_op( "SCATTER(Tensor values_in, Tensor indices, Tensor input) -> Tensor", # schema TosaSpecification.all_versions_and_profiles(), # target TOSA specifications ) diff --git a/backends/arm/tosa/dialect/real_impl.py b/backends/arm/tosa/dialect/real_impl.py index a9bbd2df2ab..65ba4441f12 100644 --- a/backends/arm/tosa/dialect/real_impl.py +++ b/backends/arm/tosa/dialect/real_impl.py @@ -23,16 +23,26 @@ def _torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: tensor = tensor.detach().cpu() if tensor.dtype == torch.bfloat16: tensor = tensor.view(torch.uint16) + elif tensor.dtype == torch.float8_e4m3fn: + tensor = tensor.view(torch.uint8) + elif tensor.dtype == torch.float8_e5m2: + tensor = tensor.view(torch.uint8) return tensor.numpy() -def _numpy_to_torch_tensor(array: np.ndarray, dtype: torch.dtype) -> torch.Tensor: +def _numpy_to_torch_tensor( + array: np.ndarray, expected_dtype: torch.dtype, expected_shape: torch.Size +) -> torch.Tensor: if array.dtype.type is np.void: - return torch.frombuffer(array, dtype=dtype) + return torch.frombuffer(array, dtype=expected_dtype).reshape(expected_shape) tensor = torch.from_numpy(array) - if dtype == torch.bfloat16: - return tensor.view(torch.bfloat16) + if expected_dtype == torch.bfloat16: + tensor = tensor.view(torch.bfloat16) + elif expected_dtype == torch.float8_e4m3fn: + tensor = tensor.view(torch.float8_e4m3fn) + elif expected_dtype == torch.float8_e5m2: + tensor = tensor.view(torch.float8_e5m2) return tensor @@ -65,22 +75,31 @@ def real_impl(*args, **kwargs) -> torch.Tensor: graph = torch.fx.Graph() node_args: list[Any] = [] + node_kwargs: dict[str, Any] = {} placeholder_nodes: list[torch.fx.Node] = [] op_handle = getattr(exir_ops.backend.tosa, op_name).default - with FakeTensorMode(allow_non_fake_inputs=True) as mode: for parameter, arg in zip(signature.parameters.values(), normalized_args): if isinstance(arg, torch.Tensor): placeholder = graph.placeholder(parameter.name) placeholder.meta["val"] = mode.from_tensor(arg.detach().cpu()) placeholder_nodes.append(placeholder) - node_args.append(placeholder) + fx_arg = placeholder else: - node_args.append(arg) - - op_node = graph.call_function(op_handle, tuple(node_args), {}) + fx_arg = arg + if parameter.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + node_args.append(fx_arg) + elif parameter.kind == inspect.Parameter.KEYWORD_ONLY: + node_kwargs[parameter.name] = fx_arg + else: + raise NotImplementedError( + f"Unsupported parameter kind for tosa.{op_name}: {parameter.kind}" + ) + op_node = graph.call_function(op_handle, tuple(node_args), node_kwargs) op_node.meta["val"] = mode.from_tensor(fake_output.detach().cpu()) - graph.output((op_node,)) tosa_spec = get_context_spec() version = tosa_spec.version @@ -129,9 +148,9 @@ def real_impl(*args, **kwargs) -> torch.Tensor: f"TOSA reference model rejected tosa.{op_name} graph: {status}" ) - return _numpy_to_torch_tensor(outputs_np[0], fake_output.dtype).to( - device=tensor_args[0].device - ) + return _numpy_to_torch_tensor( + outputs_np[0], fake_output.dtype, fake_output.shape + ).to(device=tensor_args[0].device) return real_impl diff --git a/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py b/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py index f8d0269630b..80f62ddd5f6 100644 --- a/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py +++ b/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py @@ -10,7 +10,7 @@ # Stub modules that are transitively imported by arm_quantizer but never # exercised by these tests. -for _mod in ("tosa_serializer", "tosa", "tosa.TosaGraph"): +for _mod in ("tosa_serializer", "tosa", "tosa.TosaGraph, tosa_reference_model"): if _mod not in sys.modules: try: __import__(_mod)