From 0fa22aec7d710bb6a97992d1516aaaebfec9bab2 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Wed, 12 Jul 2023 16:04:53 -0400 Subject: [PATCH 1/3] Weighted Bayesian ensemblefits --- Project.toml | 1 + src/EasyModelAnalysis.jl | 1 + src/datafit.jl | 179 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 174 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 84646e5d..c6ffe734 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ OptimizationBBO = "3e6eede4-6085-4f62-9a71-46d9bc1eb92b" OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1" OptimizationNLopt = "4e6fcdb7-1186-4e1f-a706-475e75c168bb" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLExpectations = "afe9f18d-7609-4d0e-b7b7-af0cb72b8ea8" diff --git a/src/EasyModelAnalysis.jl b/src/EasyModelAnalysis.jl index ae44063a..b72446da 100644 --- a/src/EasyModelAnalysis.jl +++ b/src/EasyModelAnalysis.jl @@ -10,6 +10,7 @@ using GlobalSensitivity, Turing using SciMLExpectations @reexport using Plots using SciMLBase.EnsembleAnalysis +using Random include("basics.jl") include("datafit.jl") diff --git a/src/datafit.jl b/src/datafit.jl index ed1322bc..23f090c7 100644 --- a/src/datafit.jl +++ b/src/datafit.jl @@ -192,16 +192,15 @@ function global_datafit(prob, pbounds, data; maxiters = 10000, loss = l2loss) Pair.(pkeys, res.u) end -function bayes_unpack_data(p, data) - pdist = getfield.(p, :second) - pkeys = getfield.(p, :first) +function bayes_unpack_data(p::AbstractVector{<:Pair}, data) + pdist, pkeys = bayes_unpack_data(p) ts = first.(last.(data)) lastt = maximum(last, ts) timeseries = last.(last.(data)) datakeys = first.(data) (pdist, pkeys, ts, lastt, timeseries, datakeys) end -function bayes_unpack_data(p) +function bayes_unpack_data(p::AbstractVector{<:Pair}) pdist = getfield.(p, :second) pkeys = getfield.(p, :first) (pdist, pkeys) @@ -249,6 +248,123 @@ Turing.@model function bayesianODE(prob, return nothing end +""" +Weights can be unbounded. Length of weights must be one less than the length of sols, to apply a sum-to-1 constraint. +Last `sol` is given the weight `1 - sum(weights)`. +""" +struct WeightedSol{T, S <: Tuple{Vararg{AbstractVector{T}}}, W} <: AbstractVector{T} + sols::S + weights::W + function WeightedSol{T}(sols::S, + weights::W) where {T, S <: Tuple{Vararg{AbstractVector{T}}}, W} + @assert length(sols) == length(weights) + 1 + new{T, S, W}(sols, weights) + end +end +Base.length(ws::WeightedSol) = length(first(ws.sols)) +Base.size(ws::WeightedSol) = (length(first(ws.sols)),) +function Base.getindex(ws::WeightedSol{T}, i::Int) where {T} + s = zero(T) + w = zero(T) + for j in eachindex(ws.weights) + w += ws.weights[j] + s += ws.weights[j] * ws.sols[j][i] + end + return s + (one(T) - w) * ws.sols[end][i] +end +function WeightedSol(sols, select, weights) + T = eltype(weights) + s = map(Base.Fix2(getindex, select), sols) + WeightedSol{T}(s, weights) +end +function bayes_unpack_data(p::Tuple{Vararg{<:AbstractVector{<:Pair}}}, data) + pdist, pkeys = bayes_unpack_data(p) + ts = first.(last.(data)) + lastt = maximum(last, ts) + timeseries = last.(last.(data)) + datakeys = first.(data) + (pdist, pkeys, ts, lastt, timeseries, datakeys) +end +function bayes_unpack_data(p::Tuple{Vararg{<:AbstractVector{<:Pair}}}) + unpacked = map(bayes_unpack_data, p) + map(first, unpacked), map(last, unpacked) +end + +struct Grouper{N} + sizes::NTuple{N, Int} +end +function (g::Grouper)(x) + i = Ref(0) + map(g.sizes) do N + _i = i[] + i[] = _i + N + view(x, (_i + 1):(_i + N)) + end +end +function flatten(x::Tuple) + reduce(vcat, x), Grouper(map(length, x)) +end + +Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector}, + t, + pdist, + grouppriorsfunc, + probspkeys, + data, + noise_prior) + σ ~ noise_prior + + ppriors ~ product_distribution(pdist) + # stdeviation = sqrt(1/length(weights)) + Nprobs = length(probs) + Nprobs⁻¹ = inv(Nprobs) + weights ~ MvNormal(Distributions.Fill(Nprobs⁻¹, Nprobs - 1), Nprobs⁻¹) + sols = map(probs, probspkeys, grouppriorsfunc(ppriors)) do prob, pkeys, pprior + solve(remake(prob, tspan = (prob.tspan[1], t[end]), p = Pair.(pkeys, pprior)), + saveat = t) + end + if !all(SciMLBase.successful_retcode, sols) + Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf) + return nothing + end + for i in eachindex(data) + data[i].second ~ MvNormal(WeightedSol(sols, data[i].first, weights), σ^2 * I) + end + return nothing +end +Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector}, + pdist, + grouppriorsfunc, + probspkeys, + ts, + lastt, + timeseries, + datakeys, + noise_prior) + σ ~ noise_prior + + ppriors ~ product_distribution(pdist) + + sols = map(probs, probspkeys, grouppriorsfunc(ppriors)) do prob, pkeys, pprior + solve(remake(prob, tspan = (prob.tspan[1], lastt), p = Pair.(pkeys, pprior))) + end + # stdeviation = sqrt(1/length(weights)) + Nprobs = length(probs) + Nprobs⁻¹ = inv(Nprobs) + weights ~ MvNormal(Distributions.Fill(Nprobs⁻¹, Nprobs - 1), Nprobs⁻¹) + if !all(SciMLBase.successful_retcode, sols) + Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf) + return nothing + end + for i in eachindex(datakeys) + vals = map(sols) do sol + sol(ts[i]; idxs = datakeys[i]) + end + timeseries[i] ~ MvNormal(WeightedSol(vals, weights), σ^2 * I) + end + return nothing +end + """ bayesian_datafit(prob, p, t, data) bayesian_datafit(prob, p, data) @@ -288,8 +404,8 @@ function bayesian_datafit(prob, mcmcensemble::AbstractMCMC.AbstractMCMCEnsemble = Turing.MCMCThreads(), nchains = 4, niter = 1000) - (pkeys, pdata) = bayes_unpack_data(p) - model = bayesianODE(prob, t, pkeys, pdata, data, noise_prior) + (pdist, pkeys) = bayes_unpack_data(p) + model = bayesianODE(prob, t, pdist, pkeys, data, noise_prior) chain = Turing.sample(model, Turing.NUTS(0.65), mcmcensemble, @@ -301,7 +417,7 @@ function bayesian_datafit(prob, end function bayesian_datafit(prob, - p, + p::Tuple{Vararg{<:AbstractVector{<:Pair}}}, data; noise_prior = InverseGamma(2, 3), mcmcensemble::AbstractMCMC.AbstractMCMCEnsemble = Turing.MCMCThreads(), @@ -318,6 +434,55 @@ function bayesian_datafit(prob, [Pair(p[i].first, collect(chain["pprior[" * string(i) * "]"])[:]) for i in eachindex(p)] end +function bayesian_datafit(probs::Union{Tuple, AbstractVector}, + p::Tuple{Vararg{<:AbstractVector{<:Pair}}}, + t, + data; + noise_prior = InverseGamma(2, 3), + mcmcensemble::AbstractMCMC.AbstractMCMCEnsemble = Turing.MCMCThreads(), + nchains = 4, + niter = 1000) + (pdist_, pkeys) = bayes_unpack_data(p) + pdist, grouppriorsfunc = flatten(pdist_) + + model = ensemblebayesianODE(probs, t, pdist, grouppriorsfunc, pkeys, data, noise_prior) + chain = Turing.sample(model, + Turing.NUTS(0.65), + mcmcensemble, + niter, + nchains; + progress = true) + [Pair(p[i].first, collect(chain["pprior[" * string(i) * "]"])[:]) + for i in eachindex(p)] +end + +function bayesian_datafit(probs::Union{Tuple, AbstractVector}, + p, + data; + noise_prior = InverseGamma(2, 3), + mcmcensemble::AbstractMCMC.AbstractMCMCEnsemble = Turing.MCMCThreads(), + nchains = 4, + niter = 1_000) + pdist_, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data(p, data) + pdist, grouppriorsfunc = flatten(pdist_) + model = ensemblebayesianODE(probs, + pdist, + grouppriorsfunc, + pkeys, + ts, + lastt, + timeseries, + datakeys, + noise_prior) + chain = Turing.sample(model, + Turing.NUTS(0.65), + mcmcensemble, + niter, + nchains; + progress = true) + [Pair(p[i].first, collect(chain["pprior[" * string(i) * "]"])[:]) + for i in eachindex(p)] +end """ model_forecast_score(probs::AbstractVector, ts::AbstractVector, dataset::AbstractVector{<:Pair}) From 6055074e1ccc618fa4effce8346a82663fec2a62 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Thu, 13 Jul 2023 10:09:48 -0400 Subject: [PATCH 2/3] separate `getsols` for easier introspection --- src/datafit.jl | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/datafit.jl b/src/datafit.jl index 23f090c7..2419a13a 100644 --- a/src/datafit.jl +++ b/src/datafit.jl @@ -305,6 +305,18 @@ function flatten(x::Tuple) reduce(vcat, x), Grouper(map(length, x)) end +function getsols(probs, probspkeys, ppriors, t::AbstractArray) + map(probs, probspkeys, ppriors) do prob, pkeys, pprior + solve(remake(prob, tspan = (prob.tspan[1], t[end]), p = Pair.(pkeys, pprior)), + saveat = t) + end +end +function getsols(probs, probspkeys, ppriors, lastt::Number) + map(probs, probspkeys, ppriors) do prob, pkeys, pprior + solve(remake(prob, tspan = (prob.tspan[1], lastt), p = Pair.(pkeys, pprior))) + end +end + Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector}, t, pdist, @@ -313,16 +325,12 @@ Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector}, data, noise_prior) σ ~ noise_prior - ppriors ~ product_distribution(pdist) - # stdeviation = sqrt(1/length(weights)) + Nprobs = length(probs) Nprobs⁻¹ = inv(Nprobs) weights ~ MvNormal(Distributions.Fill(Nprobs⁻¹, Nprobs - 1), Nprobs⁻¹) - sols = map(probs, probspkeys, grouppriorsfunc(ppriors)) do prob, pkeys, pprior - solve(remake(prob, tspan = (prob.tspan[1], t[end]), p = Pair.(pkeys, pprior)), - saveat = t) - end + sols = getsols(probs, probspkeys, grouppriorsfunc(ppriors), t) if !all(SciMLBase.successful_retcode, sols) Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf) return nothing @@ -342,13 +350,10 @@ Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector}, datakeys, noise_prior) σ ~ noise_prior - ppriors ~ product_distribution(pdist) - sols = map(probs, probspkeys, grouppriorsfunc(ppriors)) do prob, pkeys, pprior - solve(remake(prob, tspan = (prob.tspan[1], lastt), p = Pair.(pkeys, pprior))) - end - # stdeviation = sqrt(1/length(weights)) + sols = getsols(probs, probspkeys, grouppriorsfunc(ppriors), lastt) + Nprobs = length(probs) Nprobs⁻¹ = inv(Nprobs) weights ~ MvNormal(Distributions.Fill(Nprobs⁻¹, Nprobs - 1), Nprobs⁻¹) From f0d7d4dae0964bde3552547a4a207140b9801acd Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Thu, 13 Jul 2023 11:54:12 -0400 Subject: [PATCH 3/3] Fix type signatures --- src/datafit.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datafit.jl b/src/datafit.jl index 2419a13a..dfecd3b2 100644 --- a/src/datafit.jl +++ b/src/datafit.jl @@ -422,7 +422,7 @@ function bayesian_datafit(prob, end function bayesian_datafit(prob, - p::Tuple{Vararg{<:AbstractVector{<:Pair}}}, + p, data; noise_prior = InverseGamma(2, 3), mcmcensemble::AbstractMCMC.AbstractMCMCEnsemble = Turing.MCMCThreads(), @@ -462,7 +462,7 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector}, end function bayesian_datafit(probs::Union{Tuple, AbstractVector}, - p, + p::Tuple{Vararg{<:AbstractVector{<:Pair}}}, data; noise_prior = InverseGamma(2, 3), mcmcensemble::AbstractMCMC.AbstractMCMCEnsemble = Turing.MCMCThreads(),