Multi-backend refactor + MLX backend MVP; bump to 0.3a1#14
Merged
Conversation
* Introduce pyproject.toml defining package metadata and development dependencies. * Add core design documentation (DESIGN.md) explaining Sequence primitives. PiperPending-RevId: 923592270 PiperOrigin-RevId: 923592270
* Define backend-agnostic structural protocols and behavior tests in specs/. Co-authored-by: David Braun <davidbraun@google.com> Co-authored-by: Kehang Han <kehanghan@google.com> PiperPending-RevId: 923278026 PiperOrigin-RevId: 923278026
* Align Flax-based JAX implementations to inherit from specs protocols. * Align MLX-based implementations to inherit from specs protocols. * Delete obsolete mlx/basic_types.py. * Abstract and implement backend-agnostic array/nn operations (xp, nn). * Resolve JAX attention namespace collision by aligning imports. * Add complete multi-backend coding guides (AGENTS.md, evolved DESIGN.md). Co-authored-by: David Braun <davidbraun@google.com> Co-authored-by: Kehang Han <kehanghan@google.com> PiperPending-RevId: 923278027 PiperOrigin-RevId: 923278027
* Abstract testing structures into backend-agnostic specification tests. * Implement concrete, reusable testing utilities in specs/test_utils.py. * Inherit spec tests in JAX and MLX backend test suites for strict equivalence. Co-authored-by: David Braun <davidbraun@google.com> Co-authored-by: Kehang Han <kehanghan@google.com> PiperPending-RevId: 923278025 PiperOrigin-RevId: 923278025
Co-authored-by: David Braun <davidbraun@google.com> Co-authored-by: Kehang Han <kehanghan@google.com> PiperPending-RevId: 924617734 PiperOrigin-RevId: 924617734
…s for specs architecture Co-authored-by: David Braun <davidbraun@google.com> Co-authored-by: Kehang Han <kehanghan@google.com> PiperPending-RevId: 924617736 PiperOrigin-RevId: 924617736
…olerances Co-authored-by: David Braun <davidbraun@google.com> Co-authored-by: Kehang Han <kehanghan@google.com> PiperPending-RevId: 924617737 PiperOrigin-RevId: 924617737
…errides Co-authored-by: David Braun <davidbraun@google.com> Co-authored-by: Kehang Han <kehanghan@google.com> PiperPending-RevId: 924617731 PiperOrigin-RevId: 924617731
Co-authored-by: David Braun <davidbraun@google.com> Co-authored-by: Kehang Han <kehanghan@google.com> PiperPending-RevId: 924617735 PiperOrigin-RevId: 924617735
…foundation TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
…s to mx.array TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
- Ported JAX signal utility tests (hann, hamming, inv_stft) to MLX with SciPy parity. - Created MLX utils unit tests covering make_layer, latency, and delay. - Fixed make_layer to respect custom config.make() implementations (e.g. for gated units). - Enforced strict latency validation in MLX utils.get_output_latency to match JAX. - Included legacy JAX conditioning base class fix from previous rebase step. TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
…kward compatibility - Deleted `AttentionInputProjectionHelper` from JAX attention `common.py` to purge dead namespace-flattening helper code. - Stripped `AttentionInputProjectionHelper` from the inheritance base lists of all JAX self and cross attention layers, ensuring strict backward compatibility of Flax parameter PyTree namespaces with the `main` branch. TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
…g and position layers - Restated all inherited config dataclass fields in `Conditioning.Config` (in `conditioning.py`) and `AddTimingSignal.Config`, `ApplyRotaryPositionalEncoding.Config` (in `position.py`). - This ensures uniform coding style across JAX configurations, improves self-documentation and IDE autocomplete, and removes any risk of implicit dataclass inheritance behavior. TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
rryan
approved these changes
Jun 1, 2026
…age roots TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
…ff with main TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
mlx.nn.quantize only quantizes modules that define a to_quantized() method. The MLX EinsumDense layer — used for the attention head-combining output projection (equation '...nh,dnh->...d') — had no to_quantized(), so under nn.quantize it silently stayed full-precision (bf16) while the rest of the model was int8-quantized. That makes exported models larger and slower and changes numerics versus an all-int8 model. Add to_quantized() for the '...nh,dnh->...d' equation: flatten the [d, n, h] kernel to [d, n*h], mx.quantize it, and rebind layer() to flatten the [..., n, h] input and use mx.quantized_matmul. Other equations are returned unchanged.
…SelfAttention Quantizes attention layers when using combined projections (CombinedQueryKeyValueProjection layout, which is the default in mrt2 samplers). Splits the combined bias into q_bias and kv_bias to remain fully compatible with downstream evaluation functions. TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
Refactors SerialCombinatorMixin and its subclasses (Serial, SerialModules) to use a public `mlx_layers` backing attribute instead of dynamic `setattr` loops or private `_layers` lists. This is required because MLX nn.Module only tracks submodules that are stored in public attributes (without a leading underscore) for parameter collection. Since `layers` is a read-only property in the shared spec, we use `mlx_layers` as the backing attribute and have the mixin's `layers` property return it. Also updates the JAX-to-MLX weight converter to use index-based access (`layers[i]`) instead of dynamic attribute lookup. TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This is an MVP of the multi-backend SequenceLayers refactor, demonstrated through adding MLX support plus unified tests and signatures.
Also added affordances for agentic code (AGENTS.md, DESIGN.md).
NOTE: This part of the commit history will be overwritten again before the final v0.3 after some cleaning up. However, the APIs for JAX should stay fixed and MLX mostly fixed.