From 08dccd6422c243c6dfb15db226bd7b3f3a6fb7d5 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 15 Jun 2026 10:47:31 +0200 Subject: [PATCH] Fix eigvals gradient for symmetric matrices (FluxML/Zygote#1369) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `rrule(eigen, A::StridedMatrix)` reused the symmetric-manifold cotangent convention whenever `A` happened to be Hermitian: it projected the cotangent onto the stored (upper) triangle via `_symherm_back`/`triu!`, zeroing the other triangle. For a plain matrix whose entries are all independent free variables this is wrong — it disagrees with ForwardDiff/finite differences and makes the gradient discontinuous as `A` crosses exact symmetry. Concretely, `jacobian(eigvals, A)` returned an erroneous all-zero column for an exactly symmetric `A`, while a matrix one ULP away gave the correct split gradient. `eigen_rev!` already produces (via `_hermitrizelike!`) a cotangent with the off-diagonal eigenvalue sensitivity split evenly across both triangles, so we simply materialise that full matrix instead of collapsing it onto one triangle. Scope: real matrices, eigenvalues only (`T <: Real && ΔV isa AbstractZero`) — i.e. the `eigvals` path. The eigenvector phase convention differs between the symmetric and general algorithms, and the complex-Hermitian case cannot be pinned down against FiniteDifferences (eigenvalues leave the reals), so those paths are unchanged. Updated the dense "hermitian" eigvals/eigen tests to check the eigenvalue gradient of a real matrix against the *unwrapped* `eigvals`/`eigen` (the general-matrix convention); the eigenvector and complex paths keep the `Matrix(Hermitian(·))` reference. Verified `test_rrule(eigvals, A)` now passes for a symmetric `A`, and end-to-end that `jacobian(eigvals, [1 2; 2 3])` matches finite differences. Full factorization (2977) and symmetric (5499) test suites pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- Project.toml | 2 +- src/rulesets/LinearAlgebra/factorization.jl | 22 ++++++++++++++++++-- test/rulesets/LinearAlgebra/factorization.jl | 20 ++++++++++++++++-- 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index ff52dd8bb..ddb2fe3fd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.73.0" +version = "1.73.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 1391e6aef..8a1057731 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -335,8 +335,26 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{ hermA = Hermitian(A) ∂V = ΔV isa AbstractZero ? ΔV : copyto!(similar(ΔV), ΔV) ∂hermA = eigen_rev!(hermA, λ, V, Δλ, ∂V) - ∂Atriu = _symherm_back(typeof(hermA), ∂hermA, Symbol(hermA.uplo)) - ∂A = ∂Atriu isa AbstractTriangular ? triu!(∂Atriu.data) : ∂Atriu + if T <: Real && ΔV isa AbstractZero + # `A` is a plain matrix whose entries are all independent free + # variables; it merely happens to be symmetric, so the primal + # dispatched to the symmetric algorithm. `eigen_rev!` already returns + # a (Symmetric-wrapped) cotangent with the off-diagonal eigenvalue + # sensitivity split evenly across both triangles. Materialise that + # full matrix rather than projecting onto the stored triangle: the + # projection zeroes out the other triangle, disagreeing with + # ForwardDiff/finite differences and producing a gradient that is + # discontinuous as `A` crosses exact symmetry (see #1369). + # + # This is limited to the eigenvalues of a real matrix: the + # eigenvector phase convention differs between the symmetric and + # general algorithms, and the complex-Hermitian case is hard to pin + # down against finite differences, so those paths are left as-is. + ∂A = ∂hermA isa AbstractZero ? ∂hermA : Matrix(∂hermA) + else + ∂Atriu = _symherm_back(typeof(hermA), ∂hermA, Symbol(hermA.uplo)) + ∂A = ∂Atriu isa AbstractTriangular ? triu!(∂Atriu.data) : ∂Atriu + end elseif ΔV isa AbstractZero ∂K = Diagonal(Δλ) ∂A = V' \ ∂K * V' diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 1dd1aeb37..a264ae149 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -349,8 +349,13 @@ end ∂F_stable = (; [s => copy(getproperty(ΔF, s)) for s in nzprops]...) :vectors in nzprops && rmul!(∂F_stable.vectors, C) + # For the eigenvalues of a real matrix, the cotangent is the + # general-matrix gradient (matching the unwrapped `eigen`); for + # everything else it follows the symmetric-manifold convention + # (FD through `Matrix(Hermitian(x))`). See #1369. + wrap = (T <: Real && nzprops == [:values]) ? identity : x -> Matrix(Hermitian(x)) f_stable = function(x) - F_ = _eigen_stable(Matrix(Hermitian(x))) + F_ = _eigen_stable(wrap(x)) return (; (s => getproperty(F_, s) for s in nzprops)...) end @@ -414,7 +419,18 @@ end ∂self, ∂A = @maybe_inferred back(Δλ) @test ∂self === NoTangent() @test ∂A isa typeof(A) - @test ∂A ≈ j′vp(_fdm, A -> eigvals(Matrix(Hermitian(A))), Δλ, A)[1] + if T <: Real + # `A` is a plain matrix that merely happens to be symmetric, + # so the cotangent must be the general-matrix gradient (split + # over both triangles), matching the unwrapped `eigvals`, not + # the symmetric-manifold gradient projected to one triangle + # (see #1369). FiniteDifferences cannot diff the unwrapped + # complex `eigvals` (eigenvalues leave the reals), so the + # complex case keeps the Hermitian-wrapped reference. + @test ∂A ≈ j′vp(_fdm, eigvals, Δλ, A)[1] + else + @test ∂A ≈ j′vp(_fdm, A -> eigvals(Matrix(Hermitian(A))), Δλ, A)[1] + end @test @maybe_inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent()) end end