Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions examples/models/gemma4_31b/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,15 @@ def export_and_lower(
output_dir: str,
backend: str = "cuda",
use_turboquant: bool = False,
sample: bool = False,
) -> None:
"""Export and lower the model to ExecuTorch for the given backend."""
if backend == "cuda":
_export_cuda(model, config, output_dir, use_turboquant=use_turboquant)
elif backend == "mlx":
_export_mlx(model, config, output_dir, use_turboquant=use_turboquant)
_export_mlx(
model, config, output_dir, use_turboquant=use_turboquant, sample=sample
)
else:
raise ValueError(
f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}."
Expand Down Expand Up @@ -306,11 +309,30 @@ def _export_cuda(
print("Done.")


class _MLXSampleWrapper(nn.Module):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally SamplingHead from backends/mlx/llm would be directly usable. The fact that we keep defining wrappers suggest that maybe it should be refactored to work?

In terms of shape, I expect sampling to work on (B, vocab) and return (B) (tokens).

I don't see a case for (B, S, vocab).

"""Wrap the model so ``forward`` returns a sampled token id.

The MLX source transforms make ``forward`` return last-token logits
``(B, vocab)``, so sample directly. Temperature, top_p, and seed are runtime
scalar inputs so the same .pte serves any sampling request; the runner
increments the seed per token.
"""

def __init__(self, model: nn.Module):
super().__init__()
self.model = model

def forward(self, tokens, input_pos, temperature, top_p, seed):
logits = self.model(tokens, input_pos)
return torch.ops.mlx.sample(logits, temperature, top_p, seed)


def _export_mlx(
model: Gemma4_31B,
config: Gemma4_31BConfig,
output_dir: str,
use_turboquant: bool = False,
sample: bool = False,
) -> None:
"""Export to .pte via torch.export + MLX backend.

Expand Down Expand Up @@ -358,15 +380,28 @@ def _export_mlx(

seq_dim = Dim("seq_len", min=1, max=max_prefill)

example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
if sample:
model = _MLXSampleWrapper(model)
example_args = (
example_tokens,
example_input_pos,
torch.tensor(1.0, dtype=torch.float32),
torch.tensor(1.0, dtype=torch.float32),
torch.tensor(0, dtype=torch.int64),
)
dynamic_shapes = ({1: seq_dim}, {0: seq_dim}, None, None, None)
else:
example_args = (example_tokens, example_input_pos)
dynamic_shapes = ({1: seq_dim}, {0: seq_dim})

print(f"Exporting (T in [1, {max_prefill}])...")
with torch.no_grad():
exported = export(
model,
(
torch.tensor([[0, 1]], dtype=torch.long),
torch.tensor([0, 1], dtype=torch.long),
),
dynamic_shapes=({1: seq_dim}, {0: seq_dim}),
example_args,
dynamic_shapes=dynamic_shapes,
strict=True,
)

Expand All @@ -390,6 +425,7 @@ def _export_mlx(
"use_kv_cache": True,
"use_sdpa_with_kv_cache": False,
"enable_dynamic_shape": True,
"use_sampling": sample,
},
)

Expand Down Expand Up @@ -474,11 +510,21 @@ def main() -> None:
"sliding layers keep their default cache. Supported on both "
"--backend mlx and --backend cuda.",
)
parser.add_argument(
"--sample",
action="store_true",
help="MLX only: sample the next token on-device (Gumbel-max with "
"temperature/top_p/seed runtime inputs) instead of returning logits "
"for host-side sampling.",
)
args = parser.parse_args()

if args.backend == "cuda" and not torch.cuda.is_available():
parser.error("CUDA is required for the cuda backend.")

if args.sample and args.backend != "mlx":
parser.error("--sample is only supported with --backend mlx")

if args.prequantized:
model, config = load_prequantized_model(
args.prequantized,
Expand All @@ -505,6 +551,7 @@ def main() -> None:
args.output_dir,
backend=args.backend,
use_turboquant=args.turboquant,
sample=args.sample,
)


Expand Down
114 changes: 105 additions & 9 deletions examples/models/gemma4_31b/gemma4_31b_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,15 @@ constexpr const char* kDecodeMethod = "decode";

constexpr const char* kMaxPrefillChunk = "get_max_prefill_chunk";
constexpr const char* kMinPrefillChunk = "get_min_prefill_chunk";
constexpr const char* kUseSampling = "use_sampling";

Result<uint64_t> read_sampled_token(
const executorch::aten::Tensor& output,
float temperature) {
float temperature,
bool use_sampling) {
#ifdef EXECUTORCH_BUILD_CUDA
(void)temperature;
(void)use_sampling;
const void* ptr = output.const_data_ptr();
cudaPointerAttributes attrs{};
const bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess &&
Expand Down Expand Up @@ -98,6 +101,9 @@ Result<uint64_t> read_sampled_token(
static_cast<int>(output.scalar_type()));
return Error::InvalidArgument;
#else
if (use_sampling) {
return static_cast<uint64_t>(output.const_data_ptr<int64_t>()[0]);
}
return static_cast<uint64_t>(
logits_to_token(output, temperature < 0.0f ? 0.0f : temperature));
#endif
Expand Down Expand Up @@ -257,6 +263,17 @@ class Gemma4_31BSession : public LLMSession {
auto temp_host =
from_blob(&temp_val_, {1}, executorch::aten::ScalarType::Float);
temp_tensor_dev_ = clone_tensor_ptr_to(temp_host, cuda_device_);
#endif
#ifdef EXECUTORCH_BUILD_MLX
if (auto it = metadata_.find(kUseSampling); it != metadata_.end()) {
use_sampling_ = it->second != 0;
}
temp_tensor_mlx_ =
from_blob(&temp_val_mlx_, {}, executorch::aten::ScalarType::Float);
top_p_tensor_ =
from_blob(&top_p_val_, {}, executorch::aten::ScalarType::Float);
seed_tensor_ =
from_blob(&seed_val_, {}, executorch::aten::ScalarType::Long);
#endif
}

Expand All @@ -278,15 +295,27 @@ class Gemma4_31BSession : public LLMSession {
}
float first_token_temp = temperature_;
if (initial_sampling != nullptr) {
if (initial_sampling->top_p != 1.0f || initial_sampling->top_k != 0 ||
initial_sampling->seed != 0) {
if (initial_sampling->top_k != 0) {
ET_LOG(Error, "prefill_tokens: top_k is not implemented");
return Error::NotSupported;
}
if (!use_sampling_ &&
(initial_sampling->top_p != 1.0f || initial_sampling->seed != 0)) {
ET_LOG(
Error,
"Gemma4_31BSession: only temperature is supported; top_p/top_k/seed "
"are not implemented");
"prefill_tokens: top_p/seed require a sampling model "
"(export with --sample); only temperature is supported otherwise");
return Error::NotSupported;
}
first_token_temp = initial_sampling->temperature;
if (use_sampling_) {
if (!valid_top_p(initial_sampling->top_p)) {
ET_LOG(Error, "prefill_tokens: top_p must be in (0, 1]");
return Error::InvalidArgument;
}
top_p_ = initial_sampling->top_p;
seed_ = initial_sampling->seed;
}
}
if (!valid_temperature(first_token_temp)) {
ET_LOG(Error, "prefill_tokens: temperature must be -1 or in [0, 2]");
Expand Down Expand Up @@ -326,15 +355,24 @@ class Gemma4_31BSession : public LLMSession {
offset += chunk;
}
prev_decode_token_ = tokens.back();
#ifdef EXECUTORCH_BUILD_MLX
if (use_sampling_) {
seed_ += 1;
}
#endif
return Error::Ok;
}

Result<DecodeResult> decode_one(const SamplingConfig& sampling) override {
if (sampling.top_p != 1.0f || sampling.top_k != 0 || sampling.seed != 0) {
if (sampling.top_k != 0) {
ET_LOG(Error, "Gemma4_31BSession: top_k is not implemented");
return Error::NotSupported;
}
if (!use_sampling_ && (sampling.top_p != 1.0f || sampling.seed != 0)) {
ET_LOG(
Error,
"Gemma4_31BSession: only temperature is supported; top_p/top_k/seed "
"are not implemented");
"Gemma4_31BSession: top_p/seed require a sampling model "
"(export with --sample); only temperature is supported otherwise");
return Error::NotSupported;
}
if (!valid_temperature(sampling.temperature)) {
Expand All @@ -346,6 +384,13 @@ class Gemma4_31BSession : public LLMSession {
InvalidState,
"decode_one requires a pending token; call prefill_tokens() first");
temperature_ = sampling.temperature;
if (use_sampling_) {
if (!valid_top_p(sampling.top_p)) {
ET_LOG(Error, "decode_one: top_p must be in (0, 1]");
return Error::InvalidArgument;
}
top_p_ = sampling.top_p;
}

if (stop_.load(std::memory_order_relaxed)) {
return DecodeResult{0, "", /*is_eos=*/false, /*is_terminal=*/true};
Expand Down Expand Up @@ -393,13 +438,26 @@ class Gemma4_31BSession : public LLMSession {
#else
inputs.push_back(EValue(decode_tokens_));
inputs.push_back(EValue(decode_pos_));
#ifdef EXECUTORCH_BUILD_MLX
if (use_sampling_) {
set_sampling_inputs(temperature_, top_p_, seed_);
inputs.push_back(EValue(temp_tensor_mlx_));
inputs.push_back(EValue(top_p_tensor_));
inputs.push_back(EValue(seed_tensor_));
}
#endif
#endif
auto sampled =
run_locked(kDecodeMethod, inputs, temperature_, /*sync_after=*/false);
ET_CHECK_OK_OR_RETURN_ERROR(sampled.error());
pending_ = sampled.get();
prev_decode_token_ = token;
pos_ += 1;
#ifdef EXECUTORCH_BUILD_MLX
if (use_sampling_) {
seed_ += 1;
}
#endif
return DecodeResult{
token, std::move(text_piece), /*is_eos=*/false, /*is_terminal=*/false};
}
Expand All @@ -425,6 +483,18 @@ class Gemma4_31BSession : public LLMSession {
return temperature == -1.0f || (temperature >= 0.0f && temperature <= 2.0f);
}

static bool valid_top_p(float top_p) {
return top_p > 0.0f && top_p <= 1.0f;
}

#ifdef EXECUTORCH_BUILD_MLX
void set_sampling_inputs(float temp, float top_p, uint64_t seed) {
temp_val_mlx_ = (temp < 0.0f) ? 0.0f : temp;
top_p_val_ = top_p;
seed_val_ = static_cast<int64_t>(seed);
}
#endif

Result<uint64_t>
run_prefill_chunk(const uint64_t* tokens, int64_t T, float temperature) {
std::vector<int64_t> token_data(tokens, tokens + T);
Expand Down Expand Up @@ -457,6 +527,14 @@ class Gemma4_31BSession : public LLMSession {
(T >= min_prefill_chunk_) ? kPrefillMethod : kDecodeMethod;
#else
const char* method = kPrefillMethod;
#endif
#ifdef EXECUTORCH_BUILD_MLX
if (use_sampling_) {
set_sampling_inputs(temperature, top_p_, seed_);
inputs.push_back(EValue(temp_tensor_mlx_));
inputs.push_back(EValue(top_p_tensor_));
inputs.push_back(EValue(seed_tensor_));
}
#endif
return run_locked(method, inputs, temperature, /*sync_after=*/true);
}
Expand Down Expand Up @@ -560,7 +638,7 @@ class Gemma4_31BSession : public LLMSession {
: module_->execute(method, inputs);
ET_CHECK_OK_OR_RETURN_ERROR(res.error());
const auto& out_tensor = res.get()[0].toTensor();
auto sampled = read_sampled_token(out_tensor, temperature);
auto sampled = read_sampled_token(out_tensor, temperature, use_sampling_);
ET_CHECK_OK_OR_RETURN_ERROR(sampled.error());
#ifdef EXECUTORCH_BUILD_CUDA
ET_CHECK_OK_OR_RETURN_ERROR(
Expand Down Expand Up @@ -592,6 +670,10 @@ class Gemma4_31BSession : public LLMSession {
float temperature_ = -1.0f;
std::atomic<bool> stop_{false};

bool use_sampling_ = false;
float top_p_ = 1.0f;
uint64_t seed_ = 0;

int64_t decode_token_data_[1] = {0};
int64_t decode_pos_data_[1] = {0};
TensorPtr decode_tokens_;
Expand All @@ -609,6 +691,14 @@ class Gemma4_31BSession : public LLMSession {
TensorPtr decode_pos_dev_;
TensorPtr temp_tensor_dev_;
#endif
#ifdef EXECUTORCH_BUILD_MLX
float temp_val_mlx_ = 0.0f;
float top_p_val_ = 1.0f;
int64_t seed_val_ = 0;
TensorPtr temp_tensor_mlx_;
TensorPtr top_p_tensor_;
TensorPtr seed_tensor_;
#endif
};

} // namespace
Expand Down Expand Up @@ -655,6 +745,12 @@ Result<std::unique_ptr<Gemma4_31BEngine>> Gemma4_31BEngine::create(
metadata[kMaxPrefillChunk] = max_prefill_chunk;
}

#ifdef EXECUTORCH_BUILD_MLX
if (auto get_result = meta_module->get(kUseSampling); get_result.ok()) {
metadata[kUseSampling] = get_result->toScalar().to<int64_t>();
}
#endif

int64_t min_prefill_chunk = 1;
#ifdef EXECUTORCH_BUILD_CUDA
min_prefill_chunk = 5;
Expand Down
26 changes: 26 additions & 0 deletions examples/models/gemma4_31b/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cstdio>
#include <fstream>
#include <optional>
#include <random>
#include <string>
#include <vector>

Expand Down Expand Up @@ -49,6 +50,17 @@ DEFINE_string(
"",
"Path to file containing prompt text (overrides --prompt).");
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy).");
DEFINE_double(
top_p,
1.0,
"Nucleus sampling top_p in (0, 1]; 1.0 = off. Requires a model exported "
"with --sample (MLX on-device sampling).");
DEFINE_int64(
seed,
-1,
"Base RNG seed for on-device sampling; the runner increments it per token. "
"-1 (default) draws a random seed each run; set a value for reproducible "
"output. Requires a model exported with --sample.");
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2).");
DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1).");
Expand Down Expand Up @@ -162,6 +174,20 @@ int main(int argc, char** argv) {

llm::SamplingConfig sampling;
sampling.temperature = static_cast<float>(FLAGS_temperature);
sampling.top_p = static_cast<float>(FLAGS_top_p);
// Only a --sample model uses the seed; randomize an unset seed for those and
// leave non-sample models at 0 so they don't trip the top_p/seed guard.
const auto& md = engine->metadata();
const auto us_it = md.find("use_sampling");
const bool model_samples = us_it != md.end() && us_it->second != 0;
uint64_t base_seed = FLAGS_seed < 0 ? 0 : static_cast<uint64_t>(FLAGS_seed);
if (model_samples && FLAGS_seed < 0) {
base_seed = static_cast<uint64_t>(std::random_device{}());
}
sampling.seed = base_seed;
if (model_samples) {
printf("Sampling base seed: %" PRIu64 "\n", base_seed);
}
stats.inference_start_ms = llm::time_in_ms();
if (session->prefill_tokens(prompt_tokens, &sampling) != Error::Ok) {
ET_LOG(Error, "Prefill failed");
Expand Down
Loading
Loading