Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions backends/mlx/builder/program_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def to_int_or_vid_or_tid(self, v: Union[int, Slot]) -> IntOrVidOrTid:
return IntOrVidOrTid.from_vid(self.slot_to_vid(v))
return IntOrVidOrTid.from_literal(int(v))

def _mark_read(self, node: Node):
def _mark_read(self, node: Node, consumer: Optional[Node] = None):
assert self.node_info[node].handled, f"Node {node} is not handled"
assert (
self.node_info[node].remaining_reads > 0
Expand All @@ -335,9 +335,24 @@ def _mark_read(self, node: Node):
return
if not isinstance(slot, tuple):
slot = (slot,)
# When the consuming node reuses one of this node's slots in place as
# its own output (out == in, e.g. an in-place unary like exp_), the
# slot's lifetime transfers to the consumer: it must NOT be reclaimed
# here, or a later allocation could grab the same id while the
# consumer (which shares it) is still live. The consumer frees it when
# its own reads finish. Slot equality is identity-based.
aliased: set = set()
if consumer is not None:
consumer_slot = self.slot_manager.get_slot(consumer)
if consumer_slot is not None:
if not isinstance(consumer_slot, tuple):
consumer_slot = (consumer_slot,)
aliased = set(consumer_slot)
for s in slot:
if s.id_space != IdSpace.Temp:
continue
if s in aliased:
continue
if s.id_type == IdType.Tensor:
self.slot_manager.tid_managers[IdSpace.Temp].return_id(s.idx)
else:
Expand All @@ -359,7 +374,7 @@ def mark_read(n: Node):
for a in flat_args:
if isinstance(a, Node):
if a not in seen:
self._mark_read(a)
self._mark_read(a, consumer=n)
seen.add(a)

if isinstance(handler, PatternHandler):
Expand Down
247 changes: 246 additions & 1 deletion backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from executorch.backends.mlx.builder.op_registry import REGISTRY
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
from executorch.backends.mlx.builder.slot_manager import IdType, Slot
from executorch.backends.mlx.builder.slot_manager import IdSpace, IdType, Slot
from executorch.backends.mlx.serialization.mlx_graph_schema import (
AbsNode,
AddIntNode,
Expand Down Expand Up @@ -164,6 +164,7 @@
# The corresponding edge ops are automatically registered
# For ops that are not in aten (e.g., dim order ops), directly register on exir_ops
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.passes.reinplace import _derive_edge_inplace_overload
from torch.fx.node import Node

_LEAKY_RELU_DEFAULT_NEGATIVE_SLOPE = 0.01
Expand Down Expand Up @@ -428,6 +429,217 @@ def handler(P: MLXProgramBuilder, n: Node) -> Slot:
REGISTRY.register(target=[_target])(_make_unary_handler(_node_cls, _op_name))


def _make_inplace_unary_handler(node_cls: Any, op_name: str):
"""Create a handler for an in-place unary op (e.g. aten.exp_).

These nodes are produced by the MLX reinplace pass (see passes.py), which
only rewrites a functional op to its in-place form when the input is a dead,
single-use temp. We bind the node's output slot to that input slot and emit
with out_tid == in_tid so MLX donates the input buffer at eval time (same
mechanism as SLICE_UPDATE/INDEX_COPY). If the input is not a reusable temp
(defensive — should not happen given the reinplace safety analysis), fall
back to allocating a fresh output slot.
"""

def handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 1, 1, op_name)
require_kwargs(P.kwargs(n), set(), op_name)
x = args[0]
input_node = n.args[0]
# Only alias when n produces a fresh temp (no pre-assigned slot). Graph
# outputs / mutable buffers already own an Output/MutableBuffer slot
# (from _make_io_slots) and must keep it, so fall back to functional for
# those — donation on a terminal output is worthless anyway (it's copied
# out). Also require the input to be a dead, single-use temp.
if (
P.slot_manager.get_slot(n) is None
and isinstance(x, Slot)
and x.id_space == IdSpace.Temp
and isinstance(input_node, Node)
and len(input_node.users) == 1
):
# Reuse the dead input temp's slot as the output (out == in). The
# builder's slot-lifetime transfer (program_builder._mark_read) keeps
# this slot alive until n's own users are done.
P.set_slot(n, x)
P.emit(node_cls(x=P.slot_to_tid(x), out=P.slot_to_tid(x)))
return x
out = P.make_or_get_slot(n)
P.emit(node_cls(x=P.slot_to_tid(x), out=P.slot_to_tid(out)))
return out

handler.__name__ = f"_{op_name.replace('.', '_')}_handler"
handler.__doc__ = f"Handle {op_name} (in-place table-driven unary op)."
return handler


# Register in-place variants (e.g. aten.exp_) for every unary op MLX handles that
# has an aten in-place overload. REINPLACEABLE_UNARY_BASE_NAMES is the source of
# truth consumed by passes.py to build the reinplace pass's op set, so MLX has
# full control over exactly which ops get reinplaced (handlers exist for all of
# them, and nothing else — e.g. index_put is never included).
REINPLACEABLE_UNARY_BASE_NAMES: List[str] = []
for _target, _node_cls, _op_name in _UNARY_OPS:
_base = _op_name.split(".")[-1]
_ip_packet = getattr(torch.ops.aten, _base + "_", None)
_ip_op = getattr(_ip_packet, "default", None) if _ip_packet is not None else None
if _ip_op is None:
continue
REGISTRY.register(target=[_ip_op])(
_make_inplace_unary_handler(_node_cls, _op_name + "_")
)
REINPLACEABLE_UNARY_BASE_NAMES.append(_base)


def _inplace_alias_slot(P: MLXProgramBuilder, n: Node, a) -> Optional[Slot]:
"""Return ``a``'s slot if it is safe to reuse it as ``n``'s output (out == in).

The MLX reinplace pass only emits an in-place op when the mutated operand is
full-size and dtype-matching (the shape/dtype guard lives there, where it is
dynamic-shape/SymInt-safe). So this handler-side check just confirms ``a`` is
a reusable temp: ``n`` has no pre-assigned slot (not a graph output / mutable
buffer) and ``a`` is a single-use ``Temp`` tensor. (Reusing the slot is
runtime-correct regardless of shape — it is functional slot reuse; MLX only
donates the buffer when sizes are compatible.) Returns None otherwise.
"""
if P.slot_manager.get_slot(n) is not None:
return None
a_node = n.args[0] if n.args else None
if not (
isinstance(a, Slot)
and a.id_space == IdSpace.Temp
and isinstance(a_node, Node)
and len(a_node.users) == 1
):
return None
return a


def _make_inplace_binary_handler(node_cls: Any, op_name: str):
"""In-place binary handler (mul_/div_, no alpha): alias out == arg0 when safe.

Produced by the MLX reinplace pass, which already guarantees arg0 is a
full-size, dtype-matching, single-use dead temp; the alias check here is
defensive and also handles the graph-output fallback.
"""

def handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 2, op_name)
require_kwargs(P.kwargs(n), set(), op_name)
a, b = args[0], args[1]
alias = _inplace_alias_slot(P, n, a)
out = alias if alias is not None else P.make_or_get_slot(n)
P.emit(node_cls(a=P.slot_to_tid(a), b=P.slot_to_tid(b), out=P.slot_to_tid(out)))
if alias is not None:
P.set_slot(n, alias)
return out

