diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index a68dca083b3..96a91e5bdc1 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -147,12 +147,15 @@ def export_and_lower( output_dir: str, backend: str = "cuda", use_turboquant: bool = False, + sample: bool = False, ) -> None: """Export and lower the model to ExecuTorch for the given backend.""" if backend == "cuda": _export_cuda(model, config, output_dir, use_turboquant=use_turboquant) elif backend == "mlx": - _export_mlx(model, config, output_dir, use_turboquant=use_turboquant) + _export_mlx( + model, config, output_dir, use_turboquant=use_turboquant, sample=sample + ) else: raise ValueError( f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}." @@ -306,11 +309,30 @@ def _export_cuda( print("Done.") +class _MLXSampleWrapper(nn.Module): + """Wrap the model so ``forward`` returns a sampled token id. + + The MLX source transforms make ``forward`` return last-token logits + ``(B, vocab)``, so sample directly. Temperature, top_p, and seed are runtime + scalar inputs so the same .pte serves any sampling request; the runner + increments the seed per token. + """ + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, tokens, input_pos, temperature, top_p, seed): + logits = self.model(tokens, input_pos) + return torch.ops.mlx.sample(logits, temperature, top_p, seed) + + def _export_mlx( model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str, use_turboquant: bool = False, + sample: bool = False, ) -> None: """Export to .pte via torch.export + MLX backend. @@ -358,15 +380,28 @@ def _export_mlx( seq_dim = Dim("seq_len", min=1, max=max_prefill) + example_tokens = torch.tensor([[0, 1]], dtype=torch.long) + example_input_pos = torch.tensor([0, 1], dtype=torch.long) + if sample: + model = _MLXSampleWrapper(model) + example_args = ( + example_tokens, + example_input_pos, + torch.tensor(1.0, dtype=torch.float32), + torch.tensor(1.0, dtype=torch.float32), + torch.tensor(0, dtype=torch.int64), + ) + dynamic_shapes = ({1: seq_dim}, {0: seq_dim}, None, None, None) + else: + example_args = (example_tokens, example_input_pos) + dynamic_shapes = ({1: seq_dim}, {0: seq_dim}) + print(f"Exporting (T in [1, {max_prefill}])...") with torch.no_grad(): exported = export( model, - ( - torch.tensor([[0, 1]], dtype=torch.long), - torch.tensor([0, 1], dtype=torch.long), - ), - dynamic_shapes=({1: seq_dim}, {0: seq_dim}), + example_args, + dynamic_shapes=dynamic_shapes, strict=True, ) @@ -390,6 +425,7 @@ def _export_mlx( "use_kv_cache": True, "use_sdpa_with_kv_cache": False, "enable_dynamic_shape": True, + "use_sampling": sample, }, ) @@ -474,11 +510,21 @@ def main() -> None: "sliding layers keep their default cache. Supported on both " "--backend mlx and --backend cuda.", ) + parser.add_argument( + "--sample", + action="store_true", + help="MLX only: sample the next token on-device (Gumbel-max with " + "temperature/top_p/seed runtime inputs) instead of returning logits " + "for host-side sampling.", + ) args = parser.parse_args() if args.backend == "cuda" and not torch.cuda.is_available(): parser.error("CUDA is required for the cuda backend.") + if args.sample and args.backend != "mlx": + parser.error("--sample is only supported with --backend mlx") + if args.prequantized: model, config = load_prequantized_model( args.prequantized, @@ -505,6 +551,7 @@ def main() -> None: args.output_dir, backend=args.backend, use_turboquant=args.turboquant, + sample=args.sample, ) diff --git a/examples/models/gemma4_31b/gemma4_31b_engine.cpp b/examples/models/gemma4_31b/gemma4_31b_engine.cpp index 5813372abec..ac747556404 100644 --- a/examples/models/gemma4_31b/gemma4_31b_engine.cpp +++ b/examples/models/gemma4_31b/gemma4_31b_engine.cpp @@ -54,12 +54,15 @@ constexpr const char* kDecodeMethod = "decode"; constexpr const char* kMaxPrefillChunk = "get_max_prefill_chunk"; constexpr const char* kMinPrefillChunk = "get_min_prefill_chunk"; +constexpr const char* kUseSampling = "use_sampling"; Result read_sampled_token( const executorch::aten::Tensor& output, - float temperature) { + float temperature, + bool use_sampling) { #ifdef EXECUTORCH_BUILD_CUDA (void)temperature; + (void)use_sampling; const void* ptr = output.const_data_ptr(); cudaPointerAttributes attrs{}; const bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && @@ -98,6 +101,9 @@ Result read_sampled_token( static_cast(output.scalar_type())); return Error::InvalidArgument; #else + if (use_sampling) { + return static_cast(output.const_data_ptr()[0]); + } return static_cast( logits_to_token(output, temperature < 0.0f ? 0.0f : temperature)); #endif @@ -257,6 +263,17 @@ class Gemma4_31BSession : public LLMSession { auto temp_host = from_blob(&temp_val_, {1}, executorch::aten::ScalarType::Float); temp_tensor_dev_ = clone_tensor_ptr_to(temp_host, cuda_device_); +#endif +#ifdef EXECUTORCH_BUILD_MLX + if (auto it = metadata_.find(kUseSampling); it != metadata_.end()) { + use_sampling_ = it->second != 0; + } + temp_tensor_mlx_ = + from_blob(&temp_val_mlx_, {}, executorch::aten::ScalarType::Float); + top_p_tensor_ = + from_blob(&top_p_val_, {}, executorch::aten::ScalarType::Float); + seed_tensor_ = + from_blob(&seed_val_, {}, executorch::aten::ScalarType::Long); #endif } @@ -278,15 +295,27 @@ class Gemma4_31BSession : public LLMSession { } float first_token_temp = temperature_; if (initial_sampling != nullptr) { - if (initial_sampling->top_p != 1.0f || initial_sampling->top_k != 0 || - initial_sampling->seed != 0) { + if (initial_sampling->top_k != 0) { + ET_LOG(Error, "prefill_tokens: top_k is not implemented"); + return Error::NotSupported; + } + if (!use_sampling_ && + (initial_sampling->top_p != 1.0f || initial_sampling->seed != 0)) { ET_LOG( Error, - "Gemma4_31BSession: only temperature is supported; top_p/top_k/seed " - "are not implemented"); + "prefill_tokens: top_p/seed require a sampling model " + "(export with --sample); only temperature is supported otherwise"); return Error::NotSupported; } first_token_temp = initial_sampling->temperature; + if (use_sampling_) { + if (!valid_top_p(initial_sampling->top_p)) { + ET_LOG(Error, "prefill_tokens: top_p must be in (0, 1]"); + return Error::InvalidArgument; + } + top_p_ = initial_sampling->top_p; + seed_ = initial_sampling->seed; + } } if (!valid_temperature(first_token_temp)) { ET_LOG(Error, "prefill_tokens: temperature must be -1 or in [0, 2]"); @@ -326,15 +355,24 @@ class Gemma4_31BSession : public LLMSession { offset += chunk; } prev_decode_token_ = tokens.back(); +#ifdef EXECUTORCH_BUILD_MLX + if (use_sampling_) { + seed_ += 1; + } +#endif return Error::Ok; } Result decode_one(const SamplingConfig& sampling) override { - if (sampling.top_p != 1.0f || sampling.top_k != 0 || sampling.seed != 0) { + if (sampling.top_k != 0) { + ET_LOG(Error, "Gemma4_31BSession: top_k is not implemented"); + return Error::NotSupported; + } + if (!use_sampling_ && (sampling.top_p != 1.0f || sampling.seed != 0)) { ET_LOG( Error, - "Gemma4_31BSession: only temperature is supported; top_p/top_k/seed " - "are not implemented"); + "Gemma4_31BSession: top_p/seed require a sampling model " + "(export with --sample); only temperature is supported otherwise"); return Error::NotSupported; } if (!valid_temperature(sampling.temperature)) { @@ -346,6 +384,13 @@ class Gemma4_31BSession : public LLMSession { InvalidState, "decode_one requires a pending token; call prefill_tokens() first"); temperature_ = sampling.temperature; + if (use_sampling_) { + if (!valid_top_p(sampling.top_p)) { + ET_LOG(Error, "decode_one: top_p must be in (0, 1]"); + return Error::InvalidArgument; + } + top_p_ = sampling.top_p; + } if (stop_.load(std::memory_order_relaxed)) { return DecodeResult{0, "", /*is_eos=*/false, /*is_terminal=*/true}; @@ -393,6 +438,14 @@ class Gemma4_31BSession : public LLMSession { #else inputs.push_back(EValue(decode_tokens_)); inputs.push_back(EValue(decode_pos_)); +#ifdef EXECUTORCH_BUILD_MLX + if (use_sampling_) { + set_sampling_inputs(temperature_, top_p_, seed_); + inputs.push_back(EValue(temp_tensor_mlx_)); + inputs.push_back(EValue(top_p_tensor_)); + inputs.push_back(EValue(seed_tensor_)); + } +#endif #endif auto sampled = run_locked(kDecodeMethod, inputs, temperature_, /*sync_after=*/false); @@ -400,6 +453,11 @@ class Gemma4_31BSession : public LLMSession { pending_ = sampled.get(); prev_decode_token_ = token; pos_ += 1; +#ifdef EXECUTORCH_BUILD_MLX + if (use_sampling_) { + seed_ += 1; + } +#endif return DecodeResult{ token, std::move(text_piece), /*is_eos=*/false, /*is_terminal=*/false}; } @@ -425,6 +483,18 @@ class Gemma4_31BSession : public LLMSession { return temperature == -1.0f || (temperature >= 0.0f && temperature <= 2.0f); } + static bool valid_top_p(float top_p) { + return top_p > 0.0f && top_p <= 1.0f; + } + +#ifdef EXECUTORCH_BUILD_MLX + void set_sampling_inputs(float temp, float top_p, uint64_t seed) { + temp_val_mlx_ = (temp < 0.0f) ? 0.0f : temp; + top_p_val_ = top_p; + seed_val_ = static_cast(seed); + } +#endif + Result run_prefill_chunk(const uint64_t* tokens, int64_t T, float temperature) { std::vector token_data(tokens, tokens + T); @@ -457,6 +527,14 @@ class Gemma4_31BSession : public LLMSession { (T >= min_prefill_chunk_) ? kPrefillMethod : kDecodeMethod; #else const char* method = kPrefillMethod; +#endif +#ifdef EXECUTORCH_BUILD_MLX + if (use_sampling_) { + set_sampling_inputs(temperature, top_p_, seed_); + inputs.push_back(EValue(temp_tensor_mlx_)); + inputs.push_back(EValue(top_p_tensor_)); + inputs.push_back(EValue(seed_tensor_)); + } #endif return run_locked(method, inputs, temperature, /*sync_after=*/true); } @@ -560,7 +638,7 @@ class Gemma4_31BSession : public LLMSession { : module_->execute(method, inputs); ET_CHECK_OK_OR_RETURN_ERROR(res.error()); const auto& out_tensor = res.get()[0].toTensor(); - auto sampled = read_sampled_token(out_tensor, temperature); + auto sampled = read_sampled_token(out_tensor, temperature, use_sampling_); ET_CHECK_OK_OR_RETURN_ERROR(sampled.error()); #ifdef EXECUTORCH_BUILD_CUDA ET_CHECK_OK_OR_RETURN_ERROR( @@ -592,6 +670,10 @@ class Gemma4_31BSession : public LLMSession { float temperature_ = -1.0f; std::atomic stop_{false}; + bool use_sampling_ = false; + float top_p_ = 1.0f; + uint64_t seed_ = 0; + int64_t decode_token_data_[1] = {0}; int64_t decode_pos_data_[1] = {0}; TensorPtr decode_tokens_; @@ -609,6 +691,14 @@ class Gemma4_31BSession : public LLMSession { TensorPtr decode_pos_dev_; TensorPtr temp_tensor_dev_; #endif +#ifdef EXECUTORCH_BUILD_MLX + float temp_val_mlx_ = 0.0f; + float top_p_val_ = 1.0f; + int64_t seed_val_ = 0; + TensorPtr temp_tensor_mlx_; + TensorPtr top_p_tensor_; + TensorPtr seed_tensor_; +#endif }; } // namespace @@ -655,6 +745,12 @@ Result> Gemma4_31BEngine::create( metadata[kMaxPrefillChunk] = max_prefill_chunk; } +#ifdef EXECUTORCH_BUILD_MLX + if (auto get_result = meta_module->get(kUseSampling); get_result.ok()) { + metadata[kUseSampling] = get_result->toScalar().to(); + } +#endif + int64_t min_prefill_chunk = 1; #ifdef EXECUTORCH_BUILD_CUDA min_prefill_chunk = 5; diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index 081a85f2e12..61056528e73 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -49,6 +50,17 @@ DEFINE_string( "", "Path to file containing prompt text (overrides --prompt)."); DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy)."); +DEFINE_double( + top_p, + 1.0, + "Nucleus sampling top_p in (0, 1]; 1.0 = off. Requires a model exported " + "with --sample (MLX on-device sampling)."); +DEFINE_int64( + seed, + -1, + "Base RNG seed for on-device sampling; the runner increments it per token. " + "-1 (default) draws a random seed each run; set a value for reproducible " + "output. Requires a model exported with --sample."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2)."); DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); @@ -162,6 +174,20 @@ int main(int argc, char** argv) { llm::SamplingConfig sampling; sampling.temperature = static_cast(FLAGS_temperature); + sampling.top_p = static_cast(FLAGS_top_p); + // Only a --sample model uses the seed; randomize an unset seed for those and + // leave non-sample models at 0 so they don't trip the top_p/seed guard. + const auto& md = engine->metadata(); + const auto us_it = md.find("use_sampling"); + const bool model_samples = us_it != md.end() && us_it->second != 0; + uint64_t base_seed = FLAGS_seed < 0 ? 0 : static_cast(FLAGS_seed); + if (model_samples && FLAGS_seed < 0) { + base_seed = static_cast(std::random_device{}()); + } + sampling.seed = base_seed; + if (model_samples) { + printf("Sampling base seed: %" PRIu64 "\n", base_seed); + } stats.inference_start_ms = llm::time_in_ms(); if (session->prefill_tokens(prompt_tokens, &sampling) != Error::Ok) { ET_LOG(Error, "Prefill failed"); diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index b26e2783aa6..5647293198d 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -244,6 +244,54 @@ def test_export_to_pte(self): export_and_lower(model, config, out_dir, backend="mlx") self.assertTrue(os.path.exists(os.path.join(out_dir, "model.pte"))) + def test_export_to_pte_with_sampling(self): + """--sample export: forward returns a seed-reproducible int64 token.""" + try: + from executorch.backends.mlx import MLXPartitioner # noqa: F401 + except ImportError: + self.skipTest("MLX backend not available") + + from executorch.examples.models.gemma4_31b.export import ( + export_and_lower, + load_prequantized_model, + ) + from executorch.runtime import Runtime, Verification + + with tempfile.TemporaryDirectory() as ckpt_dir, tempfile.TemporaryDirectory() as out_dir: + save_checkpoint(ckpt_dir) + with open(os.path.join(ckpt_dir, "config.json"), "w") as f: + json.dump(config_dict(), f) + + model, config = load_prequantized_model( + ckpt_dir, max_seq_len=TINY_CONFIG.max_seq_len, backend="mlx" + ) + export_and_lower(model, config, out_dir, backend="mlx", sample=True) + pte = os.path.join(out_dir, "model.pte") + self.assertTrue(os.path.exists(pte)) + + program = Runtime.get().load_program(pte, verification=Verification.Minimal) + self.assertIn("use_sampling", program.method_names) + self.assertTrue(bool(program.load_method("use_sampling").execute([])[0])) + + forward = program.load_method("forward") + tokens = torch.tensor([[1, 2, 3, 4]], dtype=torch.long) + input_pos = torch.arange(4, dtype=torch.long) + + def sample(seed): + return forward.execute( + [ + tokens, + input_pos, + torch.tensor(0.8, dtype=torch.float32), + torch.tensor(0.9, dtype=torch.float32), + torch.tensor(seed, dtype=torch.int64), + ] + )[0] + + token = sample(7) + self.assertEqual(token.dtype, torch.int64) + self.assertTrue(torch.equal(token, sample(7))) # same seed reproducible + class TestGgufMlxPipeline(unittest.TestCase): """Test GGUF → MLX loading path with synthetic Q6_K-like tensors."""