Skip to content

[MLX] enable reinplacement for unary/binary ops#20557

Open
metascroy wants to merge 2 commits into
mainfrom
mlx-reinplace
Open

[MLX] enable reinplacement for unary/binary ops#20557
metascroy wants to merge 2 commits into
mainfrom
mlx-reinplace

Conversation

@metascroy

Copy link
Copy Markdown
Contributor

This adds ExecuTorch's reinplace pass to the MLX backends default pass, and targets all unary/binary ops.

@pytorch-bot

pytorch-bot Bot commented Jun 26, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20557

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 2a97360 with merge base 51729bb (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 26, 2026
@metascroy

Copy link
Copy Markdown
Contributor Author

@claude review this code

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 14m 14s —— View job


Code Review — [MLX] enable reinplacement for unary/binary ops

  • Read the diff against origin/main
  • Review exir/passes/reinplace.py changes
  • Review MLX backend changes (ops.py, passes.py, program_builder.py)
  • Review tests
  • Post review feedback

Nice work — the schema-derived in-place mapping, the dynamic-shape-safe shape/dtype guard, and the slot-lifetime transfer in _mark_read are all carefully done, and the test coverage (broadcast, dtype-change, dynamic shapes, on-device numeric chain) is thoughtful. One correctness issue stands out, plus a few smaller notes.


🔴 Correctness: non-mutated operand reads aren't recorded in seen_nodes

exir/passes/reinplace.py:478-484 — when a node is reinplaced, only the mutated args are added to seen_nodes:

if all_safe:
    for arg_idx in mutated_args:
        seen_nodes.add(node.args[arg_idx])
    ...
    continue                         # <-- skips seen_nodes.update(node.all_input_nodes)

But the reinplaced in-place op still reads its non-mutated operands at runtime. Because the reverse walk continues before the usual seen_nodes.update(node.all_input_nodes), those reads are never recorded. An earlier-executing node that mutates that same value can then be reinplaced and overwrite it before this op reads it.

Consider (all temps, so the input-placeholder guard doesn't save us):

t = torch.exp(x)   # temp, 2 users
w = torch.exp(z)   # temp, 1 user
a = t + y          # arg0 = t  (mutated)
b = w + t          # arg0 = w  (mutated, single-use temp), t read as arg1
return a * b

Reverse walk with ops_to_inplace={add.Tensor}:

  1. b = w + t → mutates w (dead single-use temp) → reinplaced to w.add_(t). Adds only w to seen_nodes (not t).
  2. a = t + y → mutates t; t is not in seen_nodes → reinplaced to t.add_(y).

Execution order then runs t.add_(y) first, corrupting t before b = w.add_(t) reads it. Result: b == w + (t + y) instead of w + t. ep.module() numerics are wrong. The existing test_multi_use_self_reinplaced_at_most_once doesn't catch this because there t is the mutated arg of both consumers (covered by the new mutated-arg tracking); here t is mutated by one and merely read by the other.

Note: the MLX backend incidentally dodges this at build time, because _inplace_alias_slot / the unary handler require len(a_node.users) == 1 before aliasing — so MLX emits functionally and stays correct. But the shared reinplace_pass is now exercised on arbitrary binary ops, and the bug is real for any consumer of the pass (and for ep.module() replay).

Suggested fix — record all reads on the reinplace path too (superset of the current mutated-arg set, and still correct for the duplicate-mutated-arg case):

if all_safe:
    ... rewrite ...
    seen_nodes.update(node.all_input_nodes)   # in-place op reads ALL its operands
    continue

I'd add a regression test mirroring the scenario above asserting ep.module() matches eager. Fix this →


🟡 Minor / discussion

  • _make_inplace_passthrough_handler relies on "output write is the last emitted op." For leaky_relu/gelu/clamp the functional handlers compute intermediates (cond, scaled, lifted constants) into fresh temps before the final write to the aliased out==self slot, so it's correct today — but it's an implicit contract on handler internals (ops.py:4660+). If a future edit to one of those functional handlers reads self after writing out, this silently corrupts. Worth a one-line assertion or at least the contract being stated at each delegated handler, not only in the passthrough docstring.

  • _resync_output_specs (passes.py) positionally re-syncs spec.arg.name. This assumes reinplace never changes output count or order — true for the current rewrite (1:1 replace_all_uses_with + erase), but brittle if the pass ever drops/reorders outputs. A short comment asserting len(output_specs) == len(out_args) (or an explicit check) would make a future violation fail loudly instead of silently mis-pairing names.

  • getattr(spec.arg, "name", None)OutputSpec.arg can be a non-TensorArgument (e.g. ConstantArgument) without a name; the getattr guard handles it, but pairing a renamed tensor output against a constant-arg spec by position could still mis-sync if outputs are heterogeneous. Edge case only; current tests don't hit it.

  • Lazy import in MLXReinplacePass.call to break the cycle with ops.py is reasonable and well-commented. No change needed — just confirming the cycle is real (ops.py imports from reinplace at module load).


✅ Things that look good

  • Dynamic-shape handling via statically_known_true (conservative = no reinplace on unprovable dims) is the right call and is tested.
  • The shape and dtype guard correctly blocks lt_-style dtype changes and broadcast-growth — good catch to gate binary ops on this rather than excluding them.
  • Building a fresh ops_to_inplace set per call and not using DEFAULT_INPLACEABLE_OPS keeps index_put/KV-cache patterns untouched — clearly reasoned in the docstring.
  • _mark_read slot-lifetime transfer correctly avoids reclaiming an aliased temp; the comment explains the hazard well.

Overall: solid PR. The seen_nodes gap is the one item I'd consider blocking, since it affects the shared pass even though MLX's own handler guards around it.
· branch mlx-reinplace

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-codesync

meta-codesync Bot commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this in D109898238.

@metascroy

Copy link
Copy Markdown
Contributor Author

@claude review this code

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 9m 50s —— View job


Code Review — [MLX] enable reinplacement for unary/binary ops

  • Read exir/passes/reinplace.py changes
  • Read MLX backend changes (ops.py, passes.py, program_builder.py)
  • Read tests
  • Post review feedback

This revision addresses every point from the prior review — the blocking seen_nodes correctness issue, the _resync_output_specs length assertion, and the passthrough-handler contract assertion are all in. Re-verifying the headline concern below.


✅ The blocking seen_nodes bug is correctly resolved

The prior review flagged that the reinplace path skipped seen_nodes.update(node.all_input_nodes), so a value read as a non-mutated operand by a reinplaced node could be overwritten by an earlier-executing reinplaced node. This PR fixes it — but via a different mechanism than the originally suggested seen_nodes.update(...), so it's worth confirming the new approach holds.

The author's comment at exir/passes/reinplace.py:479-491 argues the non-mutated reads are recorded because the reverse walk revisits the inserted in-place node. Tracing FX's linked-list semantics, this is sound:

  1. The reverse iterator (_NodeIter, advancing via _prev) is positioned at node when the body runs.
  2. inserting_before(node) links the new in-place node as node._prev.
  3. erase_node_remove_from_list only rewires the neighbors' pointers; it leaves the erased node's own _prev pointing at the inserted node. So the next __next__ yields the inserted node.
  4. The inserted node's target is the edge in-place op, which is not a key in resolved (only functional ops are), so it hits the generic branch at reinplace.py:427-430 and runs seen_nodes.update(node.all_input_nodes) — recording all operands, mutated and non-mutated alike.

Replaying the prior counterexample (b = w + t reinplaced first, t read as arg1): after b.add_ is inserted and revisited, t lands in seen_nodes, so a = t + y is no longer reinplaced. Correct. The on-device ReinplaceChainTest and ep.module() numeric checks back this up.

Minor: the explicit for arg_idx in mutated_args: seen_nodes.add(...) at line 490-491 is now strictly redundant with the revisit (which records the mutated args too, since they're also all_input_nodes of the inserted node). It's harmless belt-and-suspenders, but the comment could note that the revisit is the load-bearing mechanism and the explicit add is just defensive — today's wording reads as if the explicit add is what prevents the double-reinplace, when really the revisit does.


🟡 Minor / discussion

  • Coupling to FX iterator internals. The whole correctness argument rests on "erase_node leaves _prev intact, so reversed() visits the inserted node next." That's stable, long-standing FX behavior, but it's an implicit dependency. If you wanted to be robust against a future FX iterator change (or a reader who doesn't trust the revisit), adding the suggested seen_nodes.update(node.all_input_nodes) directly on the reinplace path would make the pass self-contained and independent of iterator semantics — it's a superset of what the revisit records, so it can't hurt. Worth considering as defense-in-depth even though the current code is correct.

  • _make_inplace_addsub_handler (ops.py:4500ish) takes the alpha-constant dtype from n.args[0] (self/arg0), but alpha scales arg1 (b). For the in-place arithmetic ops this PR targets, self and other share a dtype (the reinplace guard already enforces self↔out dtype match), so this is fine in practice — just slightly surprising to read "scale b" while sourcing the dtype from a. A one-word comment would settle it.

  • REGISTRY._handlers.get(_func_aten) (ops.py:4670ish) reaches into a private attribute to look up the already-registered functional handler. Works, but if REGISTRY exposes a public getter it'd be cleaner; otherwise fine as an internal-module access.


✅ Things that look good

  • _resync_output_specs now asserts len(output_specs) == len(out_args), so a future rewrite that drops/reorders outputs fails loudly instead of silently mis-pairing names via zip truncation — exactly the prior suggestion.
  • _make_inplace_passthrough_handler asserts the delegated handler returned the aliased slot, catching the "stopped writing n's slot" violation; the docstring honestly flags the residual read-after-write hazard it can't catch.
  • _mark_read's slot-lifetime transfer (program_builder.py:335-355) correctly skips reclaiming a temp slot the consumer has aliased in place — the comment explains the use-after-free hazard well.
  • Building a fresh ops_to_inplace per call and deliberately not using DEFAULT_INPLACEABLE_OPS keeps index_put/KV-cache patterns untouched.
  • Dynamic-shape handling via statically_known_true (conservative — no reinplace on unprovable dims) is right and tested; the shape+dtype guard correctly blocks lt_ dtype changes and broadcast growth.

Overall: the prior blocking issue is genuinely resolved and the feedback is well-incorporated. My only real suggestion is the defense-in-depth note — adding the explicit all_input_nodes update on the reinplace path so correctness doesn't hinge on FX iterator internals. Not blocking. Nice work.
· branch mlx-reinplace

@metascroy metascroy requested a review from Gasoonjia June 26, 2026 23:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant