From 2861e1031185133e8391652d046cdedccec9fbd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Thu, 28 May 2026 12:11:02 +0200 Subject: [PATCH 1/2] XNNPACK: Lift constant mul scalars for partitioning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit XNNPACK supports the tensor overload for multiply, but plain aten.mul.Scalar is not selected by the XNNPACK partitioner. Adds a narrow scalar-lifting pass that rewrites aten.mul.Scalar to aten.mul.Tensor by registering the scalar as a small constant buffer. This avoids introducing an aten.full op while allowing the existing multiply lowering path to partition the op. Keep SDPA scale multipliers as aten.mul.Scalar so ConvertToSDPAPass can still recover the attention scale before replacing the pattern. Add test coverage for that guard. Allow the XNNPACK tester to pass transform passes through to to_edge_transform_and_lower. This keeps op tests on the same path as existing XNNPACK model tests that already use explicit transform passes. For DeIT Tiny, this removes 24 portable aten.mul.Scalar nodes and reduces delegate count from 62 to 50. In current local timing checks the latency impact is modest: about 1% faster on both Android SME2 and the aarch64 XNNPACK/KleidiAI NEON-class host runner. These are modest uplifts but may introduce more opportunities for improvements. Signed-off-by: Måns Nilsson Change-Id: I83b6ad53925edb72afdf0077b5dbb99b5d9c4648 --- .../lift_constant_scalar_operands_pass.py | 151 ++++++++++++++++++ backends/xnnpack/test/ops/test_multiply.py | 31 ++++ ...test_lift_constant_scalar_operands_pass.py | 74 +++++++++ backends/xnnpack/test/tester/tester.py | 46 +++++- backends/xnnpack/utils/configs.py | 5 +- 5 files changed, 304 insertions(+), 3 deletions(-) create mode 100644 backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py create mode 100644 backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py diff --git a/backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py b/backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py new file mode 100644 index 00000000000..8c6be2b62b5 --- /dev/null +++ b/backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from numbers import Number +from typing import Dict, Optional, Union + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from torch._ops import OpOverload + + +ScalarOp = Union[EdgeOpOverload, OpOverload] + + +class LiftConstantScalarOperandsPass(ExportPass): + """ + Lift scalar operands into tensor constants for selected binary ops. + + XNNPACK already supports the tensor overloads for these binary operations. + This pass converts explicitly listed scalar overloads to their tensor + overloads by replacing constant scalar operands with small tensor constants. + The constants are registered as buffers so they do not become portable + ``full`` kernels. Keep the op map narrow until each new scalar overload is + covered by tests. + """ + + default_scalar_to_tensor_ops: Dict[ScalarOp, ScalarOp] = { + exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor, + } + sdpa_passthrough_ops = { + exir_ops.edge.aten.expand_copy.default, + exir_ops.edge.aten.view_copy.default, + } + + def __init__( + self, + scalar_to_tensor_ops: Optional[Dict[ScalarOp, ScalarOp]] = None, + ) -> None: + super().__init__() + self.scalar_to_tensor_ops = ( + scalar_to_tensor_ops + if scalar_to_tensor_ops is not None + else self.default_scalar_to_tensor_ops + ) + self._modified = False + + def _create_constant_node( + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + value: Number, + ) -> torch.fx.Node: + input_node = node.args[0] + if not isinstance(input_node, torch.fx.Node): + raise RuntimeError("Expected scalar op input to be an FX node.") + + input_value = input_node.meta["val"] + tensor = torch.tensor(value, dtype=input_value.dtype, device=input_value.device) + name = self._get_new_attr_name(graph_module) + graph_module.register_buffer(name, tensor) + + fake_mode = node.meta["val"].fake_mode + with graph_module.graph.inserting_before(node): + constant_node = graph_module.graph.get_attr(name) + constant_node.meta["val"] = fake_mode.from_tensor( + tensor, static_shapes=True + ) + return constant_node + + def _get_new_attr_name(self, graph_module: torch.fx.GraphModule) -> str: + prefix = "_tensor_constant_" + index = 0 + while hasattr(graph_module, f"{prefix}{index}"): + index += 1 + return f"{prefix}{index}" + + def _feeds_sdpa_qk_bmm(self, node: torch.fx.Node) -> bool: + """ + Return true for the scale muls consumed by XNNPACK's SDPA pattern. + + ConvertToSDPAPass recovers the user-specified attention scale from the + pre-QK^T ``aten.mul.Scalar`` nodes. Keep those scalar muls intact so + SDPA conversion can still find the scale before replacing the pattern. + """ + users_to_visit = list(node.users) + visited = set() + while users_to_visit: + user = users_to_visit.pop() + if user in visited: + continue + visited.add(user) + + if ( + user.op == "call_function" + and user.target == exir_ops.edge.aten.bmm.default + ): + return True + + if user.op == "call_function" and user.target in self.sdpa_passthrough_ops: + users_to_visit.extend(user.users) + + return False + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self._modified = False + + for node in list(graph_module.graph.nodes): + if ( + node.op != "call_function" + or node.target not in self.scalar_to_tensor_ops + or len(node.args) != 2 + or not isinstance(node.args[0], torch.fx.Node) + or not isinstance(node.args[1], Number) + ): + continue + + if ( + node.target == exir_ops.edge.aten.mul.Scalar + and self._feeds_sdpa_qk_bmm(node) + ): + continue + + input_value = node.args[0].meta.get("val") + output_value = node.meta.get("val") + if ( + input_value is None + or output_value is None + or input_value.dtype != output_value.dtype + ): + continue + + tensor_arg = self._create_constant_node(graph_module, node, node.args[1]) + node.args = (node.args[0], tensor_arg) + node.target = self.scalar_to_tensor_ops[node.target] + self._modified = True + + graph_module.graph.eliminate_dead_code() + graph_module.graph.lint() + graph_module.recompile() + + modified = self._modified + self._modified = False + return PassResult(graph_module, modified) diff --git a/backends/xnnpack/test/ops/test_multiply.py b/backends/xnnpack/test/ops/test_multiply.py index 3315200005d..118136fcd08 100644 --- a/backends/xnnpack/test/ops/test_multiply.py +++ b/backends/xnnpack/test/ops/test_multiply.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,6 +9,10 @@ import torch from executorch.backends.xnnpack.test.tester import Tester +from executorch.backends.xnnpack.utils.configs import ( + get_transform_passes, + get_xnnpack_edge_compile_config, +) class TestMul(unittest.TestCase): @@ -29,6 +34,10 @@ def forward(self, x, y): z = torch.mul(x, y) * torch.functional.torch.mul(x, y) return z + class MulScalar(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.mul.Scalar(x, 0.5) + class MulRelu(torch.nn.Module): def forward(self, x, y): z = x * y @@ -58,6 +67,28 @@ def test_fp32_mul(self): inputs = (torch.randn((1, 3)), torch.randn((4, 3))) self._test_mul(inputs) + def test_fp32_mul_scalar(self): + ( + Tester(self.MulScalar(), (torch.randn(2, 3),)) + .export() + .to_edge_transform_and_lower( + transform_passes=get_transform_passes(), + edge_compile_config=get_xnnpack_edge_compile_config( + skip_dim_order=True + ), + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_mul_Scalar", + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + def test_qs8_mul(self): inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)) ( diff --git a/backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py b/backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py new file mode 100644 index 00000000000..5ad44f78af0 --- /dev/null +++ b/backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from copy import deepcopy + +import torch +from executorch.backends.xnnpack._passes.lift_constant_scalar_operands_pass import ( + LiftConstantScalarOperandsPass, +) +from executorch.backends.xnnpack.partition.graphs import sdpa +from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops + + +class TestLiftConstantScalarOperandsPass(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class MulScalar(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.mul.Scalar(x, 0.5) + + class AddScalar(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.add.Scalar(x, 0.5) + + def _to_edge_graph(self, module): + edge = to_edge( + torch.export.export(module, (torch.randn(2, 3),), strict=True), + compile_config=get_xnnpack_edge_compile_config(skip_dim_order=True), + ) + return edge.transform([LiftConstantScalarOperandsPass()]).exported_program() + + def test_lifts_mul_scalar_operand(self): + graph = self._to_edge_graph(self.MulScalar()).graph_module.graph + + self.assertFalse( + any(node.target == exir_ops.edge.aten.mul.Scalar for node in graph.nodes) + ) + self.assertTrue( + any(node.target == exir_ops.edge.aten.mul.Tensor for node in graph.nodes) + ) + self.assertTrue(any(node.op == "get_attr" for node in graph.nodes)) + + def test_keeps_unmapped_scalar_op(self): + graph = self._to_edge_graph(self.AddScalar()).graph_module.graph + + self.assertTrue( + any(node.target == exir_ops.edge.aten.add.Scalar for node in graph.nodes) + ) + + def test_keeps_sdpa_scale_mul_scalar(self): + graph_module = deepcopy(sdpa.get_graphs()[0]) + + LiftConstantScalarOperandsPass()(graph_module) + + scale_mul_count = 0 + lifted_mul_count = 0 + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target == exir_ops.edge.aten.mul.Scalar: + scale_mul_count += 1 + if node.target == exir_ops.edge.aten.mul.Tensor: + lifted_mul_count += 1 + + self.assertEqual(scale_mul_count, 2) + self.assertEqual(lifted_mul_count, 0) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index fc12da231c0..481864e265f 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -24,9 +24,10 @@ QuantizationConfig, ) from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config -from executorch.exir import EdgeCompileConfig +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower from executorch.exir.backend.partitioner import Partitioner from torch._export.pass_base import PassType +from torch.export import ExportedProgram from torchao.quantization.pt2e.quantizer import Quantizer @@ -77,6 +78,7 @@ def __init__( self, partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, + transform_passes: Optional[List[PassType]] = None, ): super().__init__( default_partitioner_cls=XnnpackPartitioner, @@ -84,6 +86,21 @@ def __init__( edge_compile_config=edge_compile_config or get_xnnpack_edge_compile_config(), ) + self.transform_passes = transform_passes + + def run( + self, + artifact: ExportedProgram, + inputs=None, + generate_etrecord: bool = False, + ) -> None: + self.edge_dialect_program = to_edge_transform_and_lower( + artifact, + transform_passes=self.transform_passes, + compile_config=self.edge_compile_conf, + partitioner=self.partitioners, + generate_etrecord=generate_etrecord, + ) class Partition(BaseStages.Partition): @@ -132,3 +149,28 @@ def __init__( dynamic_shapes=dynamic_shapes, **kwargs, ) + + def to_edge_transform_and_lower( + self, + to_edge_and_lower_stage: Optional[BaseStages.ToEdgeTransformAndLower] = None, + generate_etrecord: bool = False, + *, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + transform_passes: Optional[List[PassType]] = None, + ): + if to_edge_and_lower_stage is None: + to_edge_and_lower_stage = ToEdgeTransformAndLower( + partitioners=partitioners, + edge_compile_config=edge_compile_config, + transform_passes=transform_passes, + ) + else: + if partitioners is not None: + to_edge_and_lower_stage.partitioners = partitioners + if edge_compile_config is not None: + to_edge_and_lower_stage.edge_compile_conf = edge_compile_config + return super().to_edge_transform_and_lower( + to_edge_and_lower_stage, + generate_etrecord=generate_etrecord, + ) diff --git a/backends/xnnpack/utils/configs.py b/backends/xnnpack/utils/configs.py index 3016e94146b..ec47b81e835 100644 --- a/backends/xnnpack/utils/configs.py +++ b/backends/xnnpack/utils/configs.py @@ -9,6 +9,9 @@ import executorch.exir as exir +from executorch.backends.xnnpack._passes.lift_constant_scalar_operands_pass import ( + LiftConstantScalarOperandsPass, +) from executorch.backends.xnnpack._passes.remove_noop_expand_copy_pass import ( RemoveNoopExpandCopyPass, ) @@ -25,7 +28,7 @@ def get_xnnpack_edge_compile_config( def get_transform_passes(additional_passes=None) -> List[PassType]: - passes = [RemoveNoopExpandCopyPass()] + passes = [RemoveNoopExpandCopyPass(), LiftConstantScalarOperandsPass()] if additional_passes: passes.extend(additional_passes) return passes From 3952854c6eacd2a4c517512f28d83f44a4efb9d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Fri, 26 Jun 2026 10:12:14 +0200 Subject: [PATCH 2/2] Address review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Måns Nilsson Change-Id: I4ab9a85590762c24f474b935a26f61d90b9a0565 --- .../lift_constant_scalar_operands_pass.py | 9 +++---- ...test_lift_constant_scalar_operands_pass.py | 19 +++++++++++--- backends/xnnpack/test/tester/tester.py | 26 +++++++++++++------ 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py b/backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py index 8c6be2b62b5..e256a31750e 100644 --- a/backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py +++ b/backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py @@ -50,7 +50,6 @@ def __init__( if scalar_to_tensor_ops is not None else self.default_scalar_to_tensor_ops ) - self._modified = False def _create_constant_node( self, @@ -65,6 +64,8 @@ def _create_constant_node( input_value = input_node.meta["val"] tensor = torch.tensor(value, dtype=input_value.dtype, device=input_value.device) name = self._get_new_attr_name(graph_module) + # Keep constants as module attributes so the portable path can emit them + # without introducing aten.full, while XNNPACK can still read them as params. graph_module.register_buffer(name, tensor) fake_mode = node.meta["val"].fake_mode @@ -110,7 +111,7 @@ def _feeds_sdpa_qk_bmm(self, node: torch.fx.Node) -> bool: return False def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self._modified = False + modified = False for node in list(graph_module.graph.nodes): if ( @@ -140,12 +141,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: tensor_arg = self._create_constant_node(graph_module, node, node.args[1]) node.args = (node.args[0], tensor_arg) node.target = self.scalar_to_tensor_ops[node.target] - self._modified = True + modified = True graph_module.graph.eliminate_dead_code() graph_module.graph.lint() graph_module.recompile() - modified = self._modified - self._modified = False return PassResult(graph_module, modified) diff --git a/backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py b/backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py index 5ad44f78af0..5c61731a786 100644 --- a/backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py +++ b/backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py @@ -16,6 +16,7 @@ from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config from executorch.exir import to_edge from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_manager import ExportedProgramPassManager class TestLiftConstantScalarOperandsPass(unittest.TestCase): @@ -30,12 +31,17 @@ class AddScalar(torch.nn.Module): def forward(self, x): return torch.ops.aten.add.Scalar(x, 0.5) - def _to_edge_graph(self, module): - edge = to_edge( + def _to_edge_program_manager(self, module): + return to_edge( torch.export.export(module, (torch.randn(2, 3),), strict=True), compile_config=get_xnnpack_edge_compile_config(skip_dim_order=True), ) - return edge.transform([LiftConstantScalarOperandsPass()]).exported_program() + + def _to_edge_graph(self, module): + edge = self._to_edge_program_manager(module) + return ExportedProgramPassManager([LiftConstantScalarOperandsPass()])( + edge.exported_program() + ).exported_program def test_lifts_mul_scalar_operand(self): graph = self._to_edge_graph(self.MulScalar()).graph_module.graph @@ -48,6 +54,13 @@ def test_lifts_mul_scalar_operand(self): ) self.assertTrue(any(node.op == "get_attr" for node in graph.nodes)) + def test_lifted_mul_scalar_can_emit_without_delegation(self): + edge = self._to_edge_program_manager(self.MulScalar()).transform( + (LiftConstantScalarOperandsPass(),) + ) + + self.assertIsNotNone(edge.to_executorch()) + def test_keeps_unmapped_scalar_op(self): graph = self._to_edge_graph(self.AddScalar()).graph_module.graph diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 481864e265f..396e149565f 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -26,6 +26,7 @@ from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower from executorch.exir.backend.partitioner import Partitioner +from executorch.exir.pass_manager import PassType as ExirPassType from torch._export.pass_base import PassType from torch.export import ExportedProgram from torchao.quantization.pt2e.quantizer import Quantizer @@ -78,7 +79,7 @@ def __init__( self, partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, - transform_passes: Optional[List[PassType]] = None, + transform_passes: Optional[List[ExirPassType]] = None, ): super().__init__( default_partitioner_cls=XnnpackPartitioner, @@ -152,25 +153,34 @@ def __init__( def to_edge_transform_and_lower( self, - to_edge_and_lower_stage: Optional[BaseStages.ToEdgeTransformAndLower] = None, + to_edge_and_transform_stage: Optional[ + BaseStages.ToEdgeTransformAndLower + ] = None, generate_etrecord: bool = False, *, partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, - transform_passes: Optional[List[PassType]] = None, + transform_passes: Optional[List[ExirPassType]] = None, ): - if to_edge_and_lower_stage is None: - to_edge_and_lower_stage = ToEdgeTransformAndLower( + if to_edge_and_transform_stage is None: + to_edge_and_transform_stage = ToEdgeTransformAndLower( partitioners=partitioners, edge_compile_config=edge_compile_config, transform_passes=transform_passes, ) else: if partitioners is not None: - to_edge_and_lower_stage.partitioners = partitioners + to_edge_and_transform_stage.partitioners = partitioners if edge_compile_config is not None: - to_edge_and_lower_stage.edge_compile_conf = edge_compile_config + to_edge_and_transform_stage.edge_compile_conf = edge_compile_config + if transform_passes is not None: + if not isinstance(to_edge_and_transform_stage, ToEdgeTransformAndLower): + raise ValueError( + "transform_passes requires the XNNPACK " + "ToEdgeTransformAndLower stage." + ) + to_edge_and_transform_stage.transform_passes = transform_passes return super().to_edge_transform_and_lower( - to_edge_and_lower_stage, + to_edge_and_transform_stage, generate_etrecord=generate_etrecord, )