handler.__name__ = f"_{op_name.replace('.', '_')}_handler"
handler.__doc__ = f"Handle {op_name} (in-place table-driven binary op)."
return handler


def _make_inplace_addsub_handler(node_cls: Any, op_name: str):
"""In-place add_/sub_ handler: handles the alpha kwarg and aliases out == arg0.

``alpha`` only scales the *other* operand (arg1), so it never blocks aliasing
arg0 (self); when ``alpha != 1`` we emit ``other * alpha`` into a temp first.
"""

def handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 2, op_name)
require_kwargs(P.kwargs(n), {"alpha"}, op_name)
a, b = args[0], args[1]
alpha = P.kwargs(n).get("alpha", 1)
if alpha != 1:
input_meta = n.args[0].meta.get("val")
dtype = input_meta.dtype if input_meta is not None else torch.float32
alpha_slot = emit_lifted_constant(P, alpha, dtype)
_, tmp = P.make_tmp_slot()
P.emit(
MultiplyNode(
a=P.slot_to_tid(b),
b=P.slot_to_tid(alpha_slot),
out=P.slot_to_tid(tmp),
)
)
b = tmp
alias = _inplace_alias_slot(P, n, a)
out = alias if alias is not None else P.make_or_get_slot(n)
P.emit(node_cls(a=P.slot_to_tid(a), b=P.slot_to_tid(b), out=P.slot_to_tid(out)))
if alias is not None:
P.set_slot(n, alias)
return out

handler.__name__ = f"_{op_name.replace('.', '_')}_handler"
handler.__doc__ = f"Handle {op_name} (in-place add/sub op)."
return handler


