Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions src/datafit.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
"""
_pprior_samples(chain, i)

Extract the posterior samples of `pprior[i]` from a Turing sampling result, supporting
both chain backends across Turing versions: the `FlexiChains.VNChain` returned by newer
Turing (indexed by `@varname(pprior[i])`) and the legacy `MCMCChains.Chains`
(indexed by the `"pprior[i]"` string key).
"""
function _pprior_samples(chain, i)
vn = @varname(pprior[i])
samples = try
chain[vn]
catch err
err isa Union{MethodError, KeyError, ArgumentError} ||
rethrow(err)
chain["pprior[" * string(i) * "]"]
end
return collect(samples)[:]
end

function l2loss(pvals, (prob, pkeys, t, data)::Tuple{Vararg{Any, 4}})
p = Pair.(pkeys, pvals)
prob = remake(prob, tspan = (prob.tspan[1], t[end]), p = p)
Expand All @@ -6,7 +26,7 @@ function l2loss(pvals, (prob, pkeys, t, data)::Tuple{Vararg{Any, 4}})
for pairs in data
tot_loss += sum((sol[pairs.first] .- pairs.second) .^ 2)
end
return tot_loss, sol
return tot_loss
end

function l2loss(pvals, (prob, pkeys, data)::Tuple{Vararg{Any, 3}})
Expand All @@ -22,7 +42,7 @@ function l2loss(pvals, (prob, pkeys, data)::Tuple{Vararg{Any, 3}})
for i in 1:length(ts)
tot_loss += sum((sol(ts[i]; idxs = datakeys[i]) .- timeseries[i]) .^ 2)
end
return tot_loss, sol
return tot_loss
end

function relative_l2loss(pvals, (prob, pkeys, t, data)::Tuple{Vararg{Any, 4}})
Expand All @@ -33,7 +53,7 @@ function relative_l2loss(pvals, (prob, pkeys, t, data)::Tuple{Vararg{Any, 4}})
for pairs in data
tot_loss += sum(((sol[pairs.first] .- pairs.second) ./ sol[pairs.first]) .^ 2)
end
return tot_loss, sol
return tot_loss
end

function relative_l2loss(pvals, (prob, pkeys, data)::Tuple{Vararg{Any, 3}})
Expand All @@ -50,7 +70,7 @@ function relative_l2loss(pvals, (prob, pkeys, data)::Tuple{Vararg{Any, 3}})
vals = sol(ts[i]; idxs = datakeys[i])
tot_loss += sum(((vals .- timeseries[i]) ./ vals) .^ 2)
end
return tot_loss, sol
return tot_loss
end

"""
Expand Down Expand Up @@ -223,7 +243,7 @@ Turing.@model function bayesianODE(prob, t, pdist, pkeys, data, noise_prior)
prob = remake(prob, tspan = (prob.tspan[1], t[end]), p = Pair.(pkeys, pprior))
sol = solve(prob, saveat = t)
if !SciMLBase.successful_retcode(sol)
Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf)
Turing.@addlogprob! -Inf
return nothing
end
for i in eachindex(data)
Expand All @@ -249,7 +269,7 @@ Turing.@model function bayesianODE(
prob = remake(prob, tspan = (prob.tspan[1], lastt), p = Pair.(pkeys, pprior))
sol = solve(prob)
if !SciMLBase.successful_retcode(sol)
Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf)
Turing.@addlogprob! -Inf
return nothing
end
for i in eachindex(datakeys)
Expand Down Expand Up @@ -308,7 +328,7 @@ function bayesian_datafit(
progress = true
)
return [
Pair(p[i].first, collect(chain["pprior[" * string(i) * "]"])[:])
Pair(p[i].first, _pprior_samples(chain, i))
for i in eachindex(p)
]
end
Expand All @@ -333,7 +353,7 @@ function bayesian_datafit(
progress = true
)
return [
Pair(p[i].first, collect(chain["pprior[" * string(i) * "]"])[:])
Pair(p[i].first, _pprior_samples(chain, i))
for i in eachindex(p)
]
end
Expand Down
23 changes: 21 additions & 2 deletions src/ensemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dataset on which the ensembler should be trained on.
function ensemble_weights(sol::EnsembleSolution, data_ensem)
obs = first.(data_ensem)
predictions = reduce(
vcat, reduce(hcat, [sol[i][s] for i in 1:length(sol)]) for s in obs
vcat, reduce(hcat, [sol.u[i][s] for i in 1:length(sol.u)]) for s in obs
)
data = reduce(
vcat,
Expand All @@ -31,6 +31,23 @@ function ensemble_weights(sol::EnsembleSolution, data_ensem)
return weights = predictions \ data
end

"""
EnsembleProbForwarder(all_probs)

