diff --git a/examples/apple/coreml/gemma4/BUCK b/examples/apple/coreml/gemma4/BUCK new file mode 100644 index 00000000000..3c3ffac248b --- /dev/null +++ b/examples/apple/coreml/gemma4/BUCK @@ -0,0 +1,24 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +fbcode_target(_kind = runtime.python_binary, + name = "export_gemma4_text_decoder_coreml", + srcs = [ + "export_gemma4_text_decoder_coreml.py", + ], + main_module = "executorch.examples.apple.coreml.gemma4.export_gemma4_text_decoder_coreml", + _is_external_target = True, + base_module = "executorch.examples.apple.coreml.gemma4", + visibility = ["PUBLIC"], + deps = [ + "//caffe2:torch", + "//executorch/backends/apple/coreml:backend", + "//executorch/backends/apple/coreml:partitioner", + "//executorch/examples/models/gemma4:text_decoder", + "//executorch/exir:lib", + "//executorch/extension/export_util:export_util", + ], +) diff --git a/examples/apple/coreml/gemma4/export_gemma4_text_decoder_coreml.py b/examples/apple/coreml/gemma4/export_gemma4_text_decoder_coreml.py new file mode 100644 index 00000000000..80540618805 --- /dev/null +++ b/examples/apple/coreml/gemma4/export_gemma4_text_decoder_coreml.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Export Gemma 4 text decoder to a CoreML-delegated ExecuTorch program. + +Gemma 4's hybrid sliding/full attention is structurally compatible with +CoreML's MLProgram backend: the existing Gemma4TextModel implementation +in ``examples/models/gemma4/text_decoder/`` lowers cleanly through +``torch.export`` and ``CoreMLPartitioner``. This script wraps that +pipeline with the CoreML-specific defaults (iOS18+ for stateful KV +caches, fp16, MQA-friendly mutable-buffer handling) so users do not +have to reassemble it themselves. + +Usage:: + + # From a HuggingFace checkpoint directory: + python export_gemma4_text_decoder_coreml.py \\ + --checkpoint_path /path/to/gemma4-e2b-it \\ + --output gemma4_text_decoder.pte + + # From a JSON config alone (random weights, smoke-test mode): + python export_gemma4_text_decoder_coreml.py \\ + --config_json /path/to/config.json --random_weights \\ + --max_seq_len 1024 --output gemma4_synthetic.pte + +The audio / vision encoders shipped with Gemma 4 are not part of this +export — for those the existing ``examples/models/gemma4`` ATen pipeline +is more appropriate. +""" + +import argparse +import json +import logging +import os +from typing import Optional, Tuple + +import coremltools as ct +import torch + +import executorch.exir +from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.examples.models.gemma4.text_decoder.gemma4_config import Gemma4Config +from executorch.examples.models.gemma4.text_decoder.gemma4_transformer import ( + Gemma4TextModel, +) +from executorch.exir import EdgeCompileConfig +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def _load_config( + checkpoint_path: Optional[str], + config_json: Optional[str], + max_seq_len: int, + sliding_window: Optional[int], + sliding_window_pattern: Optional[int], +) -> Gemma4Config: + """Build a Gemma4Config from a checkpoint dir, a JSON file, or defaults.""" + if checkpoint_path is not None: + config = Gemma4Config.from_json(os.path.join(checkpoint_path, "config.json")) + elif config_json is not None: + config = Gemma4Config.from_json(config_json) + else: + config = Gemma4Config() + + config.max_seq_len = max_seq_len + config.max_context_len = max_seq_len + if sliding_window is not None: + config.sliding_window = sliding_window + if sliding_window_pattern is not None: + config.sliding_window_pattern = sliding_window_pattern + return config + + +def _load_weights( + model: Gemma4TextModel, + config: Gemma4Config, + checkpoint_path: str, + dtype: torch.dtype, +) -> None: + """Load Gemma 4 text-decoder weights from a HuggingFace checkpoint dir. + + Reuses the same convert_weights flow that examples/models/gemma4 uses + so the loaded model exactly matches what ``examples/models/gemma4`` + would produce on the ATen path. + """ + from executorch.examples.models.gemma4.text_decoder.convert_weights import ( + convert_hf_to_custom, + ) + + state_dict = convert_hf_to_custom(checkpoint_path, config, dtype=dtype) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + if missing: + logger.warning( + "Missing %d keys when loading weights (first 5: %s)", + len(missing), + missing[:5], + ) + if unexpected: + logger.warning( + "Unexpected %d keys (first 5: %s)", len(unexpected), unexpected[:5] + ) + + +def build_model( + config: Gemma4Config, + checkpoint_path: Optional[str], + dtype: torch.dtype, +) -> Gemma4TextModel: + model = Gemma4TextModel(config).eval() + if checkpoint_path is not None: + _load_weights(model, config, checkpoint_path, dtype) + return model.to(dtype) + + +def _example_inputs(input_len: int) -> Tuple[torch.Tensor, ...]: + """Inputs for prefill: a single batch with `input_len` placeholder tokens.""" + return (torch.zeros(1, input_len, dtype=torch.long),) + + +def export( + model: Gemma4TextModel, + input_len: int, + minimum_deployment_target: ct.target, + compute_precision: ct.precision, + output_path: str, +) -> None: + """Run the Gemma 4 text-decoder model through to_edge_transform_and_lower.""" + example_inputs = _example_inputs(input_len) + + logger.info("Eager smoke-test (input_len=%d)...", input_len) + with torch.no_grad(): + model(*example_inputs) + + logger.info("torch.export...") + ep = torch.export.export(model, example_inputs, strict=False) + logger.info( + " exported program: %d nodes", + sum(1 for _ in ep.graph_module.graph.nodes), + ) + + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=minimum_deployment_target, + compute_precision=compute_precision, + compute_unit=ct.ComputeUnit.CPU_AND_NE, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, + ) + partitioner = CoreMLPartitioner( + compile_specs=compile_specs, + # Gemma 4's text decoder owns its KV caches as torch buffers; let + # CoreML take them over as iOS18+ stateful tensors. + take_over_mutable_buffer=True, + ) + + logger.info("to_edge_transform_and_lower with CoreMLPartitioner...") + edge = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[partitioner], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + fully_delegated = all( + node.op != "call_function" + or node.target.__name__ in ("executorch_call_delegate", "getitem") + for node in edge.exported_program().graph.nodes + ) + if fully_delegated: + logger.info(" fully delegated: every call_function is a CoreML call.") + else: + leftover = sorted( + { + node.target.__name__ + for node in edge.exported_program().graph.nodes + if node.op == "call_function" + and node.target.__name__ + not in ("executorch_call_delegate", "getitem") + } + ) + logger.warning( + " %d op type(s) fell back to portable: %s", + len(leftover), + leftover, + ) + + logger.info("to_executorch...") + program = edge.to_executorch( + ExecutorchBackendConfig(extract_delegate_segments=True) + ) + save_pte_program(program, output_path) + logger.info("Saved %s (%.2f MB)", output_path, os.path.getsize(output_path) / 1e6) + + +def main() -> int: + logging.basicConfig(level=logging.INFO, format="%(message)s") + + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--checkpoint_path", + type=str, + default=None, + help="Path to a HuggingFace Gemma 4 checkpoint directory.", + ) + parser.add_argument( + "--config_json", + type=str, + default=None, + help="Path to a Gemma 4 config.json (used if --checkpoint_path is omitted).", + ) + parser.add_argument( + "--random_weights", + action="store_true", + help="Skip checkpoint loading; use random weights (smoke-test only).", + ) + parser.add_argument( + "--output", + type=str, + default="gemma4_text_decoder.pte", + help="Output .pte path.", + ) + parser.add_argument("--max_seq_len", type=int, default=2048) + parser.add_argument( + "--input_len", + type=int, + default=64, + help="Prefill sequence length used to build example inputs for export.", + ) + parser.add_argument( + "--sliding_window", + type=int, + default=None, + help="Override the model's sliding window (default: from config).", + ) + parser.add_argument( + "--sliding_window_pattern", + type=int, + default=None, + help="Override the sliding/full attention pattern (default: from config).", + ) + parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp16") + parser.add_argument( + "--minimum_deployment_target", + type=str, + default="iOS18", + choices=["iOS17", "iOS18", "iOS26"], + help="Minimum CoreML deployment target. Stateful KV caches require iOS18+.", + ) + args = parser.parse_args() + + if args.random_weights and (args.checkpoint_path or args.config_json): + # Allow --random_weights with --config_json (for synthetic export); the + # combination with --checkpoint_path would be confusing because the + # checkpoint's config would be loaded but its weights ignored. + if args.checkpoint_path: + parser.error("--random_weights conflicts with --checkpoint_path") + if not args.random_weights and not args.checkpoint_path: + parser.error("either --checkpoint_path or --random_weights is required") + + config = _load_config( + checkpoint_path=args.checkpoint_path if not args.random_weights else None, + config_json=args.config_json, + max_seq_len=args.max_seq_len, + sliding_window=args.sliding_window, + sliding_window_pattern=args.sliding_window_pattern, + ) + + dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] + target = { + "iOS17": ct.target.iOS17, + "iOS18": ct.target.iOS18, + "iOS26": ct.target.iOS26, + }[args.minimum_deployment_target] + precision = {torch.float16: ct.precision.FLOAT16, torch.float32: ct.precision.FLOAT32}[dtype] + + logger.info("Gemma 4 text decoder export -> CoreML") + logger.info(" dtype=%s target=%s", args.dtype, args.minimum_deployment_target) + logger.info( + " layers=%d hidden=%d kv_heads=%d sliding_window=%d pattern=%d", + config.num_hidden_layers, + config.hidden_size, + config.num_key_value_heads, + config.sliding_window, + config.sliding_window_pattern, + ) + + model = build_model( + config, + checkpoint_path=args.checkpoint_path if not args.random_weights else None, + dtype=dtype, + ) + + export( + model, + input_len=args.input_len, + minimum_deployment_target=target, + compute_precision=precision, + output_path=args.output, + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/apple/coreml/gemma4/readme.md b/examples/apple/coreml/gemma4/readme.md new file mode 100644 index 00000000000..69956731a01 --- /dev/null +++ b/examples/apple/coreml/gemma4/readme.md @@ -0,0 +1,62 @@ +# Gemma 4 text decoder on CoreML + +This directory exports the Gemma 4 text decoder shipped with +`examples/models/gemma4` to a CoreML-delegated ExecuTorch program. + +Gemma 4's hybrid sliding/full attention, partial RoPE, per-layer +head_dim, MQA, and YOCO KV sharing are all expressed in plain PyTorch +in the upstream `examples/models/gemma4/text_decoder/` package, and that +implementation lowers cleanly through `torch.export` and +`CoreMLPartitioner` — every call is a single `executorch_call_delegate` +in the resulting `.pte`. This script assembles the small amount of +glue (CoreML compile specs, iOS18+ deployment target for stateful KV +caches, fp16 conversion) needed to run that lowering with sensible +defaults for on-device deployment. + +The audio and vision encoders are intentionally **not** exported here; +the existing ATen pipeline in `examples/models/gemma4` is more +appropriate for those. + +## Usage + +### From a HuggingFace checkpoint + +``` +python export_gemma4_text_decoder_coreml.py \ + --checkpoint_path /path/to/gemma4-e2b-it \ + --output gemma4_text_decoder.pte +``` + +### Synthetic config (smoke test, no weights) + +``` +python export_gemma4_text_decoder_coreml.py \ + --random_weights \ + --max_seq_len 1024 \ + --output /tmp/gemma4_synthetic.pte +``` + +## Options + +| Option | Default | Description | +|---|---|---| +| `--checkpoint_path` | (required if no `--random_weights`) | HuggingFace Gemma 4 checkpoint dir | +| `--config_json` | (off) | Use this `config.json` instead of the checkpoint's | +| `--random_weights` | (off) | Skip weight loading; smoke-test only | +| `--max_seq_len` | 2048 | Maximum context length | +| `--input_len` | 64 | Prefill seqlen used for example inputs | +| `--sliding_window` | (from config) | Override sliding-attention window | +| `--sliding_window_pattern` | (from config) | Override hybrid pattern (P=5 for Gemma 4 E2B) | +| `--dtype` | `fp16` | `fp16` or `fp32`. ANE requires fp16. | +| `--minimum_deployment_target` | `iOS18` | iOS17 / iOS18 / iOS26. Stateful KV caches need iOS18+. | + +## Tests + +`test.py` builds a 10-layer synthetic Gemma 4 model (4 sliding + 1 full +× 2) and runs the full export pipeline, asserting that the resulting +`.pte` exists and is non-empty: + +``` +$ python -m pytest examples/apple/coreml/gemma4/test.py -v +============================== 2 passed in 15.32s ============================== +``` diff --git a/examples/apple/coreml/gemma4/test.py b/examples/apple/coreml/gemma4/test.py new file mode 100644 index 00000000000..62adddc171a --- /dev/null +++ b/examples/apple/coreml/gemma4/test.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""End-to-end smoke tests for the Gemma 4 → CoreML export pipeline. + +These tests use a tiny synthetic Gemma 4 config (random weights, ~10 layers) +so they finish in seconds and do not need a HuggingFace checkpoint. They +verify the assertion that this export script makes: the existing +``examples/models/gemma4`` text-decoder lowers cleanly through +``CoreMLPartitioner`` with no portable fallbacks. +""" + +import os +import sys +import tempfile +import unittest + +import torch + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from export_gemma4_text_decoder_coreml import build_model, export # noqa: E402 + +import coremltools as ct # noqa: E402 + +from executorch.examples.models.gemma4.text_decoder.gemma4_config import ( # noqa: E402 + Gemma4Config, +) + + +def _tiny_config() -> Gemma4Config: + """Return a 10-layer synthetic Gemma 4 config. + + Matches the layer pattern Gemma 4 ships with (4 sliding + 1 full, + repeated twice) and Gemma 4's MQA / partial RoPE / per-layer head_dim + structure, just at much smaller dimensions. + """ + return Gemma4Config( + hidden_size=64, + num_hidden_layers=10, + num_attention_heads=4, + num_key_value_heads=1, + head_dim=16, + global_head_dim=32, + vocab_size=128, + intermediate_size=128, + sliding_window=64, + sliding_window_pattern=5, + layer_types=( + ["sliding_attention"] * 4 + + ["full_attention"] + + ["sliding_attention"] * 4 + + ["full_attention"] + ), + num_kv_shared_layers=4, + max_seq_len=128, + max_context_len=128, + hidden_size_per_layer_input=8, + vocab_size_per_layer_input=128, + ) + + +class TestGemma4CoreMLExport(unittest.TestCase): + def test_eager_forward_runs(self): + """The synthetic config produces a runnable Gemma4TextModel.""" + config = _tiny_config() + model = build_model(config, checkpoint_path=None, dtype=torch.float32) + with torch.no_grad(): + out = model(torch.zeros(1, 8, dtype=torch.long)) + self.assertEqual(out.shape, (1, 1, config.vocab_size)) + + def test_full_export_pipeline_lowers_to_coreml(self): + """Run the real export entry point and assert we got a fully delegated PTE.""" + config = _tiny_config() + # fp32 here — the on-device fp16 numerics path is exercised when the + # user passes --dtype fp16 to the CLI; this test is about the export + # plumbing, not numeric quality. + model = build_model(config, checkpoint_path=None, dtype=torch.float32) + + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, "tiny_gemma4.pte") + export( + model, + input_len=8, + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT32, + output_path=output_path, + ) + self.assertTrue(os.path.exists(output_path)) + self.assertGreater(os.path.getsize(output_path), 0) + + +if __name__ == "__main__": + unittest.main()