# In-place binary handlers + the (base, overload) source of truth consumed by
# passes.py to build the binary reinplace op set. Restricted to dtype-preserving
# arithmetic Tensor overloads; the reinplace pass additionally guards that arg0
# is full-size (no broadcast) before producing these in-place ops.
REINPLACEABLE_BINARY_BASE_OVERLOADS: List[Tuple[str, str]] = []
for _ip_target, _ip_node_cls, _ip_name, _is_addsub in (
(torch.ops.aten.add_.Tensor, AddNode, "aten.add_", True),
(torch.ops.aten.sub_.Tensor, SubtractNode, "aten.sub_", True),
(torch.ops.aten.mul_.Tensor, MultiplyNode, "aten.mul_", False),
(torch.ops.aten.div_.Tensor, DivideNode, "aten.div_", False),
):
_factory = (
_make_inplace_addsub_handler if _is_addsub else _make_inplace_binary_handler
)
REGISTRY.register(target=[_ip_target])(_factory(_ip_node_cls, _ip_name))
REINPLACEABLE_BINARY_BASE_OVERLOADS.append((_ip_name.split(".")[-1][:-1], "Tensor"))


def _make_inplace_passthrough_handler(functional_handler):
"""In-place handler that aliases out == self, then delegates to the op's
existing functional handler.

These functional handlers (clamp, pow, gelu, relu, leaky_relu, hardtanh)
obtain their output slot via ``P.make_or_get_slot(n)`` and write it with the
last op they emit. By pre-binding ``n``'s slot to the dead ``self`` temp
before delegating, that final write becomes in-place (out == in) and MLX can
donate the buffer. When ``self`` is not a reusable temp (e.g. a graph
output), no pre-bind happens and the functional handler runs unchanged.

The mutated ``self`` is always positional arg 0 for these ops, and every
op emitted before the output writer only *reads* ``self``, so the in-place
write (last) is safe. This "all reads of ``self`` happen before the final
write to out == self" ordering is a contract on each delegated functional
handler; the assertion below catches the easy-to-spot violation where a
handler stops using ``n``'s slot as its output, but a handler that reads
``self`` *after* writing out would still silently corrupt — keep that
invariant in mind when editing clamp/pow/gelu/relu/leaky_relu/hardtanh.
"""

def handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
self_slot = args[0] if args else None
alias = _inplace_alias_slot(P, n, self_slot)
if alias is not None:
P.set_slot(n, alias)
result = functional_handler(P, n)
# When we pre-bind out == self, the delegated handler must treat that
# slot as its output (write it last). Confirm it actually returned the
# aliased slot; otherwise the in-place aliasing silently did nothing.
assert alias is None or result is alias, (
f"{getattr(functional_handler, '__name__', functional_handler)} did "
f"not use the aliased out==self slot as its output for {n}; in-place "
f"passthrough requires the delegated handler to write n's slot."
)
return result

handler.__name__ = "_inplace_passthrough_handler"
handler.__doc__ = "In-place passthrough (aliases out==self, delegates)."
return handler


# ---------------------------------------------------------------------------
# Numerical checks
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -4441,3 +4653,36 @@ def emit_reverse(in_slot, out_slot):
)

return output_slots


# ---------------------------------------------------------------------------
# In-place variants for ops with bespoke functional handlers (clamp, pow,
# activations). Each reuses its functional handler via a passthrough that
# aliases out == self when self is a dead temp (see
# _make_inplace_passthrough_handler). Registered last, after every functional
# handler above is defined. REINPLACEABLE_EXTRA_BASE_OVERLOADS feeds passes.py.
# ---------------------------------------------------------------------------
REINPLACEABLE_EXTRA_BASE_OVERLOADS: List[Tuple[str, str]] = []
for _base, _overload in (
("clamp", "default"),
("clamp", "Tensor"),
("gelu", "default"),
("relu", "default"),
("leaky_relu", "default"),
("hardtanh", "default"),
("pow", "Tensor_Scalar"),
("pow", "Tensor_Tensor"),
):
_func_aten = getattr(getattr(torch.ops.aten, _base), _overload, None)
if _func_aten is None:
continue
_func_handler = REGISTRY._handlers.get(_func_aten)
if _func_handler is None:
continue
_ip_edge = _derive_edge_inplace_overload(_func_aten)
if _ip_edge is None:
continue
REGISTRY.register(target=[_ip_edge._op])(
_make_inplace_passthrough_handler(_func_handler)
)
REINPLACEABLE_EXTRA_BASE_OVERLOADS.append((_base, _overload))
Loading
Loading