Callable used as the `prob_func` of the `EnsembleProblem` returned by
[`bayesian_ensemble`](@ref). It selects the per-trajectory problem from the stored
`all_probs` vector. It supports both the `prob_func(prob, ctx)` interface of newer
SciMLBase (selecting via `ctx.sim_id`) and the legacy `prob_func(prob, i, repeat)`
interface (selecting via the integer index). Storing `all_probs` lets callers recover
the number of trajectories via `enprob.prob_func.all_probs`.
"""
struct EnsembleProbForwarder{P}
all_probs::P
end

(f::EnsembleProbForwarder)(prob, i::Integer, repeat) = f.all_probs[i]
(f::EnsembleProbForwarder)(prob, ctx) = f.all_probs[ctx.sim_id]

function bayesian_ensemble(
probs, ps, datas;
noise_prior = InverseGamma(2, 3),
Expand All @@ -56,5 +73,7 @@ function bayesian_ensemble(

@info "$(length(all_probs)) total models"

return enprob = EnsembleProblem(all_probs)
return enprob = EnsembleProblem(
all_probs[1]; prob_func = EnsembleProbForwarder(all_probs)
)
end
7 changes: 4 additions & 3 deletions src/sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ function _get_sensitivity(prob, t, x, pbounds; samples)
boundvals = getfield.(pbounds, :second)
boundkeys = getfield.(pbounds, :first)
f = function (p)
prob_func(prob, i, repeat) = remake(prob; p = Pair.(boundkeys, p[:, i]))
prob_func(prob, i::Integer, repeat) = remake(prob; p = Pair.(boundkeys, p[:, i]))
prob_func(prob, ctx) = remake(prob; p = Pair.(boundkeys, p[:, ctx.sim_id]))
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sol = solve(
ensemble_prob, nothing, EnsembleThreads(); saveat = t,
Expand All @@ -11,11 +12,11 @@ function _get_sensitivity(prob, t, x, pbounds; samples)
out = zeros(size(p, 2))
if x isa Function
for i in 1:size(p, 2)
out[i] = x(sol[i])
out[i] = x(sol.u[i])
end
else
for i in 1:size(p, 2)
out[i] = sol[i](t; idxs = x)
out[i] = sol.u[i](t; idxs = x)
end
end
return out
Expand Down
35 changes: 29 additions & 6 deletions src/threshold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,30 @@ function get_threshold(prob, obs, threshold; alg = nothing, kw...)
return sol.t[end]
end

# Decompose a symbolic threshold inequality (e.g. `x > 10.0`) into the state it
# constrains, the numeric bound, and whether the violating side is the upper one
# (`maximum(state) > bound`) or the lower one (`minimum(state) < bound`).
# Symbolics canonicalizes comparisons, so `x > 10.0` is stored as `<(10.0, x)`:
# the constant can land on either side of the operator, which this normalizes.
function _threshold_violation(threshold)
v = ModelingToolkit.value(threshold)
op = operation(v)
args = arguments(v)
isconst(z) = ModelingToolkit.value(z) isa Number
if isconst(args[1])
bound = ModelingToolkit.value(args[1])
state = args[2]
# `bound op state`: `bound < state` ⟺ `state > bound` (upper violation).
upper = (op === <) || (op === <=)
else
bound = ModelingToolkit.value(args[2])
state = args[1]
# `state op bound`: `state > bound`/`state >= bound` is the upper violation.
upper = (op === >) || (op === >=)
end
return state, bound, upper
end

"""
prob_violating_thresholdd(prob, p, thresholds)

Expand All @@ -46,16 +70,15 @@ function prob_violating_threshold(prob, p, thresholds)
h(x, u, p) = u, remake(prob, p = Pair.(pkeys, [x...])).p # remake does not work well with static arrays
function g(sol, p)
for threshold in thresholds
if (threshold.val.f == >) || (threshold.val.f == >=)
if maximum(sol[threshold.val.arguments[1]]) > threshold.val.arguments[2]
state, bound, upper = _threshold_violation(threshold)
if upper
if maximum(sol[state]) > bound
return 1.0
end
elseif (threshold.val.f == <) || (threshold.val.f == <=)
if minimum(sol[threshold.val.arguments[1]]) < threshold.val.arguments[2]
else
if minimum(sol[state]) < bound
return 1.0
end
else
error()
end
end
return 0.0
Expand Down
21 changes: 13 additions & 8 deletions test/ensemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,18 @@ eqs = [

@mtkbuild sys3 = ODESystem(eqs, t)
prob3 = ODEProblem(sys3, [], tspan);
enprob = EnsembleProblem([prob, prob2, prob3])
probs = [prob, prob2, prob3]
prob_func(prob, i::Integer, repeat) = probs[i]
prob_func(prob, ctx) = probs[ctx.sim_id]
enprob = EnsembleProblem(probs[1]; prob_func = prob_func)

sol = solve(enprob; saveat = 1);
sol = solve(enprob, Tsit5(); saveat = 1, trajectories = length(probs));

weights = [0.2, 0.5, 0.3]

fullS = vec(sum(stack(weights .* sol[S, :]), dims = 2))
fullI = vec(sum(stack(weights .* sol[I, :]), dims = 2))
fullR = vec(sum(stack(weights .* sol[R, :]), dims = 2))
fullS = vec(sum(stack(weights .* [sol.u[i][S] for i in 1:length(sol.u)]), dims = 2))
fullI = vec(sum(stack(weights .* [sol.u[i][I] for i in 1:length(sol.u)]), dims = 2))
fullR = vec(sum(stack(weights .* [sol.u[i][R] for i in 1:length(sol.u)]), dims = 2))

t_train = 0:14
data_train = [
Expand All @@ -81,14 +84,16 @@ data_forecast = [
R => (t_forecast, fullR),
]

sol = solve(enprob; saveat = t_ensem);
sol = solve(enprob, Tsit5(); saveat = t_ensem, trajectories = length(probs));

@test ensemble_weights(sol, data_ensem) ≈ [0.2, 0.5, 0.3]

probs = [prob, prob2, prob3]
ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3]
datas = [data_train, data_train, data_train]
enprobs = bayesian_ensemble(probs, ps, datas)

sol = solve(enprobs; saveat = t_ensem);
sol = solve(
enprobs, Tsit5(); saveat = t_ensem,
trajectories = length(enprobs.prob_func.all_probs)
);
ensemble_weights(sol, data_ensem)
Loading