Skip to content

Good First Issue: top-k filtering for mlx::sample (MLX backend) #20548

Description

@metascroy

Summary

mlx::sample already does on-device Gumbel-max token sampling with
temperature and top-p (nucleus) filtering. SamplingHead reserves a
top_k argument, but it is not implemented — passing anything other
than None raises:

# backends/mlx/llm/sampling.py:33-35
def forward(self, *args, temperature, top_k=None, top_p=1.0, seed=None, **kwargs):
    if top_k is not None:
        raise NotImplementedError("top_k sampling is not implemented")

This issue wires top_k through the same four layers that already carry
top_p: the SamplingHead wrapper, the mlx::sample custom op
(schema + CPU reference + fake), the emission/lowering handler in
ops.py, and the tests. top_k restricts sampling to the k most
likely tokens before drawing.

top_k=None means "no top-k filtering" — keep every token. This is
the off switch, exactly analogous to top_p=1.0.

Feasibility: HIGH — Python-only, no C++ / schema changes

top_p was added entirely in Python (op def + emission + tests) with
zero C++ runtime or flatbuffer-schema changes, because sampling is
decomposed into primitive MLX graph nodes the runtime already executes.
top_k follows the same path, and every primitive it needs already
exists and is wired end-to-end:

Need Status Where
descending-sorted probabilities already emitted for top-p backends/mlx/ops.py:3620-3627 (sorted_p)
pick the k-th largest prob as a threshold TakeNode (index: IntOrVidOrTid, axis) serialization/mlx_graph_schema.py:217-221; runtime MLXInterpreter.h
mask tokens below threshold LessNode + WhereNode (already used by top-p) backends/mlx/ops.py:3668-3686
thread a runtime int (k) into the graph ItemIntNode (already used for seed) backends/mlx/ops.py:3556-3558
(efficient alt) k-th largest without a full sort PartitionNode / ArgPartitionNode (kth field) serialization/mlx_graph_schema.py:853-861; MLXInterpreter.h:1817-1830

Because TakeNode and PartitionNode already have exec_* handlers in
the runtime, no schema.fbs edit, no generate.py re-run, and no
MLXInterpreter.h change are required.

Current state (what top_p looks like at each layer — the template)

  1. SamplingHeadbackends/mlx/llm/sampling.py:33-40. Coerces
    top_p to a tensor and forwards it to the op:

    if not isinstance(top_p, torch.Tensor):
        top_p = torch.tensor(float(top_p))
    return torch.ops.mlx.sample(last, temperature, top_p, seed)
  2. Custom opbackends/mlx/custom_ops.py:396-444. Schema
    sample(logits, temperature, top_p, seed=None), a CPU reference
    implementation (sort → cumsum → prefix-mass → threshold → mask), and
    register_fake.

  3. Emission handlerbackends/mlx/ops.py:3519-3711
    (_sample_handler). Lowers to an IfNode on temperature > 0:
    greedy ArgmaxNode vs. the sampling branch. The top-p nucleus mask is
    built at backends/mlx/ops.py:3615-3696 from SoftmaxNode,
    SortNode, CumsumNode, MinNode, Greater/Less/WhereNode. This is
    exactly where the top-k mask goes.

  4. Tests:

    • eager behavior — backends/mlx/test/test_sample.py:132-143
      (test_top_p_restricts_to_nucleus, test_top_p_one_keeps_all);
    • lowering / node-count — backends/mlx/test/test_ops.py:7729-7756
      (SampleTopPTest, with TopPSampleModel);
    • on-device end-to-end — backends/mlx/test/test_sample.py:203-219
      (test_top_p_end_to_end).

Proposed design

Make top_k an optional scalar int tensor input, mirroring seed:
None at export ⇒ no top-k nodes emitted (keep all tokens); a tensor ⇒
keep the top k. Append it after seed to keep the op signature
backward-compatible.

Part 1 — SamplingHead (backends/mlx/llm/sampling.py)

Drop the NotImplementedError, coerce top_k to a tensor when given,
and pass it through:

def forward(self, *args, temperature, top_k=None, top_p=1.0, seed=None, **kwargs):
    logits = self.model(*args, **kwargs)        # [B, S, vocab]
    last = logits[:, -1, :]                      # [B, vocab]
    if not isinstance(top_p, torch.Tensor):
        top_p = torch.tensor(float(top_p))
    if top_k is not None and not isinstance(top_k, torch.Tensor):
        top_k = torch.tensor(int(top_k), dtype=torch.int64)
    return torch.ops.mlx.sample(last, temperature, top_p, seed, top_k)

Update the docstring (sampling.py:17-27) so the top_k line reads
"keep only the k most likely tokens; None disables top-k (keep all)".

Part 2 — custom op (backends/mlx/custom_ops.py:396-444)

Add top_k: Optional[Tensor] = None to the op signature, the
register_fake signature, and the CPU reference. In the reference,
apply top-k as a threshold on the (descending) sorted probabilities
before the existing top-p step, e.g.:

if top_k is not None:
    k = int(top_k.item())
    kth = s_probs[..., k - 1 : k]              # k-th largest prob per row
    scaled = torch.where(probs >= kth, scaled, scaled.new_tensor(float("-inf")))

