make realnvp and nsf layers as part of the pkg#53
Conversation
|
After the recent Mooncake update ( using Random, Distributions, LinearAlgebra
using Bijectors
using Bijectors: partition, combine, PartitionMask
using Mooncake, Enzyme, ADTypes
import DifferentiationInterface as DI
# just define a MLP
function mlp3(
input_dim::Int,
hidden_dims::Int,
output_dim::Int;
activation=Flux.leakyrelu,
paramtype::Type{T} = Float64
) where {T<:AbstractFloat}
m = Chain(
Flux.Dense(input_dim, hidden_dims, activation),
Flux.Dense(hidden_dims, hidden_dims, activation),
Flux.Dense(hidden_dims, output_dim),
)
return Flux._paramtype(paramtype, m)
end
inputdim = 4
mask_idx = 1:2:inputdim
# creat a masking layer
mask = PartitionMask(inputdim, mask_idx)
cdim = length(mask_idx)
x = randn(inputdim)
t_net = mlp3(cdim, 16, cdim; paramtype = Float64)
ps, st = Optimisers.destructure(t_net)the following code runs perfectly function loss(ps, st, x, mask)
t_net = st(ps)
x₁, x₂, x₃ = partition(mask, x)
y₁ = x₁ .+ t_net(x₂)
y = combine(mask, y₁, x₂, x₃)
# println("y = ", y)
return sum(abs2, y)
end
loss(ps, st, x, mask) # return 3.0167880799441793
val, grad = DI.value_and_gradient(
ls_msk,
ADTypes.AutoMooncake(; config = Mooncake.Config()),
ps, DI.Cache(st), DI.Constant(x), DI.Constant(mask)
)but autograd fails if I wrap struct ACL
mask::Bijectors.PartitionMask
t::Flux.Chain
end
@functor ACL (t, )
acl = ACL(mask, t_net)
psacl, stacl = Optimisers.destructure(acl)
function loss_acl(ps, st, x)
acl = st(ps)
t_net = acl.t
mask = acl.mask
x₁, x₂, x₃ = partition(mask, x)
y₁ = x₁ .+ t_net(x₂)
y = combine(mask, y₁, x₂, x₃)
return sum(abs2, y)
end
loss_acl(psacl, stacl, x) # return 3.0167880799441793
val, grad = DI.value_and_gradient(
loss_acl,
ADTypes.AutoMooncake(; config = Mooncake.Config()),
psacl, DI.Cache(stacl), DI.Constant(x)
)with error message
val, grad = DI.value_and_gradient(
loss_acl,
ADTypes.AutoEnzyme(;
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation=Enzyme.Const,
),
psacl, DI.Cache(stacl), DI.Constant(x)
)with output Any thoughts on this @yebai @willtebbutt? |
|
Ah looks like it only has issue when the part of the fields in the structure is annotated by struct Holder
t::Flux.Chain
end
@functor Holder
psh, sth = Optimisers.destructure(Holder(t_net))
function loss2(ps, st, x)
holder = st(ps)
t_net = holder.t
y = x .+ t_net(x)
return sum(abs2, y)
end
loss2(psh, sth, x) # return 7.408352005690478
val, grad = DI.value_and_gradient(
loss2,
ADTypes.AutoMooncake(; config = Mooncake.Config()),
psh, DI.Cache(sth), DI.Constant(x)
)with outputs |
|
@zuhengxu, can you help bisect which Mooncake version / Julia version introduced this bug? |
Good point! I'll look at this today. |
|
It appears that the remaining issues with Mooncake are minor, likely due to a lack of a specific rule. @sunxd3, can you help if it requires a new rule? |
|
I'll look into it 👍 |
Red-Portal
left a comment
There was a problem hiding this comment.
Hi! I have some minor suggestions
|
NormalizingFlows.jl documentation for PR #53 is available at: |
|
Thank you @yebai @sunxd3 @Red-Portal again for the help and comments in the process of this PR! Let me know if this PR looks good to you and I'll merge it afterwards. |
|
Sorry for the delay. Reviewing a paper by JMLR has been taking up all my bandwidth. I'll take a deeper look tomorrow. |
Red-Portal
left a comment
There was a problem hiding this comment.
I only have minor suggestions. Feel free to take a look and apply them if you agree. Otherwise, looks good to me.
|
sorry for missing the tag, allow me to give a look later today or tomorrow |
sunxd3
left a comment
There was a problem hiding this comment.
couple of tiny things, very happy to do another round of review
|
Thank you @sunxd3 @Red-Portal for the review! I made the corresponding updates and let me know if the current version looks good to you! |
sunxd3
left a comment
There was a problem hiding this comment.
another couple of tiny things, nothing major beyond these
|
pretty much good to go from my end, but let's wait for Kyurae to take a look? |
Red-Portal
left a comment
There was a problem hiding this comment.
I only have a few minor comments that should be quick to handle!
|
@Red-Portal @sunxd3 Let me know if I can hit the big green button! Thanks for the quick feedback. |
|
Alright looks good to me now! |
As discussed in #36 (see #36 (comment)), I'm moving the
AffineCouplingandNeuralSplineLayerfrom the example tosrc/so it can be called.AffineCouplingandNeuralSplineLayerintosrcrealnvpand aneuralsplineflowconstructor. For therealnvp, follow the default architecture as mentioned in Advances in Black-Box VI: Normalizing Flows, Importance Weighting, and Optimization.