Summary
mlx::sample already does on-device Gumbel-max token sampling with
temperature and top-p (nucleus) filtering. SamplingHead reserves a
top_k argument, but it is not implemented — passing anything other
than None raises:
# backends/mlx/llm/sampling.py:33-35
def forward(self, *args, temperature, top_k=None, top_p=1.0, seed=None, **kwargs):
if top_k is not None:
raise NotImplementedError("top_k sampling is not implemented")
This issue wires top_k through the same four layers that already carry
top_p: the SamplingHead wrapper, the mlx::sample custom op
(schema + CPU reference + fake), the emission/lowering handler in
ops.py, and the tests. top_k restricts sampling to the k most
likely tokens before drawing.
top_k=None means "no top-k filtering" — keep every token. This is
the off switch, exactly analogous to top_p=1.0.
Feasibility: HIGH — Python-only, no C++ / schema changes
top_p was added entirely in Python (op def + emission + tests) with
zero C++ runtime or flatbuffer-schema changes, because sampling is
decomposed into primitive MLX graph nodes the runtime already executes.
top_k follows the same path, and every primitive it needs already
exists and is wired end-to-end:
| Need |
Status |
Where |
| descending-sorted probabilities |
already emitted for top-p |
backends/mlx/ops.py:3620-3627 (sorted_p) |
| pick the k-th largest prob as a threshold |
TakeNode (index: IntOrVidOrTid, axis) |
serialization/mlx_graph_schema.py:217-221; runtime MLXInterpreter.h |
| mask tokens below threshold |
LessNode + WhereNode (already used by top-p) |
backends/mlx/ops.py:3668-3686 |
thread a runtime int (k) into the graph |
ItemIntNode (already used for seed) |
backends/mlx/ops.py:3556-3558 |
| (efficient alt) k-th largest without a full sort |
PartitionNode / ArgPartitionNode (kth field) |
serialization/mlx_graph_schema.py:853-861; MLXInterpreter.h:1817-1830 |
Because TakeNode and PartitionNode already have exec_* handlers in
the runtime, no schema.fbs edit, no generate.py re-run, and no
MLXInterpreter.h change are required.
Current state (what top_p looks like at each layer — the template)
-
SamplingHead — backends/mlx/llm/sampling.py:33-40. Coerces
top_p to a tensor and forwards it to the op:
if not isinstance(top_p, torch.Tensor):
top_p = torch.tensor(float(top_p))
return torch.ops.mlx.sample(last, temperature, top_p, seed)
-
Custom op — backends/mlx/custom_ops.py:396-444. Schema
sample(logits, temperature, top_p, seed=None), a CPU reference
implementation (sort → cumsum → prefix-mass → threshold → mask), and
register_fake.
-
Emission handler — backends/mlx/ops.py:3519-3711
(_sample_handler). Lowers to an IfNode on temperature > 0:
greedy ArgmaxNode vs. the sampling branch. The top-p nucleus mask is
built at backends/mlx/ops.py:3615-3696 from SoftmaxNode,
SortNode, CumsumNode, MinNode, Greater/Less/WhereNode. This is
exactly where the top-k mask goes.
-
Tests:
- eager behavior —
backends/mlx/test/test_sample.py:132-143
(test_top_p_restricts_to_nucleus, test_top_p_one_keeps_all);
- lowering / node-count —
backends/mlx/test/test_ops.py:7729-7756
(SampleTopPTest, with TopPSampleModel);
- on-device end-to-end —
backends/mlx/test/test_sample.py:203-219
(test_top_p_end_to_end).
Proposed design
Make top_k an optional scalar int tensor input, mirroring seed:
None at export ⇒ no top-k nodes emitted (keep all tokens); a tensor ⇒
keep the top k. Append it after seed to keep the op signature
backward-compatible.
Part 1 — SamplingHead (backends/mlx/llm/sampling.py)
Drop the NotImplementedError, coerce top_k to a tensor when given,
and pass it through:
def forward(self, *args, temperature, top_k=None, top_p=1.0, seed=None, **kwargs):
logits = self.model(*args, **kwargs) # [B, S, vocab]
last = logits[:, -1, :] # [B, vocab]
if not isinstance(top_p, torch.Tensor):
top_p = torch.tensor(float(top_p))
if top_k is not None and not isinstance(top_k, torch.Tensor):
top_k = torch.tensor(int(top_k), dtype=torch.int64)
return torch.ops.mlx.sample(last, temperature, top_p, seed, top_k)
Update the docstring (sampling.py:17-27) so the top_k line reads
"keep only the k most likely tokens; None disables top-k (keep all)".
Part 2 — custom op (backends/mlx/custom_ops.py:396-444)
Add top_k: Optional[Tensor] = None to the op signature, the
register_fake signature, and the CPU reference. In the reference,
apply top-k as a threshold on the (descending) sorted probabilities
before the existing top-p step, e.g.:
if top_k is not None:
k = int(top_k.item())
kth = s_probs[..., k - 1 : k] # k-th largest prob per row
scaled = torch.where(probs >= kth, scaled, scaled.new_tensor(float("-inf")))
Keep top-k and top-p composable: a token must pass both filters
(intersection of the two kept sets). The reference is the host contract
used by the distributional tests, so it should match the emitted graph's
semantics.
Part 3 — emission (backends/mlx/ops.py, _sample_handler)
-
Widen the arity check at backends/mlx/ops.py:3531:
require_args(args, 3, 4, ...) → require_args(args, 3, 5, ...), and
read top_k = args[4] if len(args) > 4 and args[4] is not None else None.
-
In emit_sample, after sorted_p is computed
(backends/mlx/ops.py:3627, descending-sorted probabilities), and
only when top_k is not None:
- thread
k to a Vid via ItemIntNode (same pattern as seed at
backends/mlx/ops.py:3556-3558);
- get the k-th largest probability as the threshold with a
TakeNode
on sorted_p along axis=-1 at index k-1
(TakeNode.index accepts a Vid);
- emit
LessNode(probs, thresh_k) -> drop_k and fold it into the
existing drop mask before the WhereNode at
backends/mlx/ops.py:3678-3686 (drop a token if it fails top-p
or top-k).
When top_k is None, emit nothing extra so the graph is byte-for-byte
the current top-p-only graph.
Detail: TakeNode needs index k-1. Either subtract one on the
Vid, or use Partition/ArgPartition with kth = vocab - k. The
TakeNode-on-sorted route is simplest for a first issue since
sorted_p already exists; PartitionNode is the efficiency follow-up.
Part 4 — tests
Mirror the top_p tests one-for-one:
- Eager (
backends/mlx/test/test_sample.py): add top_k to the
_sample helper, add test_top_k_restricts_to_top_k (e.g. probs
[0.5, 0.3, 0.15, 0.05], top_k=2 ⇒ tokens ⊆ {0, 1}) and
test_top_k_none_keeps_all (the tail token is reachable). Add a
test_top_k_and_top_p_compose covering the intersection.
- Lowering (
backends/mlx/test/test_ops.py): add a TopKSampleModel
and a SampleTopKTest(OpTestCase) with expected_node_counts updated
for the new TakeNode (+1) and the extra mask node(s). Confirm
SampleSeededTest / SampleUnseededTest / SampleTopPTest counts are
unchanged (top-k off ⇒ no new nodes).
- End-to-end (
backends/mlx/test/test_sample.py): add
test_top_k_end_to_end modeled on test_top_p_end_to_end:203-219,
asserting the on-device token is within the top-k set.
Combined top-k + top-p semantics
When both are set, keep the intersection: a token survives only if it
is in the top-k and in the top-p nucleus. Equivalently, OR the two drop
masks before applying -inf. Make the reference impl and the emitted
graph agree, and cover it with test_top_k_and_top_p_compose.
Out of scope (follow-ups)
PartitionNode-based top-k (avoid the full sort when top-p is off).
The primitive is already wired (MLXInterpreter.h:1817-1830); this is a
pure performance optimization.
- Per-request runtime
k validation (e.g. k > vocab, k <= 0).
Document the expected range; heavy validation is a separate concern.
- Batched / multi-token sampling. Out of scope (matches the existing
max_batch_size == 1 constraint).
Acceptance criteria
Pointers
- Reserved slot to implement:
backends/mlx/llm/sampling.py:33-40.
- Op def + CPU reference + fake (mirror
top_p):
backends/mlx/custom_ops.py:396-444.
- Emission handler + exact insertion point (next to the top-p mask):
backends/mlx/ops.py:3519-3711 (mask block 3615-3696; arity check
3531; seed→Vid precedent 3556-3558).
- Primitives already available (no C++ needed):
TakeNode
(serialization/mlx_graph_schema.py:217-221), PartitionNode /
ArgPartitionNode (mlx_graph_schema.py:853-861, runtime
MLXInterpreter.h:1817-1830).
- Tests to mirror:
backends/mlx/test/test_sample.py:132-143,
:203-219; backends/mlx/test/test_ops.py:7729-7756.
- Prior art:
backends/mlx/llm/SAMPLING_GFI.md (the original
sampling GFI; predates top_p/top_k).
Summary
mlx::samplealready does on-device Gumbel-max token sampling withtemperature and top-p (nucleus) filtering.
SamplingHeadreserves atop_kargument, but it is not implemented — passing anything otherthan
Noneraises:This issue wires
top_kthrough the same four layers that already carrytop_p: theSamplingHeadwrapper, themlx::samplecustom op(schema + CPU reference + fake), the emission/lowering handler in
ops.py, and the tests.top_krestricts sampling to thekmostlikely tokens before drawing.
top_k=Nonemeans "no top-k filtering" — keep every token. This isthe off switch, exactly analogous to
top_p=1.0.Feasibility: HIGH — Python-only, no C++ / schema changes
top_pwas added entirely in Python (op def + emission + tests) withzero C++ runtime or flatbuffer-schema changes, because sampling is
decomposed into primitive MLX graph nodes the runtime already executes.
top_kfollows the same path, and every primitive it needs alreadyexists and is wired end-to-end:
backends/mlx/ops.py:3620-3627(sorted_p)TakeNode(index: IntOrVidOrTid,axis)serialization/mlx_graph_schema.py:217-221; runtimeMLXInterpreter.hLessNode+WhereNode(already used by top-p)backends/mlx/ops.py:3668-3686k) into the graphItemIntNode(already used forseed)backends/mlx/ops.py:3556-3558PartitionNode/ArgPartitionNode(kthfield)serialization/mlx_graph_schema.py:853-861;MLXInterpreter.h:1817-1830Because
TakeNodeandPartitionNodealready haveexec_*handlers inthe runtime, no
schema.fbsedit, nogenerate.pyre-run, and noMLXInterpreter.hchange are required.Current state (what
top_plooks like at each layer — the template)SamplingHead—backends/mlx/llm/sampling.py:33-40. Coercestop_pto a tensor and forwards it to the op:Custom op —
backends/mlx/custom_ops.py:396-444. Schemasample(logits, temperature, top_p, seed=None), a CPU referenceimplementation (sort → cumsum → prefix-mass → threshold → mask), and
register_fake.Emission handler —
backends/mlx/ops.py:3519-3711(
_sample_handler). Lowers to anIfNodeontemperature > 0:greedy
ArgmaxNodevs. the sampling branch. The top-p nucleus mask isbuilt at
backends/mlx/ops.py:3615-3696fromSoftmaxNode,SortNode,CumsumNode,MinNode,Greater/Less/WhereNode. This isexactly where the top-k mask goes.
Tests:
backends/mlx/test/test_sample.py:132-143(
test_top_p_restricts_to_nucleus,test_top_p_one_keeps_all);backends/mlx/test/test_ops.py:7729-7756(
SampleTopPTest, withTopPSampleModel);backends/mlx/test/test_sample.py:203-219(
test_top_p_end_to_end).Proposed design
Make
top_kan optional scalar int tensor input, mirroringseed:Noneat export ⇒ no top-k nodes emitted (keep all tokens); a tensor ⇒keep the top
k. Append it afterseedto keep the op signaturebackward-compatible.
Part 1 —
SamplingHead(backends/mlx/llm/sampling.py)Drop the
NotImplementedError, coercetop_kto a tensor when given,and pass it through:
Update the docstring (
sampling.py:17-27) so thetop_kline reads"keep only the
kmost likely tokens;Nonedisables top-k (keep all)".Part 2 — custom op (
backends/mlx/custom_ops.py:396-444)Add
top_k: Optional[Tensor] = Noneto the op signature, theregister_fakesignature, and the CPU reference. In the reference,apply top-k as a threshold on the (descending) sorted probabilities
before the existing top-p step, e.g.:
Keep top-k and top-p composable: a token must pass both filters
(intersection of the two kept sets). The reference is the host contract
used by the distributional tests, so it should match the emitted graph's
semantics.
Part 3 — emission (
backends/mlx/ops.py,_sample_handler)Widen the arity check at
backends/mlx/ops.py:3531:require_args(args, 3, 4, ...)→require_args(args, 3, 5, ...), andread
top_k = args[4] if len(args) > 4 and args[4] is not None else None.In
emit_sample, aftersorted_pis computed(
backends/mlx/ops.py:3627, descending-sorted probabilities), andonly when
top_k is not None:kto aVidviaItemIntNode(same pattern asseedatbackends/mlx/ops.py:3556-3558);TakeNodeon
sorted_palongaxis=-1at indexk-1(
TakeNode.indexaccepts aVid);LessNode(probs, thresh_k) -> drop_kand fold it into theexisting
dropmask before theWhereNodeatbackends/mlx/ops.py:3678-3686(drop a token if it fails top-por top-k).
When
top_k is None, emit nothing extra so the graph is byte-for-bytethe current top-p-only graph.
Detail:
TakeNodeneeds indexk-1. Either subtract one on theVid, or usePartition/ArgPartitionwithkth = vocab - k. TheTakeNode-on-sorted route is simplest for a first issue sincesorted_palready exists;PartitionNodeis the efficiency follow-up.Part 4 — tests
Mirror the
top_ptests one-for-one:backends/mlx/test/test_sample.py): addtop_kto the_samplehelper, addtest_top_k_restricts_to_top_k(e.g. probs[0.5, 0.3, 0.15, 0.05],top_k=2⇒ tokens ⊆{0, 1}) andtest_top_k_none_keeps_all(the tail token is reachable). Add atest_top_k_and_top_p_composecovering the intersection.backends/mlx/test/test_ops.py): add aTopKSampleModeland a
SampleTopKTest(OpTestCase)withexpected_node_countsupdatedfor the new
TakeNode(+1) and the extra mask node(s). ConfirmSampleSeededTest/SampleUnseededTest/SampleTopPTestcounts areunchanged (top-k off ⇒ no new nodes).
backends/mlx/test/test_sample.py): addtest_top_k_end_to_endmodeled ontest_top_p_end_to_end:203-219,asserting the on-device token is within the top-k set.
Combined top-k + top-p semantics
When both are set, keep the intersection: a token survives only if it
is in the top-k and in the top-p nucleus. Equivalently, OR the two drop
masks before applying
-inf. Make the reference impl and the emittedgraph agree, and cover it with
test_top_k_and_top_p_compose.Out of scope (follow-ups)
PartitionNode-based top-k (avoid the full sort when top-p is off).The primitive is already wired (
MLXInterpreter.h:1817-1830); this is apure performance optimization.
kvalidation (e.g.k > vocab,k <= 0).Document the expected range; heavy validation is a separate concern.
max_batch_size == 1constraint).Acceptance criteria
SamplingHead.forwardacceptstop_k(int or scalar int tensor orNone) and no longer raisesNotImplementedError;top_k=Nonekeeps all tokens.
mlx::sampleschema,register_fake, and CPU reference all take anoptional
top_k, with top-k applied as a threshold on the sortedprobabilities and composable with top-p (intersection).
_sample_handleremits the top-k mask only whentop_kis present(using
ItemIntNode+TakeNode+Less/Where), with noschema.fbs/generate.py/MLXInterpreter.hchanges.top_k=None, the emitted graph is identical to today's (theexisting
Sample*Testnode counts are unchanged).SampleTopKTest), and on-device(
test_top_k_end_to_end) tests pass, including a combinedtop-k + top-p case.
Pointers
backends/mlx/llm/sampling.py:33-40.top_p):backends/mlx/custom_ops.py:396-444.backends/mlx/ops.py:3519-3711(mask block3615-3696; arity check3531;seed→Vidprecedent3556-3558).TakeNode(
serialization/mlx_graph_schema.py:217-221),PartitionNode/ArgPartitionNode(mlx_graph_schema.py:853-861, runtimeMLXInterpreter.h:1817-1830).backends/mlx/test/test_sample.py:132-143,:203-219;backends/mlx/test/test_ops.py:7729-7756.backends/mlx/llm/SAMPLING_GFI.md(the originalsampling GFI; predates
top_p/top_k).