diff --git a/backends/mlx/builder/program_builder.py b/backends/mlx/builder/program_builder.py index 9909bacc417..ebd025a42e9 100644 --- a/backends/mlx/builder/program_builder.py +++ b/backends/mlx/builder/program_builder.py @@ -322,7 +322,7 @@ def to_int_or_vid_or_tid(self, v: Union[int, Slot]) -> IntOrVidOrTid: return IntOrVidOrTid.from_vid(self.slot_to_vid(v)) return IntOrVidOrTid.from_literal(int(v)) - def _mark_read(self, node: Node): + def _mark_read(self, node: Node, consumer: Optional[Node] = None): assert self.node_info[node].handled, f"Node {node} is not handled" assert ( self.node_info[node].remaining_reads > 0 @@ -335,9 +335,24 @@ def _mark_read(self, node: Node): return if not isinstance(slot, tuple): slot = (slot,) + # When the consuming node reuses one of this node's slots in place as + # its own output (out == in, e.g. an in-place unary like exp_), the + # slot's lifetime transfers to the consumer: it must NOT be reclaimed + # here, or a later allocation could grab the same id while the + # consumer (which shares it) is still live. The consumer frees it when + # its own reads finish. Slot equality is identity-based. + aliased: set = set() + if consumer is not None: + consumer_slot = self.slot_manager.get_slot(consumer) + if consumer_slot is not None: + if not isinstance(consumer_slot, tuple): + consumer_slot = (consumer_slot,) + aliased = set(consumer_slot) for s in slot: if s.id_space != IdSpace.Temp: continue + if s in aliased: + continue if s.id_type == IdType.Tensor: self.slot_manager.tid_managers[IdSpace.Temp].return_id(s.idx) else: @@ -359,7 +374,7 @@ def mark_read(n: Node): for a in flat_args: if isinstance(a, Node): if a not in seen: - self._mark_read(a) + self._mark_read(a, consumer=n) seen.add(a) if isinstance(handler, PatternHandler): diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 86e322a16e7..9b212455d3e 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -30,7 +30,7 @@ ) from executorch.backends.mlx.builder.op_registry import REGISTRY from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder -from executorch.backends.mlx.builder.slot_manager import IdType, Slot +from executorch.backends.mlx.builder.slot_manager import IdSpace, IdType, Slot from executorch.backends.mlx.serialization.mlx_graph_schema import ( AbsNode, AddIntNode, @@ -164,6 +164,7 @@ # The corresponding edge ops are automatically registered # For ops that are not in aten (e.g., dim order ops), directly register on exir_ops from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes.reinplace import _derive_edge_inplace_overload from torch.fx.node import Node _LEAKY_RELU_DEFAULT_NEGATIVE_SLOPE = 0.01 @@ -428,6 +429,217 @@ def handler(P: MLXProgramBuilder, n: Node) -> Slot: REGISTRY.register(target=[_target])(_make_unary_handler(_node_cls, _op_name)) +def _make_inplace_unary_handler(node_cls: Any, op_name: str): + """Create a handler for an in-place unary op (e.g. aten.exp_). + + These nodes are produced by the MLX reinplace pass (see passes.py), which + only rewrites a functional op to its in-place form when the input is a dead, + single-use temp. We bind the node's output slot to that input slot and emit + with out_tid == in_tid so MLX donates the input buffer at eval time (same + mechanism as SLICE_UPDATE/INDEX_COPY). If the input is not a reusable temp + (defensive — should not happen given the reinplace safety analysis), fall + back to allocating a fresh output slot. + """ + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 1, 1, op_name) + require_kwargs(P.kwargs(n), set(), op_name) + x = args[0] + input_node = n.args[0] + # Only alias when n produces a fresh temp (no pre-assigned slot). Graph + # outputs / mutable buffers already own an Output/MutableBuffer slot + # (from _make_io_slots) and must keep it, so fall back to functional for + # those — donation on a terminal output is worthless anyway (it's copied + # out). Also require the input to be a dead, single-use temp. + if ( + P.slot_manager.get_slot(n) is None + and isinstance(x, Slot) + and x.id_space == IdSpace.Temp + and isinstance(input_node, Node) + and len(input_node.users) == 1 + ): + # Reuse the dead input temp's slot as the output (out == in). The + # builder's slot-lifetime transfer (program_builder._mark_read) keeps + # this slot alive until n's own users are done. + P.set_slot(n, x) + P.emit(node_cls(x=P.slot_to_tid(x), out=P.slot_to_tid(x))) + return x + out = P.make_or_get_slot(n) + P.emit(node_cls(x=P.slot_to_tid(x), out=P.slot_to_tid(out))) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (in-place table-driven unary op)." + return handler + + +# Register in-place variants (e.g. aten.exp_) for every unary op MLX handles that +# has an aten in-place overload. REINPLACEABLE_UNARY_BASE_NAMES is the source of +# truth consumed by passes.py to build the reinplace pass's op set, so MLX has +# full control over exactly which ops get reinplaced (handlers exist for all of +# them, and nothing else — e.g. index_put is never included). +REINPLACEABLE_UNARY_BASE_NAMES: List[str] = [] +for _target, _node_cls, _op_name in _UNARY_OPS: + _base = _op_name.split(".")[-1] + _ip_packet = getattr(torch.ops.aten, _base + "_", None) + _ip_op = getattr(_ip_packet, "default", None) if _ip_packet is not None else None + if _ip_op is None: + continue + REGISTRY.register(target=[_ip_op])( + _make_inplace_unary_handler(_node_cls, _op_name + "_") + ) + REINPLACEABLE_UNARY_BASE_NAMES.append(_base) + + +def _inplace_alias_slot(P: MLXProgramBuilder, n: Node, a) -> Optional[Slot]: + """Return ``a``'s slot if it is safe to reuse it as ``n``'s output (out == in). + + The MLX reinplace pass only emits an in-place op when the mutated operand is + full-size and dtype-matching (the shape/dtype guard lives there, where it is + dynamic-shape/SymInt-safe). So this handler-side check just confirms ``a`` is + a reusable temp: ``n`` has no pre-assigned slot (not a graph output / mutable + buffer) and ``a`` is a single-use ``Temp`` tensor. (Reusing the slot is + runtime-correct regardless of shape — it is functional slot reuse; MLX only + donates the buffer when sizes are compatible.) Returns None otherwise. + """ + if P.slot_manager.get_slot(n) is not None: + return None + a_node = n.args[0] if n.args else None + if not ( + isinstance(a, Slot) + and a.id_space == IdSpace.Temp + and isinstance(a_node, Node) + and len(a_node.users) == 1 + ): + return None + return a + + +def _make_inplace_binary_handler(node_cls: Any, op_name: str): + """In-place binary handler (mul_/div_, no alpha): alias out == arg0 when safe. + + Produced by the MLX reinplace pass, which already guarantees arg0 is a + full-size, dtype-matching, single-use dead temp; the alias check here is + defensive and also handles the graph-output fallback. + """ + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, op_name) + require_kwargs(P.kwargs(n), set(), op_name) + a, b = args[0], args[1] + alias = _inplace_alias_slot(P, n, a) + out = alias if alias is not None else P.make_or_get_slot(n) + P.emit(node_cls(a=P.slot_to_tid(a), b=P.slot_to_tid(b), out=P.slot_to_tid(out))) + if alias is not None: + P.set_slot(n, alias) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (in-place table-driven binary op)." + return handler + + +def _make_inplace_addsub_handler(node_cls: Any, op_name: str): + """In-place add_/sub_ handler: handles the alpha kwarg and aliases out == arg0. + + ``alpha`` only scales the *other* operand (arg1), so it never blocks aliasing + arg0 (self); when ``alpha != 1`` we emit ``other * alpha`` into a temp first. + """ + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, op_name) + require_kwargs(P.kwargs(n), {"alpha"}, op_name) + a, b = args[0], args[1] + alpha = P.kwargs(n).get("alpha", 1) + if alpha != 1: + input_meta = n.args[0].meta.get("val") + dtype = input_meta.dtype if input_meta is not None else torch.float32 + alpha_slot = emit_lifted_constant(P, alpha, dtype) + _, tmp = P.make_tmp_slot() + P.emit( + MultiplyNode( + a=P.slot_to_tid(b), + b=P.slot_to_tid(alpha_slot), + out=P.slot_to_tid(tmp), + ) + ) + b = tmp + alias = _inplace_alias_slot(P, n, a) + out = alias if alias is not None else P.make_or_get_slot(n) + P.emit(node_cls(a=P.slot_to_tid(a), b=P.slot_to_tid(b), out=P.slot_to_tid(out))) + if alias is not None: + P.set_slot(n, alias) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (in-place add/sub op)." + return handler + + +# In-place binary handlers + the (base, overload) source of truth consumed by +# passes.py to build the binary reinplace op set. Restricted to dtype-preserving +# arithmetic Tensor overloads; the reinplace pass additionally guards that arg0 +# is full-size (no broadcast) before producing these in-place ops. +REINPLACEABLE_BINARY_BASE_OVERLOADS: List[Tuple[str, str]] = [] +for _ip_target, _ip_node_cls, _ip_name, _is_addsub in ( + (torch.ops.aten.add_.Tensor, AddNode, "aten.add_", True), + (torch.ops.aten.sub_.Tensor, SubtractNode, "aten.sub_", True), + (torch.ops.aten.mul_.Tensor, MultiplyNode, "aten.mul_", False), + (torch.ops.aten.div_.Tensor, DivideNode, "aten.div_", False), +): + _factory = ( + _make_inplace_addsub_handler if _is_addsub else _make_inplace_binary_handler + ) + REGISTRY.register(target=[_ip_target])(_factory(_ip_node_cls, _ip_name)) + REINPLACEABLE_BINARY_BASE_OVERLOADS.append((_ip_name.split(".")[-1][:-1], "Tensor")) + + +def _make_inplace_passthrough_handler(functional_handler): + """In-place handler that aliases out == self, then delegates to the op's + existing functional handler. + + These functional handlers (clamp, pow, gelu, relu, leaky_relu, hardtanh) + obtain their output slot via ``P.make_or_get_slot(n)`` and write it with the + last op they emit. By pre-binding ``n``'s slot to the dead ``self`` temp + before delegating, that final write becomes in-place (out == in) and MLX can + donate the buffer. When ``self`` is not a reusable temp (e.g. a graph + output), no pre-bind happens and the functional handler runs unchanged. + + The mutated ``self`` is always positional arg 0 for these ops, and every + op emitted before the output writer only *reads* ``self``, so the in-place + write (last) is safe. This "all reads of ``self`` happen before the final + write to out == self" ordering is a contract on each delegated functional + handler; the assertion below catches the easy-to-spot violation where a + handler stops using ``n``'s slot as its output, but a handler that reads + ``self`` *after* writing out would still silently corrupt — keep that + invariant in mind when editing clamp/pow/gelu/relu/leaky_relu/hardtanh. + """ + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + self_slot = args[0] if args else None + alias = _inplace_alias_slot(P, n, self_slot) + if alias is not None: + P.set_slot(n, alias) + result = functional_handler(P, n) + # When we pre-bind out == self, the delegated handler must treat that + # slot as its output (write it last). Confirm it actually returned the + # aliased slot; otherwise the in-place aliasing silently did nothing. + assert alias is None or result is alias, ( + f"{getattr(functional_handler, '__name__', functional_handler)} did " + f"not use the aliased out==self slot as its output for {n}; in-place " + f"passthrough requires the delegated handler to write n's slot." + ) + return result + + handler.__name__ = "_inplace_passthrough_handler" + handler.__doc__ = "In-place passthrough (aliases out==self, delegates)." + return handler + + # --------------------------------------------------------------------------- # Numerical checks # --------------------------------------------------------------------------- @@ -4441,3 +4653,36 @@ def emit_reverse(in_slot, out_slot): ) return output_slots + + +# --------------------------------------------------------------------------- +# In-place variants for ops with bespoke functional handlers (clamp, pow, +# activations). Each reuses its functional handler via a passthrough that +# aliases out == self when self is a dead temp (see +# _make_inplace_passthrough_handler). Registered last, after every functional +# handler above is defined. REINPLACEABLE_EXTRA_BASE_OVERLOADS feeds passes.py. +# --------------------------------------------------------------------------- +REINPLACEABLE_EXTRA_BASE_OVERLOADS: List[Tuple[str, str]] = [] +for _base, _overload in ( + ("clamp", "default"), + ("clamp", "Tensor"), + ("gelu", "default"), + ("relu", "default"), + ("leaky_relu", "default"), + ("hardtanh", "default"), + ("pow", "Tensor_Scalar"), + ("pow", "Tensor_Tensor"), +): + _func_aten = getattr(getattr(torch.ops.aten, _base), _overload, None) + if _func_aten is None: + continue + _func_handler = REGISTRY._handlers.get(_func_aten) + if _func_handler is None: + continue + _ip_edge = _derive_edge_inplace_overload(_func_aten) + if _ip_edge is None: + continue + REGISTRY.register(target=[_ip_edge._op])( + _make_inplace_passthrough_handler(_func_handler) + ) + REINPLACEABLE_EXTRA_BASE_OVERLOADS.append((_base, _overload)) diff --git a/backends/mlx/passes.py b/backends/mlx/passes.py index ef4c768a2f8..e88cd83f0bc 100644 --- a/backends/mlx/passes.py +++ b/backends/mlx/passes.py @@ -20,7 +20,12 @@ walk_back, ) from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import ( + ExportedProgramPassBase, + ExportedProgramPassResult, + ExportPass, + PassResult, +) from executorch.exir.passes.cse_pass import CSEPass from torch.fx import GraphModule, Node @@ -37,9 +42,107 @@ def get_default_passes() -> List[ExportPass]: CollapseDtypeConversionPass(), RemoveNoOpsPass(), CSEPass(), + # Must run last: rewrites the settled/deduped graph's functional + # elementwise chains into in-place ops so the MLX builder can donate + # buffers (out == in). + MLXReinplacePass(), ] +class MLXReinplacePass(ExportedProgramPassBase): + """Reinplace MLX-handled elementwise ops to enable MLX buffer donation. + + Rewrites functional unary/binary chains (e.g. ``exp(log(exp(x)))``, + ``h + ffn(h)``) into their in-place edge forms via ExecuTorch's + ``reinplace_pass``, restricted to an explicit, MLX-owned op set. Every op in + that set has a corresponding in-place MLX handler (see + ``REINPLACEABLE_UNARY_BASE_NAMES`` / ``REINPLACEABLE_BINARY_BASE_OVERLOADS`` + in ops.py) that binds the output slot to the dead input slot (out == in). + + Binary ops are safe to include because ``reinplace_pass`` only reinplaces + when the mutated argument already holds the output's shape and dtype (no + broadcast growth / dtype change). + + We deliberately do NOT use ``DEFAULT_INPLACEABLE_OPS``: passing an explicit + ``ops_to_inplace`` fully replaces the default, so ``index_put`` is never + reinplaced and MLX's existing KV-cache / index_copy functional patterns are + untouched. A fresh set is built per call so ``reinplace_pass`` cannot mutate + shared state. + + Runs as an EP-aware pass (it needs ``graph_signature`` to protect mutable + inputs/buffers); the pass infra hands ``ExportedProgramPassBase`` instances + the full ExportedProgram. + """ + + def call(self, exported_program) -> ExportedProgramPassResult: + # Imported lazily to avoid a module-load cycle with ops.py (which + # registers handlers on import). + from executorch.backends.mlx.ops import ( + REINPLACEABLE_BINARY_BASE_OVERLOADS, + REINPLACEABLE_EXTRA_BASE_OVERLOADS, + REINPLACEABLE_UNARY_BASE_NAMES, + ) + from executorch.exir.passes.reinplace import reinplace_pass + + # Explicit, MLX-owned op set (every op has an in-place MLX handler). + # Binary ops are safe to pass to reinplace_pass because it guards that + # the mutated arg is full-size + dtype-matching (no broadcast growth). + ops_to_inplace = { + getattr(exir_ops.edge.aten, base).default + for base in REINPLACEABLE_UNARY_BASE_NAMES + } + ops_to_inplace |= { + getattr(getattr(exir_ops.edge.aten, base), overload) + for base, overload in ( + REINPLACEABLE_BINARY_BASE_OVERLOADS + REINPLACEABLE_EXTRA_BASE_OVERLOADS + ) + } + if ops_to_inplace: + reinplace_pass(exported_program, ops_to_inplace=ops_to_inplace) + self._resync_output_specs(exported_program) + return ExportedProgramPassResult(exported_program, True) + + @staticmethod + def _resync_output_specs(exported_program) -> None: + """Re-sync graph-signature output names after reinplace. + + ``reinplace_pass`` rewrites an output-producing node (e.g. the final + ``exp`` -> ``exp_``) via ``replace_all_uses_with`` + erase, but does not + update ``graph_signature.output_specs``. Output order is preserved, so we + positionally re-sync each spec's argument name to the current output node + arg; otherwise ``ExportedProgram.validate()`` (run by the pass manager) + raises a SpecViolationError. + + This positional pairing is only valid because reinplace does a 1:1 + ``replace_all_uses_with`` + erase and never drops, adds, or reorders + outputs. We assert ``len(output_specs) == len(out_args)`` so that a + future change violating that invariant fails loudly here instead of + silently mis-pairing names (``zip`` would otherwise truncate). Specs + whose ``arg`` is not a named tensor (e.g. ``ConstantArgument``) carry no + ``name`` and are skipped by the ``getattr`` guard below. + """ + out_node = next( + n for n in reversed(exported_program.graph.nodes) if n.op == "output" + ) + out_args = out_node.args[0] + if not isinstance(out_args, (tuple, list)): + out_args = (out_args,) + output_specs = exported_program.graph_signature.output_specs + assert len(output_specs) == len(out_args), ( + "reinplace changed graph output count: " + f"{len(output_specs)} output_specs vs {len(out_args)} output args. " + "Positional output-spec re-sync assumes a 1:1, order-preserving " + "rewrite." + ) + for spec, arg in zip(output_specs, out_args): + if ( + isinstance(arg, torch.fx.Node) + and getattr(spec.arg, "name", None) is not None + and spec.arg.name != arg.name + ): + spec.arg.name = arg.name + + @dataclass class RMSNormMatch(PatternMatch): """ diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index afd4f276dde..c390716619b 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -27,6 +27,7 @@ import os from typing import Callable, Dict, List, Optional, Tuple +import executorch.exir as exir import torch import torch.nn as nn @@ -113,6 +114,62 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: return (x, y) +class ReinplaceChainModel(nn.Module): + """Elementwise chain the reinplace pass converts to in-place ops. + + Mixes a pure unary op (exp), activations (sigmoid/relu/clamp/gelu), and + binary ops (add/mul/sub) so the chain exercises the unary, activation, and + binary in-place handlers. Every op after the first sigmoid consumes a + single-use temp, so all become in-place (sigmoid_/add_/relu_/mul_/clamp_/ + exp_/gelu_/sub_) and run on one rolling buffer; the terminal neg writes the + graph output. Inputs are kept NaN/Inf-free: sigmoid -> bounded, clamp to + [-2, 2] before exp so exp stays in [e^-2, e^2]. + """ + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + s = torch.sigmoid(x) # reads input -> fresh temp + s = s + y # add_ + s = torch.relu(s) # relu_ (activation) + s = s * y # mul_ + s = torch.clamp(s, -2.0, 2.0) # clamp_ (activation), bounds exp below + s = torch.exp(s) # exp_ (pure unary) + s = torch.nn.functional.gelu(s) # gelu_ (activation) + s = s - y # sub_ + return torch.neg(s) # terminal output (not in-place) + + +@register_test +class ReinplaceChainTest(OpTestCase): + """On-device numeric check that reinplaced (out==in) ops are correct. + + Lowers with get_default_passes() so the MLXReinplacePass + in-place handlers + (out == in buffer donation) run through the actual MLX runtime. The + build-level aliasing is unit-tested in test_passes.py; only on-device + execution catches a read-after-overwrite bug from buffer reuse. + """ + + name = "reinplace_chain" + rtol = 1e-4 + atol = 1e-4 + + def create_model(self) -> nn.Module: + return ReinplaceChainModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(2, 16, 64), torch.randn(2, 16, 64)) + + def get_edge_compile_config(self) -> Optional[exir.EdgeCompileConfig]: + # Reinplace introduces non-core-ATen in-place ops (add_, sigmoid_, ...), + # so disable the strict edge verifier — matching the production export + # path (which also runs get_default_passes with this config). + return exir.EdgeCompileConfig(_check_ir_validity=False, _skip_dim_order=True) + + def get_transform_passes(self) -> Optional[list]: + from executorch.backends.mlx.passes import get_default_passes + + return get_default_passes() + + class SubModel(nn.Module): """Model that performs element-wise subtraction, optionally with alpha.""" diff --git a/backends/mlx/test/test_passes.py b/backends/mlx/test/test_passes.py index 97172c1411a..b09289d85cd 100644 --- a/backends/mlx/test/test_passes.py +++ b/backends/mlx/test/test_passes.py @@ -630,7 +630,14 @@ def forward(self, x): gm = _to_edge_gm(M(), (torch.randn(1, 16),)) + from executorch.exir.pass_base import ExportedProgramPassBase + for p in get_default_passes(): + # EP-aware passes (e.g. MLXReinplacePass) require a full + # ExportedProgram (graph_signature), not a bare GraphModule; they are + # exercised in the reinplace tests. + if isinstance(p, ExportedProgramPassBase): + continue p(gm) gm.graph.lint() @@ -651,7 +658,13 @@ def forward(self, x): gm = _to_edge_gm(module, (x,)) + from executorch.exir.pass_base import ExportedProgramPassBase + for p in get_default_passes(): + # EP-aware passes (e.g. MLXReinplacePass) require a full + # ExportedProgram, not a bare GraphModule; skip here. + if isinstance(p, ExportedProgramPassBase): + continue p(gm) actual = gm(x) @@ -661,5 +674,155 @@ def forward(self, x): torch.testing.assert_close(actual, expected) +class TestReinplacePass(unittest.TestCase): + """MLXReinplacePass rewrites functional elementwise chains into in-place edge + ops; the in-place handlers alias out == in to enable MLX buffer donation.""" + + def _lower_and_get_ep(self, module, example_inputs, dynamic_shapes=None): + from executorch.backends.mlx.passes import get_default_passes + + module.eval() + ep = export(module, example_inputs, dynamic_shapes=dynamic_shapes, strict=False) + edge = exir.to_edge( + ep, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + edge = edge.transform(get_default_passes()) + return edge.exported_program() + + def test_unary_chain_is_reinplaced_and_aliased(self): + """exp(log(exp(x))): the dead middle temp is reinplaced and aliased.""" + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + class M(nn.Module): + def forward(self, x): + return torch.exp(torch.log(torch.exp(x))) + + eep = self._lower_and_get_ep(M(), (torch.randn(4, 4),)) + + # The pass introduced an in-place edge op for the dead temp (log_). + targets = [str(n.target) for n in eep.graph.nodes if n.op == "call_function"] + self.assertTrue(any("aten.log_" in t for t in targets), targets) + + # In the built MLX program, at least one link must alias out == in. + g = MLXProgramBuilder(eep).build() + aliased = [] + for chain in g.instruction_chains: + for instr in chain.instructions: + op = instr.op + if type(op).__name__ in ("ExpNode", "LogNode"): + aliased.append(op.x.idx == op.out.idx) + self.assertTrue(any(aliased), f"expected an in-place (out==in) link: {aliased}") + + def test_full_size_binary_is_reinplaced_and_aliased(self): + """A full-size, dtype-matching dead-temp binary op is reinplaced+aliased.""" + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + class M(nn.Module): + def forward(self, x, y): + h = torch.exp(x) # full-size dead temp + s = h + y # intermediate full-size add + return torch.neg(s) + + eep = self._lower_and_get_ep(M(), (torch.randn(4, 8), torch.randn(4, 8))) + targets = [str(n.target) for n in eep.graph.nodes if n.op == "call_function"] + self.assertTrue(any("aten.add_" in t for t in targets), targets) + + g = MLXProgramBuilder(eep).build() + add_ops = [ + instr.op + for chain in g.instruction_chains + for instr in chain.instructions + if type(instr.op).__name__ == "AddNode" + ] + self.assertTrue( + any(op.a.idx == op.out.idx for op in add_ops), + f"expected an in-place AddNode (out==a): {[(o.a.idx, o.out.idx) for o in add_ops]}", + ) + + def test_dynamic_shapes_lower_and_alias(self): + """Reinplace must work (and not raise) under dynamic shapes: a full-size + chain reinplaces+aliases, building the MLX program cleanly.""" + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + class M(nn.Module): + def forward(self, x, y): + h = torch.exp(x) # [B, D] dynamic B, dead temp + s = h + y # full-size, same symbol B + return torch.neg(s) + + x = torch.randn(3, 8) + y = torch.randn(3, 8) + dynamic_shapes = { + "x": {0: torch.export.Dim("B")}, + "y": {0: torch.export.Dim("B")}, + } + eep = self._lower_and_get_ep(M(), (x, y), dynamic_shapes=dynamic_shapes) + + targets = [str(n.target) for n in eep.graph.nodes if n.op == "call_function"] + self.assertTrue(any("aten.add_" in t for t in targets), targets) + + # Build must succeed and produce an in-place AddNode. + g = MLXProgramBuilder(eep).build() + add_ops = [ + instr.op + for chain in g.instruction_chains + for instr in chain.instructions + if type(instr.op).__name__ == "AddNode" + ] + self.assertTrue(any(op.a.idx == op.out.idx for op in add_ops)) + + def test_extra_ops_build_with_inplace_handlers(self): + """clamp / pow / activations: the in-place edge op is produced and the + MLX builder lowers it via the in-place handler (build succeeds). + + Numerics are covered by the upstream reinplace tests; here we only + verify the MLX-specific path — that an in-place handler exists for each + and the program builds. + """ + import torch.nn.functional as F + + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + cases = { + "clamp": lambda x: torch.neg(torch.clamp(torch.exp(x), -1.0, 1.0)), + "pow": lambda x: torch.neg(torch.exp(x) ** 2), + "gelu": lambda x: torch.neg(F.gelu(torch.exp(x))), + "relu": lambda x: torch.neg(torch.relu(torch.exp(x))), + "leaky_relu": lambda x: torch.neg(F.leaky_relu(torch.exp(x), 0.1)), + "hardtanh": lambda x: torch.neg(F.hardtanh(torch.exp(x))), + } + for name, fn in cases.items(): + with self.subTest(op=name): + + class M(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + + eep = self._lower_and_get_ep(M(fn), (torch.randn(4, 8),)) + targets = [ + str(n.target) for n in eep.graph.nodes if n.op == "call_function" + ] + # An in-place edge op (other than the terminal neg) is present. + self.assertTrue( + any( + "aten." in t + and "_." in t.split("aten.")[-1] + and "neg_" not in t + for t in targets + ), + f"{name}: expected an in-place op, got {targets}", + ) + # The MLX builder must lower the in-place op (handler registered). + MLXProgramBuilder(eep).build() + + if __name__ == "__main__": unittest.main() diff --git a/backends/mlx/test/test_utils.py b/backends/mlx/test/test_utils.py index 1a964bea935..9746609a8d3 100644 --- a/backends/mlx/test/test_utils.py +++ b/backends/mlx/test/test_utils.py @@ -275,6 +275,7 @@ def export_model_to_pte( dynamic_shapes: Optional[Dict] = None, verbose: bool = False, edge_compile_config: Optional[exir.EdgeCompileConfig] = None, + transform_passes: Optional[list] = None, ) -> None: """ Export a PyTorch model to a .pte file using the MLX delegate. @@ -287,6 +288,8 @@ def export_model_to_pte( dynamic_shapes: Optional dynamic shapes specification for torch.export. Example: {0: {0: Dim("batch", min=1, max=32)}} for dynamic batch on first input. verbose: Whether to print the exported program for debugging. + transform_passes: Optional edge-dialect passes to run before partitioning + (e.g. ``get_default_passes()``). Mirrors the production export path. """ from executorch.backends.mlx import MLXPartitioner from executorch.exir.capture._config import ExecutorchBackendConfig @@ -308,10 +311,14 @@ def export_model_to_pte( # Lower to edge and delegate to MLX compile_config = edge_compile_config or exir.EdgeCompileConfig() + lower_kwargs = {} + if transform_passes is not None: + lower_kwargs["transform_passes"] = transform_passes edge_program = exir.to_edge_transform_and_lower( exported_program, partitioner=[MLXPartitioner()], compile_config=compile_config, + **lower_kwargs, ) # Print edge program if verbose @@ -877,6 +884,15 @@ def get_edge_compile_config(self) -> Optional[exir.EdgeCompileConfig]: """Return EdgeCompileConfig for export, or None for default.""" return None + def get_transform_passes(self) -> Optional[list]: + """Return edge-dialect transform passes to run before partitioning. + + Defaults to None (no passes), matching the historical op-test path. + Override to return ``get_default_passes()`` to exercise the production + lowering pipeline (e.g. for reinplace/donation coverage). + """ + return None + def get_test_dir(self) -> Path: """Get the directory for this test's files.""" test_dir = Path(__file__).parent / "op_tests" / self.name @@ -947,6 +963,7 @@ def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]: dynamic_shapes=dynamic_shapes, verbose=verbose, edge_compile_config=self.get_edge_compile_config(), + transform_passes=self.get_transform_passes(), ) # Save test inputs diff --git a/exir/passes/reinplace.py b/exir/passes/reinplace.py index 0dae20f4e22..332878b5190 100644 --- a/exir/passes/reinplace.py +++ b/exir/passes/reinplace.py @@ -226,8 +226,59 @@ def _derive_mutated_args(inplace_op: Any) -> Tuple[int, ...]: # --------------------------------------------------------------------------- +def _val_shape_dtype(node: torch.fx.Node): + """Return (shape, dtype) of ``node``'s fake value, or None if not a tensor.""" + v = node.meta.get("val") + if isinstance(v, torch.Tensor): # FakeTensor is a torch.Tensor subclass + return tuple(v.shape), v.dtype + return None + + +def _is_inplace_shape_dtype_compatible( + arg_node: torch.fx.Node, consuming_node: torch.fx.Node +) -> bool: + """True if writing ``consuming_node``'s output into ``arg_node`` is legal. + + An in-place op writes its result into the mutated argument, so that argument + must already hold the output's shape and dtype. Reinplacing a broadcasting + argument (e.g. ``add_(self[D], other[N, D])`` whose output is ``[N, D]``) or + a dtype-changing one (e.g. ``lt_`` whose output is bool) is invalid. + + Dynamic-shape safe: dimensions may be ``SymInt``s, so we compare each dim + with ``statically_known_true`` rather than ``==`` (which would call + ``bool(SymInt == SymInt)`` and can raise / add a guard for distinct + symbols). This is conservative — a dim that cannot be *proven* equal is + treated as incompatible (no reinplace), which is the safe default. + + When fake-tensor metadata is unavailable for either node, fall back to + permissive (preserves prior behavior for shape/dtype-preserving ops like + ``index_put``). + """ + from torch.fx.experimental.symbolic_shapes import statically_known_true + + a = _val_shape_dtype(arg_node) + o = _val_shape_dtype(consuming_node) + if a is None or o is None: + return True + (a_shape, a_dtype) = a + (o_shape, o_dtype) = o + if a_dtype != o_dtype: + return False + if len(a_shape) != len(o_shape): + return False + for ad, od in zip(a_shape, o_shape): + if isinstance(ad, int) and isinstance(od, int): + if ad != od: + return False + elif not statically_known_true(ad == od): + # SymInt dim(s): only treat as equal when provably so (no guard). + return False + return True + + def _is_safe_to_reinplace( node: torch.fx.Node, + consuming_node: torch.fx.Node, later_nodes: Set[torch.fx.Node], inputs: Set[torch.fx.Node], mutable_inputs: Set[torch.fx.Node], @@ -236,6 +287,11 @@ def _is_safe_to_reinplace( # There is probably a faster way to do this but this works for now. if node in later_nodes: return False + # The mutated argument must already hold the op's output shape/dtype; + # otherwise the in-place op would have to broadcast-grow or change dtype, + # which is invalid (e.g. add_(small, big), lt_(float, ...)). + if not _is_inplace_shape_dtype_compatible(node, consuming_node): + return False # If its not an input then we can reinplace it if node not in inputs: return True @@ -383,6 +439,7 @@ def reinplace_pass( # noqa: C901 break if first_tensor_idx is not None and _is_safe_to_reinplace( node.args[first_tensor_idx], # pyre-ignore[6] + node, seen_nodes, inputs, mutable_nodes, @@ -413,10 +470,25 @@ def reinplace_pass( # noqa: C901 f"Tensor(a!). A Tensor input in an FX graph " f"must be a torch.fx.Node." ) - if not _is_safe_to_reinplace(arg_node, seen_nodes, inputs, mutable_nodes): + if not _is_safe_to_reinplace( + arg_node, node, seen_nodes, inputs, mutable_nodes + ): all_safe = False break if all_safe: + # We intentionally skip `seen_nodes.update(node.all_input_nodes)` + # here even though the in-place op reads its non-mutated operands: + # the rewrite inserts the in-place node before `node`, and the live + # `reversed` walk visits it next. Its in-place overload isn't in + # `resolved`, so it falls into the generic branch above, which + # records all its operands (including the non-mutated ones). + # + # Mark the mutated args as used so a different consumer of the same + # value (earlier in topo order, visited later in this reverse walk) + # is not also reinplaced — that would read a value this op already + # overwrote in place. + for arg_idx in mutated_args: + seen_nodes.add(node.args[arg_idx]) with ep.graph.inserting_before(node): # Forward both args and kwargs: the in-place overload # is schema-matched to the functional one, so any diff --git a/exir/tests/test_reinplace_pass.py b/exir/tests/test_reinplace_pass.py index c7b72a73a13..6f1e07b7f32 100644 --- a/exir/tests/test_reinplace_pass.py +++ b/exir/tests/test_reinplace_pass.py @@ -326,6 +326,131 @@ def forward( # kernels in the other tests in this file. edge.to_executorch() + def test_broadcasting_self_not_reinplaced(self) -> None: + """An op whose mutated arg (self) broadcasts up to a larger output + must NOT be reinplaced: the in-place form cannot grow self + (``add_(self[D], other[N, D])`` is invalid). Only the shape guard + prevents this — the arg is otherwise a dead, single-use temp. + """ + + class M(torch.nn.Module): + def forward(self, bias: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + c = torch.relu(bias) # intermediate temp, shape [D] + return (c + x) * 2.0 # c [D] broadcasts to [N, D] + + model = M() + bias = torch.randn(4) + x = torch.randn(3, 4) + edge = to_edge(export(model, (bias, x), strict=True)) + ep = reinplace_pass( + edge.exported_program(), + ops_to_inplace={edge_ops.edge.aten.add.Tensor}, + ) + self.assertEqual( + len(_find_nodes(ep, "add_")), + 0, + "broadcast-first add must stay functional", + ) + ep.validate() + torch.testing.assert_close(ep.module()(bias, x), model(bias, x)) + + def test_dtype_changing_op_not_reinplaced(self) -> None: + """A comparison whose output dtype (bool) differs from its mutated + arg (float) must NOT be reinplaced: ``lt_`` cannot change self's + dtype. The shape/dtype guard blocks it.""" + + class M(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + t = torch.exp(x) # float temp + return (t < y).float() # bool output, dtype != self + + model = M() + x = torch.randn(3, 4) + y = torch.randn(3, 4) + edge = to_edge(export(model, (x, y), strict=True)) + ep = reinplace_pass( + edge.exported_program(), + ops_to_inplace={edge_ops.edge.aten.lt.Tensor}, + ) + self.assertEqual( + len(_find_nodes(ep, "lt_")), + 0, + "dtype-changing comparison must stay functional", + ) + ep.validate() + + def test_multi_use_self_reinplaced_at_most_once(self) -> None: + """When a value is the mutated arg of two ops, reinplacing both + would let the earlier consumer read a value the later one already + overwrote in place. Only the last consumer (execution order) may be + reinplaced; numerics must be preserved.""" + + class M(torch.nn.Module): + def forward( + self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> torch.Tensor: + t = torch.exp(x) # intermediate temp with TWO consumers + a = t + y + b = t + z + return a * b + + model = M() + x, y, z = torch.randn(5), torch.randn(5), torch.randn(5) + edge = to_edge(export(model, (x, y, z), strict=True)) + ep = reinplace_pass( + edge.exported_program(), + ops_to_inplace={edge_ops.edge.aten.add.Tensor}, + ) + self.assertLessEqual( + len(_find_nodes(ep, "add_")), + 1, + "a value with two consumers must be reinplaced at most once", + ) + # Note: no ep.validate() here — reinplacing intentionally introduces a + # non-core-ATen in-place op (add_), which the strict edge verifier + # rejects (see test_ops_to_inplace_extends_with_add). Correctness is + # checked by running the rewritten graph instead. + torch.testing.assert_close(ep.module()(x, y, z), model(x, y, z)) + + def test_dynamic_shapes_no_guard_error(self) -> None: + """With dynamic (SymInt) dims the shape guard must not raise: + a same-symbol full-size op is reinplaced, while a broadcasting one + (static dim 1 vs a dynamic dim) is left functional.""" + + class M(torch.nn.Module): + def forward(self, x: torch.Tensor, row: torch.Tensor) -> torch.Tensor: + t = torch.exp(x) # [B, D] dynamic B, single-use temp + s = t + x # full-size, same symbol B -> reinplaceable + r = torch.relu(row) # [1, D] temp + return r + s # [1, D] + [B, D] -> self=r broadcasts -> skip + + model = M() + x = torch.randn(3, 4) + row = torch.randn(1, 4) + dynamic_shapes = {"x": {0: torch.export.Dim("B")}, "row": None} + exported_program = export( + model, (x, row), dynamic_shapes=dynamic_shapes, strict=True + ) + edge = to_edge(exported_program) + + # Must not raise GuardOnDataDependentSymNode on the SymInt comparison. + ep = reinplace_pass( + edge.exported_program(), + ops_to_inplace={edge_ops.edge.aten.add.Tensor}, + ) + # The same-symbol full-size add is reinplaced... + self.assertGreaterEqual( + len(_find_nodes(ep, "add_")), + 1, + "same-symbol full-size add should reinplace under dynamic shapes", + ) + # ...while the broadcasting add stays functional. + self.assertGreaterEqual( + len(_find_nodes(ep, "aten.add.Tensor")), + 1, + "broadcasting add must remain functional under dynamic shapes", + ) + def test_ops_to_inplace_empty_disables_all_rewrites(self) -> None: """Passing an empty ``ops_to_inplace`` set should disable every rewrite, even ops that are in ``DEFAULT_INPLACEABLE_OPS``.