Unroll _transform_tuple to fix Enzyme.autodiff on tuples of length ≥ 33#170
Conversation
The recursive Base.tail fold in _transform_tuple makes Enzyme.autodiff
(Forward and Reverse) throw `AssertionError("conv == 37")` from
Enzyme/src/rules/jitrules.jl:2073 once the tuple has ≥ 33 entries
(EnzymeAD/Enzyme.jl#3104). Replace it with a @generated straight-line
unroll that produces the same outputs bit-for-bit while emitting no
self-invoke in the typed IR — which is what Enzyme trips on.
Verified against the full Pkg.test() suite (all Pass = Total) and a
35-entry SW07-Pfeifer-style NamedTuple prior (fwd + rev both succeed).
|
Any chance that this will be merged here? I understand that the real fix should be on Enzyme's side, but that may be much harder. Thanks! PS: The is my real world MWE that lead me finally to this PR; maybe it is useful for someone. using Distributions
using Enzyme
using TransformVariables
N = 33
dists = ntuple(i -> LogNormal(0.0, 1.0), N)
dists = NamedTuple{ntuple(i -> Symbol("x", i), N)}(dists)
function prior_transform(priors)
transforms = map(priors) do prior
left, right = extrema(support(prior))
left = isinf(left) ? -TransformVariables.∞ : left
right = isinf(right) ? TransformVariables.∞ : right
TransformVariables.as(Real, left, right)
end
TransformVariables.as(transforms)
end
trans = prior_transform(dists)
q = fill(-0.1, TransformVariables.dimension(trans))
foo(q) = sum(values(TransformVariables.transform(trans, q)))
Enzyme.gradient(Enzyme.Reverse, foo, q) # AssertionError: conv == 37 |
|
@jlperla, thanks for this, @scheidan, thanks for the ping. I apologize for the delay in reviewing this. It is not strictly equivalent as, AFAIK, built-ins do not necessarily unroll above a certain tuple length. But given that the intention of using a tuple is to get type-stable code, I don't see a problem with this here. Also, EnzymeAD/Enzyme.jl#3104 indicates that this is an issue on the Julia side, so fixing it on our end may be the best option for now. @devmotion, this is fine with me, do you have any comments? |
|
(closing and reopening to make CI run) |
| for i in 1:N] | ||
| ℓ_sum = foldl((a, b) -> :($a + $b), ℓs) | ||
| return quote | ||
| idx = index |
There was a problem hiding this comment.
Why is a separate idx variable needed? Couldn't we just operate with index?
Co-authored-by: Tamas K. Papp <tkpapp@gmail.com>
Just wanted to mention: This is wrong, I actually saw downstream test failures due to this PR. Previously, summation was performed using (basically) |
|
I was waiting for CI to run. Did it? |
|
In this PR? Sure, it ran before it was merged. My point was merely that it broke my downstream CI due to not yielding "same outputs bit-for-bit", and I wanted to point out for any future reader of this PR that this was an incorrect claim in the initial comment above. |
I just want to clarify that this is not something this package ever promised. AFAIK very few packages in the Julia ecosystem have that kind of commitment. @devmotion, thanks for pointing this out though.
I am sorry to hear this, but if they were comparing exact output, those were the wrong kind of tests. |
|
No, they were not comparing to exact TransformVariables output. The test failure was caused by different likelihood values of MLE estimates, apparently the tiny difference in the transform was sufficient to cause slightly different optimization trajectories. Not a big problem, of course, but that made me realize the incorrect claim above. |
Replace the
Base.tail-recursive_transform_tuplewith a@generatedstraight-line unroll — same outputs bit-for-bit, but the typed IR no longer contains a self-invoke, which is whatEnzyme.autodiff(Forward and Reverse) trips on at tuple length ≥ 33 withAssertionError("conv == 37")(EnzymeAD/Enzyme.jl#3104).