Keep top-k and top-p composable: a token must pass both filters
(intersection of the two kept sets). The reference is the host contract
used by the distributional tests, so it should match the emitted graph's
semantics.

Part 3 — emission (backends/mlx/ops.py, _sample_handler)

  1. Widen the arity check at backends/mlx/ops.py:3531:
    require_args(args, 3, 4, ...)require_args(args, 3, 5, ...), and
    read top_k = args[4] if len(args) > 4 and args[4] is not None else None.

  2. In emit_sample, after sorted_p is computed
    (backends/mlx/ops.py:3627, descending-sorted probabilities), and
    only when top_k is not None:

    • thread k to a Vid via ItemIntNode (same pattern as seed at
      backends/mlx/ops.py:3556-3558);
    • get the k-th largest probability as the threshold with a TakeNode
      on sorted_p along axis=-1 at index k-1
      (TakeNode.index accepts a Vid);
    • emit LessNode(probs, thresh_k) -> drop_k and fold it into the
      existing drop mask before the WhereNode at
      backends/mlx/ops.py:3678-3686 (drop a token if it fails top-p
      or top-k).

    When top_k is None, emit nothing extra so the graph is byte-for-byte
    the current top-p-only graph.

    Detail: TakeNode needs index k-1. Either subtract one on the
    Vid, or use Partition/ArgPartition with kth = vocab - k. The
    TakeNode-on-sorted route is simplest for a first issue since
    sorted_p already exists; PartitionNode is the efficiency follow-up.

Part 4 — tests

Mirror the top_p tests one-for-one:

  • Eager (backends/mlx/test/test_sample.py): add top_k to the
    _sample helper, add test_top_k_restricts_to_top_k (e.g. probs
    [0.5, 0.3, 0.15, 0.05], top_k=2 ⇒ tokens ⊆ {0, 1}) and
    test_top_k_none_keeps_all (the tail token is reachable). Add a
    test_top_k_and_top_p_compose covering the intersection.
  • Lowering (backends/mlx/test/test_ops.py): add a TopKSampleModel
    and a SampleTopKTest(OpTestCase) with expected_node_counts updated
    for the new TakeNode (+1) and the extra mask node(s). Confirm
    SampleSeededTest / SampleUnseededTest / SampleTopPTest counts are
    unchanged (top-k off ⇒ no new nodes).
  • End-to-end (backends/mlx/test/test_sample.py): add
    test_top_k_end_to_end modeled on test_top_p_end_to_end:203-219,
    asserting the on-device token is within the top-k set.

Combined top-k + top-p semantics

When both are set, keep the intersection: a token survives only if it
is in the top-k and in the top-p nucleus. Equivalently, OR the two drop
masks before applying -inf. Make the reference impl and the emitted
graph agree, and cover it with test_top_k_and_top_p_compose.

Out of scope (follow-ups)

  • PartitionNode-based top-k (avoid the full sort when top-p is off).
    The primitive is already wired (MLXInterpreter.h:1817-1830); this is a
    pure performance optimization.
  • Per-request runtime k validation (e.g. k > vocab, k <= 0).
    Document the expected range; heavy validation is a separate concern.
  • Batched / multi-token sampling. Out of scope (matches the existing
    max_batch_size == 1 constraint).

Acceptance criteria

  • SamplingHead.forward accepts top_k (int or scalar int tensor or
    None) and no longer raises NotImplementedError; top_k=None
    keeps all tokens.
  • mlx::sample schema, register_fake, and CPU reference all take an
    optional top_k, with top-k applied as a threshold on the sorted
    probabilities and composable with top-p (intersection).
  • _sample_handler emits the top-k mask only when top_k is present
    (using ItemIntNode + TakeNode + Less/Where), with no
    schema.fbs / generate.py / MLXInterpreter.h changes.
  • With top_k=None, the emitted graph is identical to today's (the
    existing Sample*Test node counts are unchanged).
  • New eager, lowering (SampleTopKTest), and on-device
    (test_top_k_end_to_end) tests pass, including a combined
    top-k + top-p case.

Pointers

  • Reserved slot to implement: backends/mlx/llm/sampling.py:33-40.
  • Op def + CPU reference + fake (mirror top_p):
    backends/mlx/custom_ops.py:396-444.
  • Emission handler + exact insertion point (next to the top-p mask):
    backends/mlx/ops.py:3519-3711 (mask block 3615-3696; arity check
    3531; seedVid precedent 3556-3558).
  • Primitives already available (no C++ needed): TakeNode
    (serialization/mlx_graph_schema.py:217-221), PartitionNode /
    ArgPartitionNode (mlx_graph_schema.py:853-861, runtime
    MLXInterpreter.h:1817-1830).
  • Tests to mirror: backends/mlx/test/test_sample.py:132-143,
    :203-219; backends/mlx/test/test_ops.py:7729-7756.
  • Prior art: backends/mlx/llm/SAMPLING_GFI.md (the original
    sampling GFI; predates top_p/top_k).

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

Status
No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions