Skip to content

Changing network architecture causes training errors. #77

@vavrines

Description

@vavrines

Describe the bug 🐞

Removing some activation functions in the DeepONet model will cause training errors.

Minimal Reproducible Example 👇

Working version:

using Optimization, Optimisers, OptimizationOptimisers
using Lux, NeuralOperators, Random, ComponentArrays
import LuxCUDA, Zygote

rng = MersenneTwister(1234)
dev = gpu_device()

function train(loss, θ, opt, iters, ad)
    optf = Optimization.OptimizationFunction((x, p) -> loss(x), ad)
    optprob = Optimization.OptimizationProblem(optf, θ)
    return Optimization.solve(optprob, opt; maxiters = iters)
end

begin
    m = 32
    batch_size = 64
    eval_points = 10

    us = rand(Float32, m, batch_size) |> dev
    ys = rand(Float32, 1, eval_points) |> dev
    vs = rand(Float32, eval_points, batch_size) |> dev
end

nn = DeepONet(
    Chain(Dense(m => 8, σ), Dense(8 => 8, σ), Dense(8 => 8, σ)),
    Chain(Dense(1 => 4, σ), Dense(4 => 8, σ)),
)

ps, st = Lux.setup(rng, nn) |> dev
pv = ComponentArray(cpu_device()(ps)) |> dev

function loss(p)
    pred, _ = Lux.apply(nn, (us, ys), p, st)
    err = pred - vs
    l = sum(abs2, err)

    return l
end

res = train(loss, pv, AdamW(), 10, AutoZygote())

Now, if we modify the DeepOnet architecture, while the rest of codes are kept unchanged:

nn = DeepONet(
    Chain(Dense(m => 8, σ), Dense(8 => 8, σ), Dense(8 => 8)),
    Chain(Dense(1 => 4, σ), Dense(4 => 8)),
) # two σs are removed

Then an error message appears.

Error & Stacktrace ⚠️

ERROR: MethodError: mapreducedim!(::typeof(identity), ::typeof(Base.add_sum), ::CUDA.CuArray{…}, ::LinearAlgebra.Adjoint{…}) is ambiguous.

Candidates:
  mapreducedim!(f, op, R::GPUArraysCore.AnyGPUArray, A::AbstractArray)
    @ GPUArrays ~/.julia/packages/GPUArrays/u6tui/src/host/mapreduce.jl:10
  mapreducedim!(f, op::Union{typeof(&), typeof(+), typeof(Base._extrema_rf), typeof(Base.add_sum), typeof(max), typeof(min), typeof(|)}, B::AbstractArray, A::LinearAlgebra.Adjoint{T, <:AbstractMatrix} where T)
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/adjtrans.jl:443

Possible fix, define
  mapreducedim!(::Any, ::Union{…}, ::GPUArraysCore.AnyGPUArray, ::LinearAlgebra.Adjoint{…} where T)

Stacktrace:
  [1] sum!(f::Function, r::CUDA.CuArray{…}, A::LinearAlgebra.Adjoint{…}; init::Bool)
    @ Base ./reducedim.jl:1006
  [2] sum!(r::CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, A::LinearAlgebra.Adjoint{Float32, CUDA.CuArray{…}}; init::Bool)
    @ Base ./reducedim.jl:1008
  [3] reduce_sum(x::CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, y::LinearAlgebra.Adjoint{Float32, CUDA.CuArray{…}})
    @ LuxLib.Impl ~/.julia/packages/LuxLib/bYUJG/src/impl/common_ops.jl:36
  [4] ∇bias_add(b::CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Δ::LinearAlgebra.Adjoint{Float32, CUDA.CuArray{…}})
    @ LuxLib.Impl ~/.julia/packages/LuxLib/bYUJG/src/impl/common_ops.jl:30
  [5] (::LuxLib.Impl.var"#249#253"{LinearAlgebra.Adjoint{}, ChainRulesCore.ProjectTo{}, CUDA.CuArray{}})()
    @ LuxLib.Impl ~/.julia/packages/LuxLib/bYUJG/src/impl/matmul.jl:241
  [6] unthunk
    @ ~/.julia/packages/ChainRulesCore/XAgYn/src/tangent_types/thunks.jl:213 [inlined]
  [7] (::ComponentArrays.var"#89#90"{ComponentVector{}, Symbol})(Δ::ChainRulesCore.Thunk{LuxLib.Impl.var"#249#253"{…}})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/dB3Ra/src/compat/chainrulescore.jl:2
  [8] (::Zygote.ZBack{ComponentArrays.var"#89#90"{…}})(dy::ChainRulesCore.Thunk{LuxLib.Impl.var"#249#253"{…}})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:222
  [9] getproperty
    @ ~/.julia/packages/Lux/FMMvw/src/extended_ops.jl:96 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Any})(Δ::ChainRulesCore.Thunk{LuxLib.Impl.var"#249#253"{…}})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [11] Dense
    @ ~/.julia/packages/Lux/FMMvw/src/layers/basic.jl:361 [inlined]
 [12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{LinearAlgebra.Adjoint{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [13] apply
    @ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
 [14] applychain
    @ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:0 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{LinearAlgebra.Adjoint{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [16] Chain
    @ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:509 [inlined]
 [17] apply
    @ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{LinearAlgebra.Adjoint{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [19] applychain
    @ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:0 [inlined]
 [20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRulesCore.Thunk{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [21] Chain
    @ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:509 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRulesCore.Thunk{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [23] apply
    @ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
 [24] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRulesCore.Thunk{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [25] applyparallel
    @ ./tuple.jl:0 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{LinearAlgebra.Adjoint{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [27] Parallel
    @ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:178 [inlined]
 [28] apply
    @ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{LinearAlgebra.Adjoint{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [30] applychain
    @ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:0 [inlined]
 [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{CUDA.CuArray{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [32] Chain
    @ ~/.julia/packages/Lux/FMMvw/src/layers/containers.jl:509 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{CUDA.CuArray{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [34] apply
    @ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
 [35] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{CUDA.CuArray{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [36] AbstractLuxWrapperLayer
    @ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:269 [inlined]
 [37] apply
    @ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
 [38] loss
    @ ~/Coding/Engine/mwe2.jl:38 [inlined]
 [39] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [40] #11
    @ ~/Coding/Engine/mwe2.jl:9 [inlined]
 [41] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:97
 [42] withgradient(::Function, ::ComponentVector{Float32, CUDA.CuArray{…}, Tuple{…}}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:219
 [43] value_and_gradient
    @ ~/.julia/packages/DifferentiationInterface/a7NWj/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:115 [inlined]
 [44] value_and_gradient!(f::Function, grad::ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep{…}, backend::AutoZygote, x::ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceZygoteExt ~/.julia/packages/DifferentiationInterface/a7NWj/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:131
 [45] (::OptimizationZygoteExt.var"#fg!#16"{})(res::ComponentVector{…}, θ::ComponentVector{…})
    @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/Lc8sB/ext/OptimizationZygoteExt.jl:53
 [46] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/vlc0v/src/OptimizationOptimisers.jl:110 [inlined]
 [47] macro expansion
    @ ~/.julia/packages/Optimization/kOLrw/src/utils.jl:32 [inlined]
 [48] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/vlc0v/src/OptimizationOptimisers.jl:92
 [49] solve!(cache::OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/dld4W/src/solve.jl:232
 [50] solve(::OptimizationProblem{…}, ::AdamW{…}; kwargs::@Kwargs{})
    @ SciMLBase ~/.julia/packages/SciMLBase/dld4W/src/solve.jl:130
 [51] solve
    @ ~/.julia/packages/SciMLBase/dld4W/src/solve.jl:127 [inlined]
 [52] train(loss::typeof(loss), θ::ComponentVector{…}, opt::AdamW{…}, iters::Int64, ad::AutoZygote)
    @ Main ~/Coding/Engine/mwe2.jl:11
Some type information was truncated. Use `show(err)` to see complete types.

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
  [052768ef] CUDA v5.8.2
  [b0b7db55] ComponentArrays v0.15.29
  [b2108857] Lux v1.16.0
  [d0bbae9a] LuxCUDA v0.3.3
  [ea5c82af] NeuralOperators v0.6.2
  [3bd65402] Optimisers v0.4.6
  [7f7a1694] Optimization v4.5.0
  [42dfb2eb] OptimizationOptimisers v0.3.8
  [e88e6eb3] Zygote v0.7.10
  [02a925ec] cuDNN v1.4.3

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions