Problem
The MLX MoE sorts tokens by expert so each expert's tokens are
contiguous for a grouped matmul (gather_mm/gather_qmm with
sorted_indices=True). This helps prefill (many tokens to batch) but
is pure overhead during decode (1 token → nothing to batch): the
argsort, activation gather, and output scatter all run for nothing, and
the sorted/segmented matmul kernel is a poor fit for a tiny batch.
The sort is decided by a static export-time flag (sort_experts),
so a single dynamic-shape .pte runs it in both phases
(backends/mlx/llm/switch.py:243-271,
examples/models/qwen3_5_moe/mlx_source_transformations.py:324-332).
Fix: decide sort vs. no-sort at runtime, based on token count M,
inside the backend lowering, using the existing emit_if_else helper.
Approach: two glue ops around the existing grouped matmuls
SwitchMLP.forward becomes:
x_input, idx, sort_experts, inv_order = torch.ops.mlx.moe_gather_inputs(
x, expert_indices, self.top_k, self.sort_cutoff
)
if self.fuse_gate_up:
gate_up = self.gate_up_proj(x_input, idx, sorted_indices=sort_experts)
gate = gate_up[..., : self.intermediate_size]
up = gate_up[..., self.intermediate_size :]
else:
gate = self.gate_proj(x_input, idx, sorted_indices=sort_experts)
up = self.up_proj(x_input, idx, sorted_indices=sort_experts)
h = self.activation(gate) * up
down = self.down_proj(h, idx, sorted_indices=sort_experts)
down = torch.ops.mlx.moe_scatter_outputs(
down, sort_experts, inv_order, self.top_k
)
return (down * expert_weights.unsqueeze(-1)).sum(dim=-2)
sort_experts is now a runtime 0-d int tensor (1 = sorted, 0 =
unsorted) produced by moe_gather_inputs. It feeds both the
grouped matmuls (sorted_indices=) and moe_scatter_outputs. The two new ops'
handlers branch on M via emit_if_else, comparing the runtime token
count M against the static sort_cutoff: sort when M > sort_cutoff,
skip the sort otherwise. sort_cutoff is a compile-time arg of the op
(like top_k), consulted during emission — sort_cutoff=1 recovers
"sort whenever there is more than one token." The grouped-matmul call
sites are structurally unchanged; only their sorted_indices argument now
carries the runtime sort_experts tensor.
Two design constraints (both must hold)
1. Both emit_if_else branches must produce the same shapes.
torch.export fixes each op output's rank/shape from its register_fake,
and the single downstream gather is traced against those shapes. So the
sorted and unsorted branches of moe_gather_inputs must write
identically-shaped outputs. We therefore unify on the sorted
layout — the unsorted branch replicates each token top_k times instead
of using the rank-4 broadcast form:
| output |
shape |
sorted (M > sort_cutoff) |
unsorted (M ≤ sort_cutoff) |
x_input |
[N*top_k, 1, D] |
x[order // top_k] then unsqueeze(-2) |
repeat_interleave(x, top_k) then unsqueeze(-2) |
idx |
[N*top_k] int32 |
flat_indices[order] |
expert_indices.flatten() |
sort_experts |
[] (0-d) int32 |
1 |
0 |
inv_order |
[N*top_k] int32 (fake/sorted) |
argsort(order) |
0-element sentinel (unread) |
The unsorted branch's repeat_interleave and the sorted branch's
x[order // top_k] both materialize M*top_k rows, so neither is cheaper
on that axis — the sort decision is governed by the matmul, not the
materialization (see sort_cutoff below).
moe_scatter_outputs(down, sort_experts, inv_order, top_k) takes
down [N*top_k, 1, H] → squeeze(-2) → returns [N, top_k, H]
(prefill: down[inv_order].reshape(N, top_k, -1); decode:
down.reshape(N, top_k, -1)).
2. The grouped matmul's sorted_indices must be runtime-settable.
Today sorted_indices is a static bool baked onto
GatherMmNode/GatherQmmNode (backends/mlx/ops.py:1557,1595;
schema.fbs:951,966). It is an MLX kernel correctness contract, not a
no-op hint: sorted_indices=True tells the kernel the indices are grouped
by expert, so it must be True only when we actually sorted. Since the
sort decision is now at runtime, upgrade sorted_indices from bool
to IntOrVid (0 = unsorted, nonzero = sorted) — the same mechanism
PartitionNode.kth already uses.
The change
1. Upgrade sorted_indices: bool → IntOrVid
This is the enabling change and the only one touching C++/schema. It
mirrors PartitionNode.kth (kth: IntOrVid in schema.fbs:683,
resolved at runtime by resolve_int).
- Schema (
backends/mlx/serialization/schema.fbs:945-967): change
sorted_indices: bool = false; → sorted_indices: IntOrVid; on
GatherMmNode and GatherQmmNode.
BC note: schema.fbs:14 says never change an existing field's type.
Two options: (A, recommended) change the type directly — MLX
delegate .pte are regenerated from source each export, so this is
acceptable; (B, strict append-only) keep the legacy bool field,
append a new sorted_indices_iov: IntOrVid; at the end of each table,
and have the runtime prefer the IntOrVid when present.
- Codegen: run
python backends/mlx/serialization/generate.py to
regenerate mlx_graph_schema.py, the serializers/inspector, the
_generated/ bindings, runtime/schema_generated.h, and the partial
runtime/MLXLoader.{h,cpp}.
- Loader (
backends/mlx/runtime/MLXLoader.cpp:2056,2085): change
node.sorted_indices = fb->sorted_indices(); →
node.sorted_indices = convert_int_or_vid(fb->sorted_indices());
(exactly like kth at MLXLoader.cpp:1452,1467). MLXLoader.h:952,967
field type bool → IntOrVid.
- Runtime (
backends/mlx/runtime/MLXInterpreter.h:875,909): resolve
to a bool at exec time:
bool sorted = resolve_int(n.sorted_indices, st) != 0; then pass
sorted to gather_mm(...) / gather_qmm(...).
2. Custom ops accept a runtime sorted_indices
backends/mlx/custom_ops.py — change gather_mm (:282-322) and
gather_qmm (:325-393) sorted_indices: bool = False →
sorted_indices: Optional[Tensor] = None (0-d int; None/0 = unsorted).
The eager reference and register_fake already ignore sorted_indices
(it is layout-only with identical numerics — custom_ops.py:299-303,
347-369), so only the signatures change. Update SwitchLinear.forward
(backends/mlx/llm/switch.py:130-161) to pass the tensor straight
through.
3. Two new custom ops (backends/mlx/custom_ops.py)
The eager references branch on M on purpose — they are the
executable spec the lowering handler (section 4) mirrors branch-for-branch.
Sorting is an invertible permutation (identical numerics either way), so
the two paths exist for the lowering's sake, not the math's. Write them
to read exactly like the two MLX branches the handler will emit:
@torch.library.custom_op("mlx::moe_gather_inputs", mutates_args=())
def moe_gather_inputs(
x: Tensor, expert_indices: Tensor, top_k: int, sort_cutoff: int
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
N = x.shape[0]
if N > sort_cutoff: # SORTED path (handler: emit_sorted)
flat = expert_indices.flatten()
order = flat.argsort().to(torch.int32)
inv_order = order.argsort().to(torch.int32)
idx = flat[order].to(torch.int32) # [N*top_k]
x_input = x[(order // top_k).to(torch.int64)].unsqueeze(-2) # [N*top_k, 1, D]
sort_experts = torch.ones((), dtype=torch.int32)
else: # UNSORTED path (handler: emit_unsorted)
x_input = x.repeat_interleave(top_k, dim=0).unsqueeze(-2) # [N*top_k, 1, D]
idx = expert_indices.flatten().to(torch.int32) # [N*top_k]
sort_experts = torch.zeros((), dtype=torch.int32)
inv_order = torch.empty(0, dtype=torch.int32) # sentinel: never read
return x_input, idx, sort_experts, inv_order
@torch.library.custom_op("mlx::moe_scatter_outputs", mutates_args=())
def moe_scatter_outputs(
down: Tensor, sort_experts: Tensor, inv_order: Tensor, top_k: int
) -> Tensor:
down = down.squeeze(-2) # [N*top_k, H]
if sort_experts.item(): # prefill: scatter back (handler: emit_then)
down = down[inv_order]
# decode: no scatter (inv_order is the unread sentinel) (handler: emit_else)
return down.reshape(down.shape[0] // top_k, top_k, -1) # [N, top_k, H]
The reference body is opaque to torch.export (custom ops are leaf
nodes), so the Python if N > sort_cutoff is fine — it runs only in
eager/parity tests. The outputs consumed downstream — x_input, idx,
sort_experts — are identically-shaped across both branches.
inv_order is the exception ([N*top_k] when sorted, a 0-element
sentinel when unsorted), but it is never read on the unsorted path
(see moe_scatter_outputs), so the shape divergence is harmless.
Add a register_fake for each. Unlike the reference, the fake must not
branch on M — under export M is a symbolic SymInt, so data-dependent
control flow on it is illegal. The fake returns one shape for all M:
x_input [N*top_k, 1, D], idx [N*top_k], sort_experts [], and
inv_order [N*top_k] (the sorted-path shape). sort_experts must be a
0-d tensor (there is no scalar-bool slot); its runtime value (0/1) is
chosen by the handler branches in section 4. The fake's [N*top_k]
inv_order vs. the unsorted reference's 0-element sentinel is benign —
inv_order feeds no downstream fake (moe_scatter_outputs derives its output
shape from down + top_k), so run torch.library.opcheck with M > 1
(or skip the inv_order shape check), since decode deliberately returns a
sentinel.
4. Handlers (backends/mlx/ops.py) — branch on M
Each handler is a node-for-node lowering of the section-3 reference:
emit_sorted mirrors the N > sort_cutoff branch, emit_unsorted mirrors
the else branch. Both pre-allocate every output slot and both branches
write all of them
write all of them** (the IfNode only selects a chain; downstream reads
fixed slot ids). This is the multi-output generalization of the
single-output emit_if_else pattern in _sample_handler
(ops.py:3537,3712) and the GGUF Q4_K linear
(backends/mlx/custom_kernel_ops/gguf/q4k/linear.py:453-477).
@REGISTRY.register(target=[torch.ops.mlx.moe_gather_inputs.default])
def _moe_gather_inputs_handler(P, n):
x, expert_indices = P.args(n)[0], P.args(n)[1]
top_k = P.args(n)[2] # static int
sort_cutoff = P.args(n)[3] # static int, consulted here at emission
out_slots = P.make_or_get_slots(n) # (x_input, idx, sort_experts, inv_order)
m_iov = emit_shape(P, n.args[0], x, end_dim=1)[0] # M = N (token count)
def emit_sorted(): # M > cutoff: argsort -> gather -> sorted_idx -> inv_order; sort flag = 1
... # ArgsortNode x2, FloorDivideNode (order // top_k), TakeNode/GatherNode,
# ExpandDimsNode, a 0-d const "1" -> write all four out_slots
def emit_unsorted(): # M <= cutoff: repeat token, flatten idx, sentinel inv_order; sort flag = 0
... # BroadcastTo/Reshape + ExpandDimsNode, ReshapeNode (flatten idx),
# a 0-element const for inv_order (sentinel; never read when sort flag == 0,
# so no dynamic IotaNode needed), a 0-d const "0" -> write the SAME four out_slots
# inv_order is only consumed by moe_scatter_outputs's sorted branch, which runs only when
# sort_experts == 1. So emit_unsorted writes a 0-element sentinel for inv_order
# (matching the eager reference's unsorted branch) instead of a real arange -- no
# dynamic IotaNode needed. register_fake still declares inv_order as [N*top_k].
# cond = (M - 1) // sort_cutoff: 0 (-> else/unsorted) for M <= sort_cutoff,
# >= 1 (-> then/sorted) for M > sort_cutoff. The IfNode rule is nonzero -> then
# (MLXInterpreter.h:1879-1881). There is no scalar-int compare node, but
# SubtractIntNode + FloorDivideIntNode suffice. If M is a compile-time literal,
# both fold and emit_if_else picks one branch -- no IfNode emitted.
cond = emit_floordiv(P, emit_sub_int(P, m_iov, IntOrVid.from_literal(1)),
IntOrVid.from_literal(sort_cutoff))
emit_if_else(P, cond, emit_sorted, emit_unsorted)
return out_slots
moe_scatter_outputs handler: read the sort_experts 0-d tensor
to a Vid via ItemIntNode, then emit_if_else on it —
then: TakeNode(down, inv_order) then ReshapeNode; else: ReshapeNode
only (skip the gather). Both write the single output slot.
Grouped-matmul handlers (_gather_mm_handler ops.py:1545,
_gather_qmm_handler ops.py:1573): read sorted_indices; if it is a
runtime tensor, thread it to a Vid (ItemIntNode) and set
GatherMmNode.sorted_indices = IntOrVid.from_vid(...); if absent/None,
use IntOrVid.from_literal(0).
Available schema nodes (no new node types needed):
ArgsortNode (mlx_graph_schema.py:845), TakeNode (:216),
GatherNode (:560), FloorDivideNode (tensor order // top_k,
:1042), FloorDivideIntNode (scalar (M-1) // sort_cutoff, :293),
SubtractIntNode (:279), ReshapeNode (:531),
ExpandDimsNode (:194), SqueezeNode (:458),
BroadcastToNode (:507), ItemIntNode (:188),
SymSizeNode (:306), plus IdCopyNode for pass-through writes
(ops.py:2023). inv_order = argsort(argsort) = two ArgsortNodes,
matching switch.py:245-246. Add a small emit_floordiv helper
alongside emit_ceil_div (op_helpers.py:326-352) that folds when its
operand is literal.
5. Replace the static sort_experts flag with a sort_cutoff knob
The export-time sort_experts bool (always/never sort) becomes a
sort_cutoff int that is threaded the same way but now selects the
runtime threshold consulted during emission:
SwitchMLP.__init__/forward (backends/mlx/llm/switch.py:220-271):
replace the if sort_experts: / else: blocks with the two op calls;
drop the sort_experts parameter and store self.sort_cutoff (passed
to moe_gather_inputs).
_sparse_moe_forward (mlx_source_transformations.py:55-78): replace
sort_experts=getattr(self, "_sort_experts", False) — SwitchMLP now
reads self.sort_cutoff.
_swap_sparse_moe (:324-332) and mlx_source_transformations
(:335-364): swap the sort_experts bool for a sort_cutoff int
(e.g. default 1).
examples/models/qwen3_5_moe/export.py (~:52-58): pass sort_cutoff
instead of sort_experts (default 1; tune per the matmul crossover).
sort_cutoff is a compile-time constant baked at export, not a
per-call runtime input — it sets the M threshold the IfNode compares
against. Pick it by benchmarking the sorted-vs-unsorted matmul crossover
for the model's num_experts/top_k; 1 reproduces "sort whenever
M > 1."
Acceptance criteria
Testing
- Unit / lowering —
backends/mlx/test/test_ops.py. Add an
OpTestCase (model wrapping SwitchMLP.forward, or the two ops
directly) with expected_node_counts, following the Sample*Test
node-count pattern (test_ops.py:7680-7756). With a fixed
sort_cutoff, assert both batch_size ≤ sort_cutoff (unsorted:
sentinel inv_order, sorted_indices==0, no ArgsortNode) and
batch_size > sort_cutoff (sorted: ArgsortNode×2,
TakeNode/GatherNode, IfNode). Cover the static-M fold (a
fixed-shape input of each kind emits no IfNode). Existing
grouped-matmul tests GatherMmTest (test_ops.py:6563) and
GatherQmmTest (:6669) — each already has a batch_size=1 config —
should be extended to pass a 0-d sorted_indices tensor through the new
IntOrVid path.
- E2E CI gate —
.github/workflows/mlx.yml → test-mlx-qwen35-moe:
python -m executorch.examples.models.qwen3_5_moe.export \
--tiny-test --backend mlx --qlinear 4w --qlinear-group-size 32 \
--output-dir /tmp/qwen35_moe_mlx_tiny
python -m executorch.examples.models.qwen3_5_moe.run \
--pte /tmp/qwen35_moe_mlx_tiny/model.pte --prompt-len 4 --max-new-tokens 5
Must keep the exact output Generated token ids: [167, 94, 253, 88, 227]
(exercises prefill + decode, i.e. both branches).
- Perf check (optional): report
decode_token_per_sec before/after.
Pointers
- MoE forward + the sort:
backends/mlx/llm/switch.py:220-271
(SwitchMLP.forward), SwitchLinear.forward :130-161.
- Grouped-matmul ops:
backends/mlx/custom_ops.py:282-393
(gather_mm/gather_qmm); handlers backends/mlx/ops.py:1545-1641.
sorted_indices is a correctness contract, ignored by the eager
reference and passed straight to MLX core
(MLXInterpreter.h:875,909).
IntOrVid precedent end-to-end: PartitionNode.kth — schema.fbs:683,
mlx_graph_schema.py:856, MLXLoader.cpp:1452,
MLXInterpreter.h:1819 (resolve_int).
- Branch helpers:
backends/mlx/builder/op_helpers.py — emit_if_else
:355, emit_sub_int :307, emit_shape :177, emit_product
:232.
- Multi-output handler template:
_split_with_sizes_handler
(ops.py:1500-1542, make_or_get_slots + tuple return);
extractor _getitem_handler (ops.py:2008-2028).
emit_if_else single-output precedent: _sample_handler
(ops.py:3519-3713); Q4_K M-branch
(backends/mlx/custom_kernel_ops/gguf/q4k/linear.py:453-477).
- Export wiring to update (
sort_experts → sort_cutoff):
examples/models/qwen3_5_moe/export.py,
.../mlx_source_transformations.py:55-78,324-364.
Problem
The MLX MoE sorts tokens by expert so each expert's tokens are
contiguous for a grouped matmul (
gather_mm/gather_qmmwithsorted_indices=True). This helps prefill (many tokens to batch) butis pure overhead during decode (1 token → nothing to batch): the
argsort, activation gather, and output scatter all run for nothing, andthe sorted/segmented matmul kernel is a poor fit for a tiny batch.
The sort is decided by a static export-time flag (
sort_experts),so a single dynamic-shape
.pteruns it in both phases(
backends/mlx/llm/switch.py:243-271,examples/models/qwen3_5_moe/mlx_source_transformations.py:324-332).Fix: decide sort vs. no-sort at runtime, based on token count
M,inside the backend lowering, using the existing
emit_if_elsehelper.Approach: two glue ops around the existing grouped matmuls
SwitchMLP.forwardbecomes:sort_expertsis now a runtime 0-d int tensor (1 = sorted, 0 =unsorted) produced by
moe_gather_inputs. It feeds both thegrouped matmuls (
sorted_indices=) andmoe_scatter_outputs. The two new ops'handlers branch on
Mviaemit_if_else, comparing the runtime tokencount
Magainst the staticsort_cutoff: sort whenM > sort_cutoff,skip the sort otherwise.
sort_cutoffis a compile-time arg of the op(like
top_k), consulted during emission —sort_cutoff=1recovers"sort whenever there is more than one token." The grouped-matmul call
sites are structurally unchanged; only their
sorted_indicesargument nowcarries the runtime
sort_expertstensor.Two design constraints (both must hold)
1. Both
emit_if_elsebranches must produce the same shapes.torch.exportfixes each op output's rank/shape from itsregister_fake,and the single downstream gather is traced against those shapes. So the
sorted and unsorted branches of
moe_gather_inputsmust writeidentically-shaped outputs. We therefore unify on the sorted
layout — the unsorted branch replicates each token
top_ktimes insteadof using the rank-4 broadcast form:
M > sort_cutoff)M ≤ sort_cutoff)x_input[N*top_k, 1, D]x[order // top_k]thenunsqueeze(-2)repeat_interleave(x, top_k)thenunsqueeze(-2)idx[N*top_k]int32flat_indices[order]expert_indices.flatten()sort_experts[](0-d) int3210inv_order[N*top_k]int32 (fake/sorted)argsort(order)The unsorted branch's
repeat_interleaveand the sorted branch'sx[order // top_k]both materializeM*top_krows, so neither is cheaperon that axis — the sort decision is governed by the matmul, not the
materialization (see
sort_cutoffbelow).moe_scatter_outputs(down, sort_experts, inv_order, top_k)takesdown[N*top_k, 1, H]→squeeze(-2)→ returns[N, top_k, H](prefill:
down[inv_order].reshape(N, top_k, -1); decode:down.reshape(N, top_k, -1)).2. The grouped matmul's
sorted_indicesmust be runtime-settable.Today
sorted_indicesis a staticboolbaked ontoGatherMmNode/GatherQmmNode(backends/mlx/ops.py:1557,1595;schema.fbs:951,966). It is an MLX kernel correctness contract, not ano-op hint:
sorted_indices=Truetells the kernel the indices are groupedby expert, so it must be
Trueonly when we actually sorted. Since thesort decision is now at runtime, upgrade
sorted_indicesfromboolto
IntOrVid(0= unsorted, nonzero = sorted) — the same mechanismPartitionNode.kthalready uses.The change
1. Upgrade
sorted_indices:bool→IntOrVidThis is the enabling change and the only one touching C++/schema. It
mirrors
PartitionNode.kth(kth: IntOrVidinschema.fbs:683,resolved at runtime by
resolve_int).backends/mlx/serialization/schema.fbs:945-967): changesorted_indices: bool = false;→sorted_indices: IntOrVid;onGatherMmNodeandGatherQmmNode.BC note:
schema.fbs:14says never change an existing field's type.Two options: (A, recommended) change the type directly — MLX
delegate
.pteare regenerated from source each export, so this isacceptable; (B, strict append-only) keep the legacy
boolfield,append a new
sorted_indices_iov: IntOrVid;at the end of each table,and have the runtime prefer the
IntOrVidwhen present.python backends/mlx/serialization/generate.pytoregenerate
mlx_graph_schema.py, the serializers/inspector, the_generated/bindings,runtime/schema_generated.h, and the partialruntime/MLXLoader.{h,cpp}.backends/mlx/runtime/MLXLoader.cpp:2056,2085): changenode.sorted_indices = fb->sorted_indices();→node.sorted_indices = convert_int_or_vid(fb->sorted_indices());(exactly like
kthatMLXLoader.cpp:1452,1467).MLXLoader.h:952,967field type
bool→IntOrVid.backends/mlx/runtime/MLXInterpreter.h:875,909): resolveto a
boolat exec time:bool sorted = resolve_int(n.sorted_indices, st) != 0;then passsortedtogather_mm(...)/gather_qmm(...).2. Custom ops accept a runtime
sorted_indicesbackends/mlx/custom_ops.py— changegather_mm(:282-322) andgather_qmm(:325-393)sorted_indices: bool = False→sorted_indices: Optional[Tensor] = None(0-d int;None/0= unsorted).The eager reference and
register_fakealready ignoresorted_indices(it is layout-only with identical numerics —
custom_ops.py:299-303,347-369), so only the signatures change. UpdateSwitchLinear.forward(
backends/mlx/llm/switch.py:130-161) to pass the tensor straightthrough.
3. Two new custom ops (
backends/mlx/custom_ops.py)The eager references branch on
Mon purpose — they are theexecutable spec the lowering handler (section 4) mirrors branch-for-branch.
Sorting is an invertible permutation (identical numerics either way), so
the two paths exist for the lowering's sake, not the math's. Write them
to read exactly like the two MLX branches the handler will emit:
Add a
register_fakefor each. Unlike the reference, the fake must notbranch on
M— under exportMis a symbolicSymInt, so data-dependentcontrol flow on it is illegal. The fake returns one shape for all
M:x_input[N*top_k, 1, D],idx[N*top_k],sort_experts[], andinv_order[N*top_k](the sorted-path shape).sort_expertsmust be a0-d tensor (there is no scalar-bool slot); its runtime value (0/1) is
chosen by the handler branches in section 4. The fake's
[N*top_k]inv_ordervs. the unsorted reference's 0-element sentinel is benign —inv_orderfeeds no downstream fake (moe_scatter_outputsderives its outputshape from
down+top_k), so runtorch.library.opcheckwithM > 1(or skip the
inv_ordershape check), since decode deliberately returns asentinel.
4. Handlers (
backends/mlx/ops.py) — branch onMEach handler is a node-for-node lowering of the section-3 reference:
emit_sortedmirrors theN > sort_cutoffbranch,emit_unsortedmirrorsthe
elsebranch. Both pre-allocate every output slot and both brancheswrite all of them
write all of them** (the
IfNodeonly selects a chain; downstream readsfixed slot ids). This is the multi-output generalization of the
single-output
emit_if_elsepattern in_sample_handler(
ops.py:3537,3712) and the GGUF Q4_K linear(
backends/mlx/custom_kernel_ops/gguf/q4k/linear.py:453-477).moe_scatter_outputshandler: read thesort_experts0-d tensorto a
VidviaItemIntNode, thenemit_if_elseon it —then:
TakeNode(down, inv_order)thenReshapeNode; else:ReshapeNodeonly (skip the gather). Both write the single output slot.
Grouped-matmul handlers (
_gather_mm_handlerops.py:1545,_gather_qmm_handlerops.py:1573): readsorted_indices; if it is aruntime tensor, thread it to a
Vid(ItemIntNode) and setGatherMmNode.sorted_indices = IntOrVid.from_vid(...); if absent/None,use
IntOrVid.from_literal(0).5. Replace the static
sort_expertsflag with asort_cutoffknobThe export-time
sort_expertsbool (always/never sort) becomes asort_cutoffint that is threaded the same way but now selects theruntime threshold consulted during emission:
SwitchMLP.__init__/forward(backends/mlx/llm/switch.py:220-271):replace the
if sort_experts:/else:blocks with the two op calls;drop the
sort_expertsparameter and storeself.sort_cutoff(passedto
moe_gather_inputs)._sparse_moe_forward(mlx_source_transformations.py:55-78): replacesort_experts=getattr(self, "_sort_experts", False)—SwitchMLPnowreads
self.sort_cutoff._swap_sparse_moe(:324-332) andmlx_source_transformations(
:335-364): swap thesort_expertsbool for asort_cutoffint(e.g. default
1).examples/models/qwen3_5_moe/export.py(~:52-58): passsort_cutoffinstead of
sort_experts(default1; tune per the matmul crossover).Acceptance criteria
sorted_indicesis anIntOrVidonGatherMmNode/GatherQmmNode(schema +
generate.pyre-run + loader + interpreter viaresolve_int), resolving to the correct kernelboolat runtime.mlx::moe_gather_inputs/mlx::moe_scatter_outputsexist with
register_fake; their handlers branch onMwithemit_if_else, both branches writing identically-shaped outputs.SwitchMLP.forwardcalls the two ops; thesort_expertsbool isreplaced by a
sort_cutoffint threaded through_sparse_moe_forward,_swap_sparse_moe, andexport.py.sort_cutoffat emission viacond = (M-1) // sort_cutoff: the unsorted branch (M ≤ sort_cutoff)emits no
argsort/scatter and runs the grouped matmul withsorted_indices == 0; the sorted branch (M > sort_cutoff) sorts andruns with
sorted_indices == 1.Mis statically known,condfolds to a literal and noIfNodeis emitted (the helper picks one branch).Testing
backends/mlx/test/test_ops.py. Add anOpTestCase(model wrappingSwitchMLP.forward, or the two opsdirectly) with
expected_node_counts, following theSample*Testnode-count pattern (
test_ops.py:7680-7756). With a fixedsort_cutoff, assert bothbatch_size ≤ sort_cutoff(unsorted:sentinel
inv_order,sorted_indices==0, noArgsortNode) andbatch_size > sort_cutoff(sorted:ArgsortNode×2,TakeNode/GatherNode,IfNode). Cover the static-Mfold (afixed-shape input of each kind emits no
IfNode). Existinggrouped-matmul tests
GatherMmTest(test_ops.py:6563) andGatherQmmTest(:6669) — each already has abatch_size=1config —should be extended to pass a 0-d
sorted_indicestensor through the newIntOrVidpath..github/workflows/mlx.yml→test-mlx-qwen35-moe:Generated token ids: [167, 94, 253, 88, 227](exercises prefill + decode, i.e. both branches).
decode_token_per_secbefore/after.Pointers
backends/mlx/llm/switch.py:220-271(
SwitchMLP.forward),SwitchLinear.forward:130-161.backends/mlx/custom_ops.py:282-393(
gather_mm/gather_qmm); handlersbackends/mlx/ops.py:1545-1641.sorted_indicesis a correctness contract, ignored by the eagerreference and passed straight to MLX core
(
MLXInterpreter.h:875,909).IntOrVidprecedent end-to-end:PartitionNode.kth—schema.fbs:683,mlx_graph_schema.py:856,MLXLoader.cpp:1452,MLXInterpreter.h:1819(resolve_int).backends/mlx/builder/op_helpers.py—emit_if_else:355,emit_sub_int:307,emit_shape:177,emit_product:232._split_with_sizes_handler(
ops.py:1500-1542,make_or_get_slots+ tuple return);extractor
_getitem_handler(ops.py:2008-2028).emit_if_elsesingle-output precedent:_sample_handler(
ops.py:3519-3713); Q4_K M-branch(
backends/mlx/custom_kernel_ops/gguf/q4k/linear.py:453-477).sort_experts→sort_cutoff):examples/models/qwen3_5_moe/export.py,.../mlx_source_transformations.py:55-78,324-364.