-
-
Notifications
You must be signed in to change notification settings - Fork 13
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug 🐞
Removing some activation functions in the DeepONet model will cause training errors.
Minimal Reproducible Example 👇
Working version:
using Optimization, Optimisers, OptimizationOptimisers
using Lux, NeuralOperators, Random, ComponentArrays
import LuxCUDA, Zygote
rng = MersenneTwister(1234)
dev = gpu_device()
function train(loss, θ, opt, iters, ad)
optf = Optimization.OptimizationFunction((x, p) -> loss(x), ad)
optprob = Optimization.OptimizationProblem(optf, θ)
return Optimization.solve(optprob, opt; maxiters = iters)
end
begin
m = 32
batch_size = 64
eval_points = 10
us = rand(Float32, m, batch_size) |> dev
ys = rand(Float32, 1, eval_points) |> dev
vs = rand(Float32, eval_points, batch_size) |> dev
end
nn = DeepONet(
Chain(Dense(m => 8, σ), Dense(8 => 8, σ), Dense(8 => 8, σ)),
Chain(Dense(1 => 4, σ), Dense(4 => 8, σ)),
)
ps, st = Lux.setup(rng, nn) |> dev
pv = ComponentArray(cpu_device()(ps)) |> dev
function loss(p)
pred, _ = Lux.apply(nn, (us, ys), p, st)
err = pred - vs
l = sum(abs2, err)
return l
end
res = train(loss, pv, AdamW(), 10, AutoZygote())Now, if we modify the DeepOnet architecture, while the rest of codes are kept unchanged:
nn = DeepONet(
Chain(Dense(m => 8, σ), Dense(8 => 8, σ), Dense(8 => 8)),
Chain(Dense(1 => 4, σ), Dense(4 => 8)),
) # two σs are removedThen an error message appears.
Error & Stacktrace
ERROR: MethodError: mapreducedim!(::typeof(identity), ::typeof(Base.add_sum), ::CUDA.CuArray{…}, ::LinearAlgebra.Adjoint{…}) is ambiguous.
Candidates:
mapreducedim!(f, op, R::GPUArraysCore.AnyGPUArray, A::AbstractArray)
@ GPUArrays ~/.julia/packages/GPUArrays/u6tui/src/host/mapreduce.jl:10
mapreducedim!(f, op::Union{typeof(&), typeof(+), typeof(Base._extrema_rf), typeof(Base.add_sum), typeof(max), typeof(min), typeof(|)}, B::AbstractArray, A::LinearAlgebra.Adjoint{T, <:AbstractMatrix} where T)
@ LinearAlgebra ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/adjtrans.jl:443
Possible fix, define
mapreducedim!(::Any, ::Union{…}, ::GPUArraysCore.AnyGPUArray, ::LinearAlgebra.Adjoint{…} where T)
Stacktrace:
[1] sum!(f::Function, r::CUDA.CuArray{…}, A::LinearAlgebra.Adjoint{…}; init::Bool)
@ Base ./reducedim.jl:1006
[2] sum!(r::CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, A::LinearAlgebra.Adjoint{Float32, CUDA.CuArray{…}}; init::Bool)
@ Base ./reducedim.jl:1008
[3] reduce_sum(x::CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, y::LinearAlgebra.Adjoint{Float32, CUDA.CuArray{…}})
@ LuxLib.Impl ~/.julia/packages/LuxLib/bYUJG/src/impl/common_ops.jl:36
[4] ∇bias_add(b::CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Δ::LinearAlgebra.Adjoint{Float32, CUDA.CuArray{…}})
@ LuxLib.Impl ~/.julia/packages/LuxLib/bYUJG/src/impl/common_ops.jl:30
[5] (::LuxLib.Impl.var"#249#253"{LinearAlgebra.Adjoint{…}, ChainRulesCore.ProjectTo{…}, CUDA.CuArray{…}})()
@ LuxLib.Impl ~/.julia/packages/LuxLib/bYUJG/src/impl/matmul.jl:241
[6] unthunk
@ ~/.julia/packages/ChainRulesCore/XAgYn/src/tangent_types/thunks.jl:213 [inlined]
[7] (::ComponentArrays.var"#89#90"{ComponentVector{…}, Symbol})(Δ::ChainRulesCore.Thunk{LuxLib.Impl.var"#249#253"{…}})
@ ComponentArrays ~/.julia/packages/ComponentArrays/dB3Ra/src/compat/chainrulescore.jl:2
[8] (::Zygote.ZBack{ComponentArrays.var"#89#90"{…}})(dy::ChainRulesCore.Thunk{LuxLib.Impl.var"#249#253"{…}})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:222
[9] getproperty
@ ~/.julia/packages/Lux/FMMvw/src/extended_ops.jl:96 [inlined]
[10] (::Zygote.Pullback{Tuple{…}, Any})(Δ::ChainRulesCore.Thunk{LuxLib.Impl.var"#249#253"{…}})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[11] Dense
@ ~/.julia/packages/Lux/FMMvw/src/layers/basic.jl:361 [inlined]
[12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{LinearAlgebra.Adjoint{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[13] apply
@ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
[14] applychain
@ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:0 [inlined]
[15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{LinearAlgebra.Adjoint{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[16] Chain
@ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:509 [inlined]
[17] apply
@ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
[18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{LinearAlgebra.Adjoint{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[19] applychain
@ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:0 [inlined]
[20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRulesCore.Thunk{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[21] Chain
@ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:509 [inlined]
[22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRulesCore.Thunk{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[23] apply
@ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
[24] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRulesCore.Thunk{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[25] applyparallel
@ ./tuple.jl:0 [inlined]
[26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{LinearAlgebra.Adjoint{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[27] Parallel
@ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:178 [inlined]
[28] apply
@ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
[29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{LinearAlgebra.Adjoint{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[30] applychain
@ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:0 [inlined]
[31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{CUDA.CuArray{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[32] Chain
@ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:509 [inlined]
[33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{CUDA.CuArray{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[34] apply
@ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
[35] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{CUDA.CuArray{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[36] AbstractLuxWrapperLayer
@ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:269 [inlined]
[37] apply
@ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
[38] loss
@ ~/Coding/Engine/mwe2.jl:38 [inlined]
[39] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
[40] #11
@ ~/Coding/Engine/mwe2.jl:9 [inlined]
[41] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:97
[42] withgradient(::Function, ::ComponentVector{Float32, CUDA.CuArray{…}, Tuple{…}}, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:219
[43] value_and_gradient
@ ~/.julia/packages/DifferentiationInterface/a7NWj/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:115 [inlined]
[44] value_and_gradient!(f::Function, grad::ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep{…}, backend::AutoZygote, x::ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
@ DifferentiationInterfaceZygoteExt ~/.julia/packages/DifferentiationInterface/a7NWj/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:131
[45] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::ComponentVector{…}, θ::ComponentVector{…})
@ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/Lc8sB/ext/OptimizationZygoteExt.jl:53
[46] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/vlc0v/src/OptimizationOptimisers.jl:110 [inlined]
[47] macro expansion
@ ~/.julia/packages/Optimization/kOLrw/src/utils.jl:32 [inlined]
[48] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/vlc0v/src/OptimizationOptimisers.jl:92
[49] solve!(cache::OptimizationCache{…})
@ SciMLBase ~/.julia/packages/SciMLBase/dld4W/src/solve.jl:232
[50] solve(::OptimizationProblem{…}, ::AdamW{…}; kwargs::@Kwargs{…})
@ SciMLBase ~/.julia/packages/SciMLBase/dld4W/src/solve.jl:130
[51] solve
@ ~/.julia/packages/SciMLBase/dld4W/src/solve.jl:127 [inlined]
[52] train(loss::typeof(loss), θ::ComponentVector{…}, opt::AdamW{…}, iters::Int64, ad::AutoZygote)
@ Main ~/Coding/Engine/mwe2.jl:11
Some type information was truncated. Use `show(err)` to see complete types.Environment (please complete the following information):
- Output of
using Pkg; Pkg.status()
[052768ef] CUDA v5.8.2
[b0b7db55] ComponentArrays v0.15.29
[b2108857] Lux v1.16.0
[d0bbae9a] LuxCUDA v0.3.3
[ea5c82af] NeuralOperators v0.6.2
[3bd65402] Optimisers v0.4.6
[7f7a1694] Optimization v4.5.0
[42dfb2eb] OptimizationOptimisers v0.3.8
[e88e6eb3] Zygote v0.7.10
[02a925ec] cuDNN v1.4.3Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working