Skip to content

Good First Issue: Runtime MoE expert-sort for decode (MLX backend, Qwen 3.5 MoE) #20554

Description

@metascroy

Problem

The MLX MoE sorts tokens by expert so each expert's tokens are
contiguous for a grouped matmul (gather_mm/gather_qmm with
sorted_indices=True). This helps prefill (many tokens to batch) but
is pure overhead during decode (1 token → nothing to batch): the
argsort, activation gather, and output scatter all run for nothing, and
the sorted/segmented matmul kernel is a poor fit for a tiny batch.

The sort is decided by a static export-time flag (sort_experts),
so a single dynamic-shape .pte runs it in both phases
(backends/mlx/llm/switch.py:243-271,
examples/models/qwen3_5_moe/mlx_source_transformations.py:324-332).

Fix: decide sort vs. no-sort at runtime, based on token count M,
inside the backend lowering, using the existing emit_if_else helper.


Approach: two glue ops around the existing grouped matmuls

SwitchMLP.forward becomes:

x_input, idx, sort_experts, inv_order = torch.ops.mlx.moe_gather_inputs(
    x, expert_indices, self.top_k, self.sort_cutoff
)

if self.fuse_gate_up:
    gate_up = self.gate_up_proj(x_input, idx, sorted_indices=sort_experts)
    gate = gate_up[..., : self.intermediate_size]
    up = gate_up[..., self.intermediate_size :]
else:
    gate = self.gate_proj(x_input, idx, sorted_indices=sort_experts)
    up = self.up_proj(x_input, idx, sorted_indices=sort_experts)
h = self.activation(gate) * up
down = self.down_proj(h, idx, sorted_indices=sort_experts)

down = torch.ops.mlx.moe_scatter_outputs(
    down, sort_experts, inv_order, self.top_k
)
return (down * expert_weights.unsqueeze(-1)).sum(dim=-2)

sort_experts is now a runtime 0-d int tensor (1 = sorted, 0 =
unsorted) produced by moe_gather_inputs. It feeds both the
grouped matmuls (sorted_indices=) and moe_scatter_outputs. The two new ops'
handlers branch on M via emit_if_else, comparing the runtime token
count M against the static sort_cutoff: sort when M > sort_cutoff,
skip the sort otherwise. sort_cutoff is a compile-time arg of the op
(like top_k), consulted during emission — sort_cutoff=1 recovers
"sort whenever there is more than one token." The grouped-matmul call
sites are structurally unchanged; only their sorted_indices argument now
carries the runtime sort_experts tensor.

Two design constraints (both must hold)

1. Both emit_if_else branches must produce the same shapes.
torch.export fixes each op output's rank/shape from its register_fake,
and the single downstream gather is traced against those shapes. So the
sorted and unsorted branches of moe_gather_inputs must write
identically-shaped outputs. We therefore unify on the sorted
layout
— the unsorted branch replicates each token top_k times instead
of using the rank-4 broadcast form:

output shape sorted (M > sort_cutoff) unsorted (M ≤ sort_cutoff)
x_input [N*top_k, 1, D] x[order // top_k] then unsqueeze(-2) repeat_interleave(x, top_k) then unsqueeze(-2)
idx [N*top_k] int32 flat_indices[order] expert_indices.flatten()
sort_experts [] (0-d) int32 1 0
inv_order [N*top_k] int32 (fake/sorted) argsort(order) 0-element sentinel (unread)

The unsorted branch's repeat_interleave and the sorted branch's
x[order // top_k] both materialize M*top_k rows, so neither is cheaper
on that axis — the sort decision is governed by the matmul, not the
materialization (see sort_cutoff below).

moe_scatter_outputs(down, sort_experts, inv_order, top_k) takes
down [N*top_k, 1, H]squeeze(-2) → returns [N, top_k, H]
(prefill: down[inv_order].reshape(N, top_k, -1); decode:
down.reshape(N, top_k, -1)).

2. The grouped matmul's sorted_indices must be runtime-settable.
Today sorted_indices is a static bool baked onto
GatherMmNode/GatherQmmNode
(backends/mlx/ops.py:1557,1595;
schema.fbs:951,966). It is an MLX kernel correctness contract, not a
no-op hint: sorted_indices=True tells the kernel the indices are grouped
by expert, so it must be True only when we actually sorted. Since the
sort decision is now at runtime, upgrade sorted_indices from bool
to IntOrVid
(0 = unsorted, nonzero = sorted) — the same mechanism
PartitionNode.kth already uses.


The change

1. Upgrade sorted_indices: boolIntOrVid

This is the enabling change and the only one touching C++/schema. It
mirrors PartitionNode.kth (kth: IntOrVid in schema.fbs:683,
resolved at runtime by resolve_int).

  • Schema (backends/mlx/serialization/schema.fbs:945-967): change
    sorted_indices: bool = false;sorted_indices: IntOrVid; on
    GatherMmNode and GatherQmmNode.
    BC note: schema.fbs:14 says never change an existing field's type.
    Two options: (A, recommended) change the type directly — MLX
    delegate .pte are regenerated from source each export, so this is
    acceptable; (B, strict append-only) keep the legacy bool field,
    append a new sorted_indices_iov: IntOrVid; at the end of each table,
    and have the runtime prefer the IntOrVid when present.
  • Codegen: run python backends/mlx/serialization/generate.py to
    regenerate mlx_graph_schema.py, the serializers/inspector, the
    _generated/ bindings, runtime/schema_generated.h, and the partial
    runtime/MLXLoader.{h,cpp}.
  • Loader (backends/mlx/runtime/MLXLoader.cpp:2056,2085): change
    node.sorted_indices = fb->sorted_indices();
    node.sorted_indices = convert_int_or_vid(fb->sorted_indices());
    (exactly like kth at MLXLoader.cpp:1452,1467). MLXLoader.h:952,967
    field type boolIntOrVid.
  • Runtime (backends/mlx/runtime/MLXInterpreter.h:875,909): resolve
    to a bool at exec time:
    bool sorted = resolve_int(n.sorted_indices, st) != 0; then pass
    sorted to gather_mm(...) / gather_qmm(...).

2. Custom ops accept a runtime sorted_indices

backends/mlx/custom_ops.py — change gather_mm (:282-322) and
gather_qmm (:325-393) sorted_indices: bool = False
sorted_indices: Optional[Tensor] = None (0-d int; None/0 = unsorted).
The eager reference and register_fake already ignore sorted_indices
(it is layout-only with identical numerics — custom_ops.py:299-303,
347-369), so only the signatures change. Update SwitchLinear.forward
(backends/mlx/llm/switch.py:130-161) to pass the tensor straight
through.

3. Two new custom ops (backends/mlx/custom_ops.py)

The eager references branch on M on purpose — they are the
executable spec the lowering handler (section 4) mirrors branch-for-branch.
Sorting is an invertible permutation (identical numerics either way), so
the two paths exist for the lowering's sake, not the math's. Write them
to read exactly like the two MLX branches the handler will emit:

@torch.library.custom_op("mlx::moe_gather_inputs", mutates_args=())
def moe_gather_inputs(
    x: Tensor, expert_indices: Tensor, top_k: int, sort_cutoff: int
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    N = x.shape[0]
    if N > sort_cutoff:  # SORTED path  (handler: emit_sorted)
        flat = expert_indices.flatten()
        order = flat.argsort().to(torch.int32)
        inv_order = order.argsort().to(torch.int32)
        idx = flat[order].to(torch.int32)                      # [N*top_k]
        x_input = x[(order // top_k).to(torch.int64)].unsqueeze(-2)  # [N*top_k, 1, D]
        sort_experts = torch.ones((), dtype=torch.int32)
    else:                # UNSORTED path (handler: emit_unsorted)
        x_input = x.repeat_interleave(top_k, dim=0).unsqueeze(-2)    # [N*top_k, 1, D]
        idx = expert_indices.flatten().to(torch.int32)              # [N*top_k]
        sort_experts = torch.zeros((), dtype=torch.int32)
        inv_order = torch.empty(0, dtype=torch.int32)               # sentinel: never read
    return x_input, idx, sort_experts, inv_order

@torch.library.custom_op("mlx::moe_scatter_outputs", mutates_args=())
def moe_scatter_outputs(
    down: Tensor, sort_experts: Tensor, inv_order: Tensor, top_k: int
) -> Tensor:
    down = down.squeeze(-2)                                     # [N*top_k, H]
    if sort_experts.item():  # prefill: scatter back   (handler: emit_then)
        down = down[inv_order]
    # decode: no scatter (inv_order is the unread sentinel)  (handler: emit_else)
    return down.reshape(down.shape[0] // top_k, top_k, -1)     # [N, top_k, H]

The reference body is opaque to torch.export (custom ops are leaf
nodes), so the Python if N > sort_cutoff is fine — it runs only in
eager/parity tests. The outputs consumed downstream — x_input, idx,
sort_experts — are identically-shaped across both branches.
inv_order is the exception ([N*top_k] when sorted, a 0-element
sentinel when unsorted), but it is never read on the unsorted path
(see moe_scatter_outputs), so the shape divergence is harmless.

Add a register_fake for each. Unlike the reference, the fake must not
branch on M
— under export M is a symbolic SymInt, so data-dependent
control flow on it is illegal. The fake returns one shape for all M:
x_input [N*top_k, 1, D], idx [N*top_k], sort_experts [], and
inv_order [N*top_k] (the sorted-path shape). sort_experts must be a
0-d tensor (there is no scalar-bool slot); its runtime value (0/1) is
chosen by the handler branches in section 4. The fake's [N*top_k]
inv_order vs. the unsorted reference's 0-element sentinel is benign —
inv_order feeds no downstream fake (moe_scatter_outputs derives its output
shape from down + top_k), so run torch.library.opcheck with M > 1
(or skip the inv_order shape check), since decode deliberately returns a
sentinel.

4. Handlers (backends/mlx/ops.py) — branch on M

Each handler is a node-for-node lowering of the section-3 reference:
emit_sorted mirrors the N > sort_cutoff branch, emit_unsorted mirrors
the else branch. Both pre-allocate every output slot and both branches
write all of them

write all of them** (the IfNode only selects a chain; downstream reads
fixed slot ids). This is the multi-output generalization of the
single-output emit_if_else pattern in _sample_handler
(ops.py:3537,3712) and the GGUF Q4_K linear
(backends/mlx/custom_kernel_ops/gguf/q4k/linear.py:453-477).

@REGISTRY.register(target=[torch.ops.mlx.moe_gather_inputs.default])
def _moe_gather_inputs_handler(P, n):
    x, expert_indices = P.args(n)[0], P.args(n)[1]
    top_k = P.args(n)[2]                       # static int
    sort_cutoff = P.args(n)[3]                 # static int, consulted here at emission
    out_slots = P.make_or_get_slots(n)         # (x_input, idx, sort_experts, inv_order)

    m_iov = emit_shape(P, n.args[0], x, end_dim=1)[0]   # M = N (token count)

    def emit_sorted():    # M > cutoff: argsort -> gather -> sorted_idx -> inv_order; sort flag = 1
        ...   # ArgsortNode x2, FloorDivideNode (order // top_k), TakeNode/GatherNode,
              # ExpandDimsNode, a 0-d const "1" -> write all four out_slots
    def emit_unsorted():  # M <= cutoff: repeat token, flatten idx, sentinel inv_order; sort flag = 0
        ...   # BroadcastTo/Reshape + ExpandDimsNode, ReshapeNode (flatten idx),
              # a 0-element const for inv_order (sentinel; never read when sort flag == 0,
              # so no dynamic IotaNode needed), a 0-d const "0" -> write the SAME four out_slots

    # inv_order is only consumed by moe_scatter_outputs's sorted branch, which runs only when
    # sort_experts == 1. So emit_unsorted writes a 0-element sentinel for inv_order
    # (matching the eager reference's unsorted branch) instead of a real arange -- no
    # dynamic IotaNode needed. register_fake still declares inv_order as [N*top_k].

    # cond = (M - 1) // sort_cutoff: 0 (-> else/unsorted) for M <= sort_cutoff,
    # >= 1 (-> then/sorted) for M > sort_cutoff. The IfNode rule is nonzero -> then
    # (MLXInterpreter.h:1879-1881). There is no scalar-int compare node, but
    # SubtractIntNode + FloorDivideIntNode suffice. If M is a compile-time literal,
    # both fold and emit_if_else picks one branch -- no IfNode emitted.
    cond = emit_floordiv(P, emit_sub_int(P, m_iov, IntOrVid.from_literal(1)),
                         IntOrVid.from_literal(sort_cutoff))
    emit_if_else(P, cond, emit_sorted, emit_unsorted)
    return out_slots

moe_scatter_outputs handler: read the sort_experts 0-d tensor
to a Vid via ItemIntNode, then emit_if_else on it —
then: TakeNode(down, inv_order) then ReshapeNode; else: ReshapeNode
only (skip the gather). Both write the single output slot.

Grouped-matmul handlers (_gather_mm_handler ops.py:1545,
_gather_qmm_handler ops.py:1573): read sorted_indices; if it is a
runtime tensor, thread it to a Vid (ItemIntNode) and set
GatherMmNode.sorted_indices = IntOrVid.from_vid(...); if absent/None,
use IntOrVid.from_literal(0).

Available schema nodes (no new node types needed):
ArgsortNode (mlx_graph_schema.py:845), TakeNode (:216),
GatherNode (:560), FloorDivideNode (tensor order // top_k,
:1042), FloorDivideIntNode (scalar (M-1) // sort_cutoff, :293),
SubtractIntNode (:279), ReshapeNode (:531),
ExpandDimsNode (:194), SqueezeNode (:458),
BroadcastToNode (:507), ItemIntNode (:188),
SymSizeNode (:306), plus IdCopyNode for pass-through writes
(ops.py:2023). inv_order = argsort(argsort) = two ArgsortNodes,
matching switch.py:245-246. Add a small emit_floordiv helper
alongside emit_ceil_div (op_helpers.py:326-352) that folds when its
operand is literal.

5. Replace the static sort_experts flag with a sort_cutoff knob

The export-time sort_experts bool (always/never sort) becomes a
sort_cutoff int that is threaded the same way but now selects the
runtime threshold consulted during emission:

  • SwitchMLP.__init__/forward (backends/mlx/llm/switch.py:220-271):
    replace the if sort_experts: / else: blocks with the two op calls;
    drop the sort_experts parameter and store self.sort_cutoff (passed
    to moe_gather_inputs).
  • _sparse_moe_forward (mlx_source_transformations.py:55-78): replace
    sort_experts=getattr(self, "_sort_experts", False)SwitchMLP now
    reads self.sort_cutoff.
  • _swap_sparse_moe (:324-332) and mlx_source_transformations
    (:335-364): swap the sort_experts bool for a sort_cutoff int
    (e.g. default 1).
  • examples/models/qwen3_5_moe/export.py (~:52-58): pass sort_cutoff
    instead of sort_experts (default 1; tune per the matmul crossover).

sort_cutoff is a compile-time constant baked at export, not a
per-call runtime input — it sets the M threshold the IfNode compares
against. Pick it by benchmarking the sorted-vs-unsorted matmul crossover
for the model's num_experts/top_k; 1 reproduces "sort whenever
M > 1."


Acceptance criteria

  • sorted_indices is an IntOrVid on GatherMmNode/GatherQmmNode
    (schema + generate.py re-run + loader + interpreter via
    resolve_int), resolving to the correct kernel bool at runtime.
  • mlx::moe_gather_inputs / mlx::moe_scatter_outputs
    exist with register_fake; their handlers branch on M with
    emit_if_else, both branches writing identically-shaped outputs.
  • SwitchMLP.forward calls the two ops; the sort_experts bool is
    replaced by a sort_cutoff int threaded through _sparse_moe_forward,
    _swap_sparse_moe, and export.py.
  • The handler consults sort_cutoff at emission via
    cond = (M-1) // sort_cutoff: the unsorted branch (M ≤ sort_cutoff)
    emits no argsort/scatter and runs the grouped matmul with
    sorted_indices == 0; the sorted branch (M > sort_cutoff) sorts and
    runs with sorted_indices == 1.
  • When M is statically known, cond folds to a literal and no
    IfNode is emitted (the helper picks one branch).
  • Numerics unchanged; the CI exact-token gate still passes.

Testing

  • Unit / loweringbackends/mlx/test/test_ops.py. Add an
    OpTestCase (model wrapping SwitchMLP.forward, or the two ops
    directly) with expected_node_counts, following the Sample*Test
    node-count pattern (test_ops.py:7680-7756). With a fixed
    sort_cutoff, assert both batch_size ≤ sort_cutoff (unsorted:
    sentinel inv_order, sorted_indices==0, no ArgsortNode) and
    batch_size > sort_cutoff (sorted: ArgsortNode×2,
    TakeNode/GatherNode, IfNode). Cover the static-M fold (a
    fixed-shape input of each kind emits no IfNode). Existing
    grouped-matmul tests GatherMmTest (test_ops.py:6563) and
    GatherQmmTest (:6669) — each already has a batch_size=1 config —
    should be extended to pass a 0-d sorted_indices tensor through the new
    IntOrVid path.
  • E2E CI gate.github/workflows/mlx.ymltest-mlx-qwen35-moe:
    python -m executorch.examples.models.qwen3_5_moe.export \
      --tiny-test --backend mlx --qlinear 4w --qlinear-group-size 32 \
      --output-dir /tmp/qwen35_moe_mlx_tiny
    python -m executorch.examples.models.qwen3_5_moe.run \
      --pte /tmp/qwen35_moe_mlx_tiny/model.pte --prompt-len 4 --max-new-tokens 5
    Must keep the exact output Generated token ids: [167, 94, 253, 88, 227]
    (exercises prefill + decode, i.e. both branches).
  • Perf check (optional): report decode_token_per_sec before/after.

Pointers

  • MoE forward + the sort: backends/mlx/llm/switch.py:220-271
    (SwitchMLP.forward), SwitchLinear.forward :130-161.
  • Grouped-matmul ops: backends/mlx/custom_ops.py:282-393
    (gather_mm/gather_qmm); handlers backends/mlx/ops.py:1545-1641.
  • sorted_indices is a correctness contract, ignored by the eager
    reference and passed straight to MLX core
    (MLXInterpreter.h:875,909).
  • IntOrVid precedent end-to-end: PartitionNode.kthschema.fbs:683,
    mlx_graph_schema.py:856, MLXLoader.cpp:1452,
    MLXInterpreter.h:1819 (resolve_int).
  • Branch helpers: backends/mlx/builder/op_helpers.pyemit_if_else
    :355, emit_sub_int :307, emit_shape :177, emit_product
    :232.
  • Multi-output handler template: _split_with_sizes_handler
    (ops.py:1500-1542, make_or_get_slots + tuple return);
    extractor _getitem_handler (ops.py:2008-2028).
  • emit_if_else single-output precedent: _sample_handler
    (ops.py:3519-3713); Q4_K M-branch
    (backends/mlx/custom_kernel_ops/gguf/q4k/linear.py:453-477).
  • Export wiring to update (sort_expertssort_cutoff):
    examples/models/qwen3_5_moe/export.py,
    .../mlx_source_transformations.py:55-78,324-364.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    Status
    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions