fix: Metal GPU compatibility for Apple Silicon (jax-metal 0.1.1)#207
Open
elmariachi111 wants to merge 1 commit into
Open
fix: Metal GPU compatibility for Apple Silicon (jax-metal 0.1.1)#207elmariachi111 wants to merge 1 commit into
elmariachi111 wants to merge 1 commit into
Conversation
jnp.linalg.eigh is not implemented on the Metal platform in jax-metal 0.1.1, blocking the AF2 structure module (quat_affine.py) on Apple Silicon Macs. jnp.linalg.svd internally calls eigh and is also blocked. Fix 1 (colabdesign/af/alphafold/model/quat_affine.py): Replace jnp.linalg.eigh in rot_to_quat() with power iteration (50 iter fori_loop). The dominant eigenvector of the 4x4 symmetric K matrix is found iteratively; validated to <1e-7 error vs numpy eigh on 20 random rotation matrices. Canonical sign convention applied so the largest-magnitude component is always positive. Fix 2 (colabdesign/shared/protein.py): Add _metal_safe_svd(): a power iteration SVD on A^T A with eigenvalue deflation for the 2nd vector and cross product for the 3rd (valid for the 3x3 Kabsch input). Validated to <2e-4 error vs numpy SVD on random 3x3 matrices (60-iteration power method). Replaces jnp.linalg.svd in _np_kabsch() when use_jax=True and wraps the result in jax.lax.stop_gradient to prevent gradient flow through fori_loop (which causes a separate jax-metal compiler bug). Tested on: macOS 26.3 ARM64, jax==0.5.0 / jaxlib==0.5.0 / jax-metal==0.1.1 / Python 3.10 Enables AF2 forward inference on Apple Silicon Metal. Note: full backprop (value_and_grad + haiku RNG) is blocked by a separate unresolved jax-metal compiler bug and is not addressed here. References: https://github.com/cytokineking/FreeBindCraft (Apple Silicon porting notes)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two patches that make ColabDesign's AlphaFold2 forward pass run on Apple Silicon Macs using
jax-metal 0.1.1.jnp.linalg.eighandjnp.linalg.svdare not implemented for the Metal platform injax-metal 0.1.1, blocking the AF2 structure module entirely on macOS ARM64.Patch 1 —
colabdesign/af/alphafold/model/quat_affine.pyjnp.linalg.eighis called inrot_to_quat()to find the largest eigenvector of a 4×4 symmetric matrix (the Shepperd K matrix → quaternion conversion). On Metal this raises:Fix: Replace with a 50-iteration power iteration using
jax.lax.fori_loop(only requires matmul + norm, both of which work on Metal). Validated to <1e-7 error vsnumpy.linalg.eighon 20 random rotation matrices.Patch 2 —
colabdesign/shared/protein.pyjnp.linalg.svdis called in_np_kabsch()for Kabsch structural alignment. On Metal,jnp.linalg.svdinternally lowers to theeighprimitive and fails with the same error.Fix: Add
_metal_safe_svd(A)— SVD via power iteration on AᵀA with deflation (3 eigenvectors via power method + cross product), using onlyjax.lax.fori_loop+ matmul + norm. Validated to <2e-4 error vsnumpy.linalg.svdon random 3×3 matrices. The result is wrapped injax.lax.stop_gradient(standard practice for Kabsch alignment, prevents Metal from attempting to differentiate through thefori_loop).What this enables / what remains blocked
AF2 forward pass / structure prediction: Works on Metal. pLDDT / pTM confidence scoring: Works on Metal. Full binder hallucination (value_and_grad): Blocked by separate jax-metal compiler bug.
Test environment
jax==0.5.0/jaxlib==0.5.0/jax-metal==0.1.1