-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmodel_inference.py
More file actions
20 lines (19 loc) · 793 Bytes
/
model_inference.py
File metadata and controls
20 lines (19 loc) · 793 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
'''
Created on Jun 1, 2017
@author: zwieback
'''
import pymc3 as pm
import numpy as np
from pymc3.backends.base import merge_traces
def model_inference(model, niter=2000, nadvi=200000, ntraceadvi=1000, seed=123, nchains=2):
with model:
v_params = pm.variational.advi(n=nadvi, random_seed=seed)
tracevi = pm.variational.sample_vp(v_params, draws=ntraceadvi, random_seed=seed)
traces = []
for chain in range(nchains):
step = pm.NUTS(scaling=np.power(model.dict_to_array(v_params.stds), 2), is_cov=True, target_accept=0.95)
trace = pm.sample(niter, chain=chain, step=step, random_seed=seed)
trace = trace[niter//2::2]
traces.append(trace)
trace = merge_traces(traces)
return trace, v_params, tracevi