-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enzyme support #85
base: master
Are you sure you want to change the base?
Enzyme support #85
Conversation
Test on v0.11.0. v0.11.1 has some odd bugs (which is why it's excluded in most downstream) |
Locally I tested on Enzyme main, still failed. CI will probably pull in 0.11.1. When 0.12 or 0.11.2 land, we can lower bound Enzyme here to the latest version that works. |
Not sure I can help but just wanted to say I strongly support this PR! Thanks @mohamed82008 |
Wouldn't it be more natural to make |
Test it. If faster or more accurate, happy to switch. |
Co-authored-by: Seth Axen <[email protected]>
@mohamed82008 mind addressing the above comments and getting this merged? |
Let me re-run the tests locally. |
Tests still fail with the latest Enzyme on both M1 and Windows machines. Here is the Windows error. Status `C:\Users\moham\AppData\Local\Temp\jl_yH181v\Project.toml`
[c29ec348] AbstractDifferentiation v0.6.0 `https://github.com/JuliaDiff/AbstractDifferentiation.jl.git#mt/enzyme`
[d360d2e6] ChainRulesCore v1.16.0
[163ba53b] DiffResults v1.1.0
[7da242da] Enzyme v0.11.6 `https://github.com/EnzymeAD/Enzyme.jl.git#main`
[e2ba6199] ExprTools v0.1.10
[26cc04aa] FiniteDifferences v0.12.29
[f6369f11] ForwardDiff v0.10.35
[ae029012] Requires v1.3.0
[37e2e3b7] ReverseDiff v1.15.0
[9f7883ad] Tracker v0.2.26
[e88e6eb3] Zygote v0.6.62
[37e2e46d] LinearAlgebra `@stdlib/LinearAlgebra`
[9a3f8284] Random `@stdlib/Random`
[8dfed614] Test `@stdlib/Test`
Status `C:\Users\moham\AppData\Local\Temp\jl_yH181v\Manifest.toml`
[c29ec348] AbstractDifferentiation v0.6.0 `https://github.com/JuliaDiff/AbstractDifferentiation.jl.git#mt/enzyme`
[621f4979] AbstractFFTs v1.5.0
[79e6a3ab] Adapt v3.6.2
[a9b6321e] Atomix v0.1.0
[fa961155] CEnum v0.4.2
[082447d4] ChainRules v1.53.0
[d360d2e6] ChainRulesCore v1.16.0
[bbf7d656] CommonSubexpressions v0.3.0
[34da2185] Compat v4.9.0
[9a962f9c] DataAPI v1.15.0
[e2d170a0] DataValueInterfaces v1.0.0
[163ba53b] DiffResults v1.1.0
[b552c78f] DiffRules v1.15.1
[ffbed154] DocStringExtensions v0.9.3
[7da242da] Enzyme v0.11.6 `https://github.com/EnzymeAD/Enzyme.jl.git#main`
[f151be2c] EnzymeCore v0.5.1
[e2ba6199] ExprTools v0.1.10
[1a297f60] FillArrays v1.5.0
[26cc04aa] FiniteDifferences v0.12.29
[f6369f11] ForwardDiff v0.10.35
[069b7b12] FunctionWrappers v1.1.3
[d9f16b24] Functors v0.4.5
[0c68f7d7] GPUArrays v8.8.1
[46192b85] GPUArraysCore v0.1.5
[61eb1bfa] GPUCompiler v0.21.4
[7869d1d1] IRTools v0.4.10
[92d709cd] IrrationalConstants v0.2.2
[82899510] IteratorInterfaceExtensions v1.0.0
[692b3bcd] JLLWrappers v1.4.1
[63c18a36] KernelAbstractions v0.9.8
[929cbde3] LLVM v6.1.0
[2ab3a3ac] LogExpFunctions v0.3.24
[1914dd2f] MacroTools v0.5.10
[872c559c] NNlib v0.9.4
[77ba4419] NaNMath v1.0.2
[d8793406] ObjectFile v0.4.0
[3bd65402] Optimisers v0.2.19
[bac558e1] OrderedCollections v1.6.2
[aea7be01] PrecompileTools v1.1.2
[21216c6a] Preferences v1.4.0
[c1ae055f] RealDot v0.1.0
[189a3867] Reexport v1.2.2
[ae029012] Requires v1.3.0
[37e2e3b7] ReverseDiff v1.15.0
[708f8203] Richardson v1.4.0
[6c6a2e73] Scratch v1.2.0
[276daf66] SpecialFunctions v2.3.0
[90137ffa] StaticArrays v1.6.2
[1e83bf80] StaticArraysCore v1.4.2
[09ab397b] StructArrays v0.6.15
[53d494c1] StructIO v0.3.0
[3783bdb8] TableTraits v1.0.1
[bd369af6] Tables v1.10.1
[a759f4b9] TimerOutputs v0.5.23
[9f7883ad] Tracker v0.2.26
[013be700] UnsafeAtomics v0.2.1
[d80eeb9a] UnsafeAtomicsLLVM v0.1.3
[e88e6eb3] Zygote v0.6.62
[700de1a5] ZygoteRules v0.2.3
[7cc45869] Enzyme_jll v0.0.78+0
[dad2f222] LLVMExtra_jll v0.0.23+0
[efe28fd5] OpenSpecFun_jll v0.5.5+0
[0dad84c5] ArgTools v1.1.1 `@stdlib/ArgTools`
[56f22d72] Artifacts `@stdlib/Artifacts`
[2a0f44e3] Base64 `@stdlib/Base64`
[ade2ca70] Dates `@stdlib/Dates`
[8ba89e20] Distributed `@stdlib/Distributed`
[f43a241f] Downloads v1.6.0 `@stdlib/Downloads`
[7b1f6079] FileWatching `@stdlib/FileWatching`
[b77e0a4c] InteractiveUtils `@stdlib/InteractiveUtils`
[4af54fe1] LazyArtifacts `@stdlib/LazyArtifacts`
[b27032c2] LibCURL v0.6.3 `@stdlib/LibCURL`
[76f85450] LibGit2 `@stdlib/LibGit2`
[8f399da3] Libdl `@stdlib/Libdl`
[37e2e46d] LinearAlgebra `@stdlib/LinearAlgebra`
[56ddb016] Logging `@stdlib/Logging`
[d6f4376e] Markdown `@stdlib/Markdown`
[ca575930] NetworkOptions v1.2.0 `@stdlib/NetworkOptions`
[44cfe95a] Pkg v1.9.2 `@stdlib/Pkg`
[de0858da] Printf `@stdlib/Printf`
[3fa0cd96] REPL `@stdlib/REPL`
[9a3f8284] Random `@stdlib/Random`
[ea8e919c] SHA v0.7.0 `@stdlib/SHA`
[9e88b42a] Serialization `@stdlib/Serialization`
[6462fe0b] Sockets `@stdlib/Sockets`
[2f01184e] SparseArrays `@stdlib/SparseArrays`
[10745b16] Statistics v1.9.0 `@stdlib/Statistics`
[fa267f1f] TOML v1.0.3 `@stdlib/TOML`
[a4e569a6] Tar v1.10.0 `@stdlib/Tar`
[8dfed614] Test `@stdlib/Test`
[cf7118a7] UUIDs `@stdlib/UUIDs`
[4ec0a83e] Unicode `@stdlib/Unicode`
[e66e0078] CompilerSupportLibraries_jll v1.0.5+0 `@stdlib/CompilerSupportLibraries_jll`
[deac9b47] LibCURL_jll v7.84.0+0 `@stdlib/LibCURL_jll`
[29816b5a] LibSSH2_jll v1.10.2+0 `@stdlib/LibSSH2_jll`
[c8ffd9c3] MbedTLS_jll v2.28.2+0 `@stdlib/MbedTLS_jll`
[14a3606d] MozillaCACerts_jll v2022.10.11 `@stdlib/MozillaCACerts_jll`
[4536629a] OpenBLAS_jll v0.3.21+4 `@stdlib/OpenBLAS_jll`
[05823500] OpenLibm_jll v0.8.1+0 `@stdlib/OpenLibm_jll`
[bea87d4a] SuiteSparse_jll v5.10.1+6 `@stdlib/SuiteSparse_jll`
[83775a58] Zlib_jll v1.2.13+0 `@stdlib/Zlib_jll`
[8e850b90] libblastrampoline_jll v5.8.0+0 `@stdlib/libblastrampoline_jll`
[8e850ede] nghttp2_jll v1.48.0+0 `@stdlib/nghttp2_jll`
[3f19e933] p7zip_jll v17.4.0+0 `@stdlib/p7zip_jll`
Precompiling project...
80 dependencies successfully precompiled in 67 seconds. 6 already precompiled.
Testing Running tests...
┌ Warning: `ForwardDiff.gradient(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
│ typeof(f) = AbstractDifferentiation.var"#3#4"{ForwardDiffBackend1, var"#19#21", Tuple{Vector{Float64}}}
└ @ Zygote C:\Users\moham\.julia\packages\Zygote\JeHtr\src\lib\forward.jl:142
o: %11 = call nonnull {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* noundef nonnull @jl_f__apply_iterate, {} addrspace(10)* noundef null, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140708473958160 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140708429727024 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %10) #11, !dbg !42
ot: {} addrspace(10)*
ir: %99 = call nonnull {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* noundef nonnull @jl_f__apply_iterate, {} addrspace(10)* noundef null, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140708473958160 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140708429727024 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %98) #11, !dbg !66
irt: {} addrspace(10)*
p: %"'dual_phi7" = phi [5 x {} addrspace(10)*] , !dbg !66
PT: [5 x {} addrspace(10)*]
newCall: %99 = call nonnull {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* noundef nonnull @jl_f__apply_iterate, {} addrspace(10)* noundef null, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140708473958160 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140708429727024 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %98) #11, !dbg !66
newCallT: {} addrspace(10)*
Assertion failed: invertedReturn->getType() == gutils->getShadowType(call.getType()), file /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h, line 8077
[21920] signal (22): SIGABRT
in expression starting at C:\Users\moham\.julia\packages\AbstractDifferentiation\cNUlO\test\enzyme.jl:10
crt_sig_handler at C:/workdir/src\signals-win.c:95
raise at C:\WINDOWS\System32\msvcrt.dll (unknown line)
abort at C:\WINDOWS\System32\msvcrt.dll (unknown line)
assert at C:\WINDOWS\System32\msvcrt.dll (unknown line)
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme\AdjointGenerator.h:8076
delegateCallInst at /opt/x86_64-w64-mingw32/x86_64-w64-mingw32/sys-root/usr/local/include/llvm/IR\InstVisitor.h:302 [inlined]
visitCall at /opt/x86_64-w64-mingw32/x86_64-w64-mingw32/sys-root/usr/local/include/llvm/IR\Instruction.def:209 [inlined]
visit at /opt/x86_64-w64-mingw32/x86_64-w64-mingw32/sys-root/usr/local/include/llvm/IR\Instruction.def:209
visit at /opt/x86_64-w64-mingw32/x86_64-w64-mingw32/sys-root/usr/local/include/llvm/IR\InstVisitor.h:112 [inlined]
CreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme\EnzymeLogic.cpp:4686
EnzymeCreateForwardDiff at /workspace/srcdir/Enzyme/enzyme/Enzyme\CApi.cpp:565
EnzymeCreateForwardDiff at C:\Users\moham\.julia\packages\Enzyme\tHfGe\src\api.jl:142
enzyme! at C:\Users\moham\.julia\packages\Enzyme\tHfGe\src\compiler.jl:7575
unknown function (ip: 0000025b7f9f7931)
#codegen#438 at C:\Users\moham\.julia\packages\Enzyme\tHfGe\src\compiler.jl:9100
codegen at C:\Users\moham\.julia\packages\Enzyme\tHfGe\src\compiler.jl:8705 [inlined]
_thunk at C:\Users\moham\.julia\packages\Enzyme\tHfGe\src\compiler.jl:9652
_thunk at C:\Users\moham\.julia\packages\Enzyme\tHfGe\src\compiler.jl:9652 [inlined]
cached_compilation at C:\Users\moham\.julia\packages\Enzyme\tHfGe\src\compiler.jl:9686 [inlined]
#475 at C:\Users\moham\.julia\packages\Enzyme\tHfGe\src\compiler.jl:9749
JuliaContext at C:\Users\moham\.julia\packages\GPUCompiler\YO8Uj\src\driver.jl:47
unknown function (ip: 0000025b7f9f2816)
#s292#474 at C:\Users\moham\.julia\packages\Enzyme\tHfGe\src\compiler.jl:9704 [inlined]
#s292#474 at .\none:0
GeneratedFunctionStub at .\boot.jl:602
jl_apply at C:/workdir/src\julia.h:1879 [inlined]
jl_call_staged at C:/workdir/src\method.c:530
ijl_code_for_staged at C:/workdir/src\method.c:581
get_staged at .\compiler\utilities.jl:115
retrieve_code_info at .\compiler\utilities.jl:127 [inlined]
InferenceState at .\compiler\inferencestate.jl:354
typeinf_edge at .\compiler\typeinfer.jl:923
abstract_call_method at .\compiler\abstractinterpretation.jl:611
abstract_call_gf_by_type at .\compiler\abstractinterpretation.jl:152
abstract_call_known at .\compiler\abstractinterpretation.jl:1949
jfptr_abstract_call_known_20122.clone_1 at C:\Users\moham\.julia\juliaup\julia-1.9.2+0.x64.w64.mingw32\lib\julia\sys.dll (unknown line)
tojlinvoke21878.clone_1 at C:\Users\moham\.julia\juliaup\julia-1.9.2+0.x64.w64.mingw32\lib\julia\sys.dll (unknown line)
j_abstract_call_known_14207.clone_1 at C:\Users\moham\.julia\juliaup\julia-1.9.2+0.x64.w64.mingw32\lib\julia\sys.dll (unknown line)
abstract_call at .\compiler\abstractinterpretation.jl:2020
abstract_call at .\compiler\abstractinterpretation.jl:1999
abstract_eval_statement_expr at .\compiler\abstractinterpretation.jl:2183
abstract_eval_statement at .\compiler\abstractinterpretation.jl:2396
abstract_eval_basic_statement at .\compiler\abstractinterpretation.jl:2658
typeinf_local at .\compiler\abstractinterpretation.jl:2867
typeinf_nocycle at .\compiler\abstractinterpretation.jl:2955
_typeinf at .\compiler\typeinfer.jl:246
typeinf at .\compiler\typeinfer.jl:216
typeinf_ext at .\compiler\typeinfer.jl:1057
typeinf_ext_toplevel at .\compiler\typeinfer.jl:1090
typeinf_ext_toplevel at .\compiler\typeinfer.jl:1086
jfptr_typeinf_ext_toplevel_20761.clone_1 at C:\Users\moham\.julia\juliaup\julia-1.9.2+0.x64.w64.mingw32\lib\julia\sys.dll (unknown line)
_jl_invoke at C:/workdir/src\gf.c:2758 [inlined]
ijl_apply_generic at C:/workdir/src\gf.c:2940 [inlined]
jl_apply at C:/workdir/src\julia.h:1879 [inlined]
jl_type_infer at C:/workdir/src\gf.c:320
jl_generate_fptr_impl at C:/workdir/src\jitlayers.cpp:444
jl_compile_method_internal at C:/workdir/src\gf.c:2348
jl_compile_method_internal at C:/workdir/src\gf.c:2241 [inlined]
_jl_invoke at C:/workdir/src\gf.c:2750 [inlined]
ijl_apply_generic at C:/workdir/src\gf.c:2940
unknown function (ip: 0000025b7a720141)
Allocations: 1152562605 (Pool: 1152256428; Big: 306177); GC: 1532
ERROR: Package AbstractDifferentiation errored during testing (exit code: 3) |
Try running the test file locally after commenting out the lines for other backends. |
Did you apply the changes I suggested above in code review? |
I applied the comments and more tests pass now! A few still fail or error. The Hessian ones crash Julia so I commented them out. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added more comments of places you should use autodiff directly. If there's a place you can also use autodiff directly for the hessian I think that is necessary in fact.
Specifically because you need to use autodiff_deferred in the innermost call (see the enzyme Julia docs for a hessian example)
There are a few failed tests. These need to be isolated and reported upstream. Hessian tests also currently trigger a segfault on my machine so they are commented out. Overall this PR (and Enzyme) seems closer than ever before but not quite there yet. |
dup = if y isa Real | ||
if Δ isa Real | ||
Enzyme.Duplicated([y], [Δ]) | ||
elseif Δ isa Tuple{Real} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tuple issue hits again...
Δ_xs = zero.(xs) | ||
dup = if y isa Real | ||
if Δ isa Real | ||
Enzyme.Duplicated([y], [Δ]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems a bit strange - that's not something an Enzyme user would do AFAIK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya it's a quick and dirty hack to get it running, needs to be optimised
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe one can reuse some of the things I did in TuringLang/DistributionsAD.jl#254.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is a real or tuple of real, this should be an active argument [in reverse mode]
Mutating(f), | ||
Enzyme.Const, | ||
dup, | ||
Enzyme.Duplicated.(xs, Δ_xs)..., |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That means users of AbstractDifferentiation miss a major feature of Enzyme. But maybe it's unavoidable and the current design of AbstractDifferentiation can't support it and the wrapper will always be less performant than Enzyme?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's brainstorm solutions, I think it's possible to support partial pullback with an extended API
AD backend that uses reverse mode of Enzyme.jl. | ||
|
||
!!! note | ||
To be able to use this backend, you have to load Enzyme. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only true on Julia >= 1.9 I think?
end | ||
end | ||
function AD.pushforward_function(::AD.EnzymeReverseBackend, f, xs...) | ||
return AD.pushforward_function(AD.EnzymeForwardBackend(), f, xs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This creates an inconsistency with the behaviour of other backends where it is guaranteed that the specified backend is used for every operation. I think the better design might be to have dedicated Reverse+Forward wrappers that allow to specify different backends for forward and reverse mode operations and pick the best mode for every call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. This was done to make some failed tests pass which likely fail due to an Enzyme correctness issue. We should change this before merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the test case of the correctness issue? Can you open an issue with it?
end | ||
|
||
AD.@primitive function value_and_pullback_function(b::AD.EnzymeReverseBackend, f, xs...) | ||
y = f(xs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should use ReverseSplitMode here, and call the augmented forward pass for that result, use the reverse pass (and tape created from aug) for the reverse pass.
!!! note | ||
To be able to use this backend, you have to load Enzyme. | ||
""" | ||
struct EnzymeForwardBackend <: AbstractForwardMode end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
struct EnzymeForwardBackend <: AbstractForwardMode end | |
struct EnzymeForwardBackend <: AbstractForwardMode end |
@mohamed82008 @devmotion the failures look to me as all issues from tuples in the AD interface, not anything which fails to be differentiated by Enzyme. If those can get fixed, I can fix whatever else arises from within Enzyme calls. |
This is a draft PR because Enzyme gives incorrect gradients in the tests. I suspect that's because of global captures related to this thread https://discourse.julialang.org/t/whats-the-state-of-automatic-differentiation-in-julia-january-2023/92473/21.