Skip to content

Multi-backend refactor + MLX backend MVP; bump to 0.3a1#14

Merged
JulianSlzr merged 29 commits into
mainfrom
mlx
Jun 4, 2026
Merged

Multi-backend refactor + MLX backend MVP; bump to 0.3a1#14
JulianSlzr merged 29 commits into
mainfrom
mlx

Conversation

@JulianSlzr
Copy link
Copy Markdown
Collaborator

@JulianSlzr JulianSlzr commented Jun 1, 2026

This is an MVP of the multi-backend SequenceLayers refactor, demonstrated through adding MLX support plus unified tests and signatures.

Also added affordances for agentic code (AGENTS.md, DESIGN.md).

NOTE: This part of the commit history will be overwritten again before the final v0.3 after some cleaning up. However, the APIs for JAX should stay fixed and MLX mostly fixed.

JulianSlzr added 13 commits June 1, 2026 20:46
* Introduce pyproject.toml defining package metadata and development dependencies.
* Add core design documentation (DESIGN.md) explaining Sequence primitives.

PiperPending-RevId: 923592270
PiperOrigin-RevId: 923592270
* Define backend-agnostic structural protocols and behavior tests in specs/.

Co-authored-by: David Braun <davidbraun@google.com>
Co-authored-by: Kehang Han <kehanghan@google.com>

PiperPending-RevId: 923278026
PiperOrigin-RevId: 923278026
* Align Flax-based JAX implementations to inherit from specs protocols.
* Align MLX-based implementations to inherit from specs protocols.
* Delete obsolete mlx/basic_types.py.
* Abstract and implement backend-agnostic array/nn operations (xp, nn).
* Resolve JAX attention namespace collision by aligning imports.
* Add complete multi-backend coding guides (AGENTS.md, evolved DESIGN.md).

Co-authored-by: David Braun <davidbraun@google.com>
Co-authored-by: Kehang Han <kehanghan@google.com>

PiperPending-RevId: 923278027
PiperOrigin-RevId: 923278027
* Abstract testing structures into backend-agnostic specification tests.
* Implement concrete, reusable testing utilities in specs/test_utils.py.
* Inherit spec tests in JAX and MLX backend test suites for strict equivalence.

Co-authored-by: David Braun <davidbraun@google.com>
Co-authored-by: Kehang Han <kehanghan@google.com>

PiperPending-RevId: 923278025
PiperOrigin-RevId: 923278025
Co-authored-by: David Braun <davidbraun@google.com>
Co-authored-by: Kehang Han <kehanghan@google.com>

PiperPending-RevId: 924617734
PiperOrigin-RevId: 924617734
…s for specs architecture

Co-authored-by: David Braun <davidbraun@google.com>
Co-authored-by: Kehang Han <kehanghan@google.com>

PiperPending-RevId: 924617736
PiperOrigin-RevId: 924617736
…olerances

Co-authored-by: David Braun <davidbraun@google.com>
Co-authored-by: Kehang Han <kehanghan@google.com>

PiperPending-RevId: 924617737
PiperOrigin-RevId: 924617737
…errides

Co-authored-by: David Braun <davidbraun@google.com>
Co-authored-by: Kehang Han <kehanghan@google.com>

PiperPending-RevId: 924617731
PiperOrigin-RevId: 924617731
Co-authored-by: David Braun <davidbraun@google.com>
Co-authored-by: Kehang Han <kehanghan@google.com>

PiperPending-RevId: 924617735
PiperOrigin-RevId: 924617735
…foundation

TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
…s to mx.array

TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
- Ported JAX signal utility tests (hann, hamming, inv_stft) to MLX with SciPy parity.
- Created MLX utils unit tests covering make_layer, latency, and delay.
- Fixed make_layer to respect custom config.make() implementations (e.g. for gated units).
- Enforced strict latency validation in MLX utils.get_output_latency to match JAX.
- Included legacy JAX conditioning base class fix from previous rebase step.

TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
@JulianSlzr JulianSlzr requested a review from rryan June 1, 2026 21:01
…kward compatibility

- Deleted `AttentionInputProjectionHelper` from JAX attention `common.py` to purge dead namespace-flattening helper code.
- Stripped `AttentionInputProjectionHelper` from the inheritance base lists of all JAX self and cross attention layers, ensuring strict backward compatibility of Flax parameter PyTree namespaces with the `main` branch.

TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
…g and position layers

- Restated all inherited config dataclass fields in `Conditioning.Config` (in `conditioning.py`) and `AddTimingSignal.Config`, `ApplyRotaryPositionalEncoding.Config` (in `position.py`).
- This ensures uniform coding style across JAX configurations, improves self-documentation and IDE autocomplete, and removes any risk of implicit dataclass inheritance behavior.

TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
Copy link
Copy Markdown
Collaborator

@rryan rryan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome work!

Comment thread sequence_layers/jax/attention/common.py Outdated
Comment thread sequence_layers/jax/__init__.py Outdated
Comment thread sequence_layers/jax/__init__.py Outdated
Comment thread sequence_layers/jax/conditioning.py
Comment thread sequence_layers/jax/dsp_test.py
Comment thread sequence_layers/jax/simple.py Outdated
Comment thread sequence_layers/jax/simple.py Outdated
Comment thread sequence_layers/jax/simple.py Outdated
Comment thread sequence_layers/jax/simple.py Outdated
Comment thread sequence_layers/jax/__init__.py Outdated
…age roots

TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
…ff with main

TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
DBraun and others added 3 commits June 2, 2026 14:49
mlx.nn.quantize only quantizes modules that define a to_quantized() method.
The MLX EinsumDense layer — used for the attention head-combining output
projection (equation '...nh,dnh->...d') — had no to_quantized(), so under
nn.quantize it silently stayed full-precision (bf16) while the rest of the
model was int8-quantized. That makes exported models larger and slower and
changes numerics versus an all-int8 model.

Add to_quantized() for the '...nh,dnh->...d' equation: flatten the [d, n, h]
kernel to [d, n*h], mx.quantize it, and rebind layer() to flatten the
[..., n, h] input and use mx.quantized_matmul. Other equations are returned
unchanged.
…SelfAttention

Quantizes attention layers when using combined projections (CombinedQueryKeyValueProjection layout, which is the default in mrt2 samplers). Splits the combined bias into q_bias and kv_bias to remain fully compatible with downstream evaluation functions.

TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
Refactors SerialCombinatorMixin and its subclasses (Serial, SerialModules) to use a public `mlx_layers` backing attribute instead of dynamic `setattr` loops or private `_layers` lists.

This is required because MLX nn.Module only tracks submodules that are stored in public attributes (without a leading underscore) for parameter collection. Since `layers` is a read-only property in the shared spec, we use `mlx_layers` as the backing attribute and have the mixin's `layers` property return it.

Also updates the JAX-to-MLX weight converter to use index-based access (`layers[i]`) instead of dynamic attribute lookup.

TAG=agy
CONV=21ada17b-3411-4090-8450-e69d8ebfeae6
@JulianSlzr JulianSlzr changed the title Multi-backend refactor + MLX backend MVP; bump to 0.3.0rc1 Multi-backend refactor + MLX backend MVP; bump to 0.3a1 Jun 4, 2026
@JulianSlzr JulianSlzr merged commit 0f3ef62 into main Jun 4, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants