Skip to content

Commit 5e17c8b

Browse files
MCMC Sampling of partially-observed multivariate normal nodes (#1612)
* added getMixedDataNodeNames method for models * removed comments from NEWS * partial_mvn (#1543) `partial_mvn` sampler, to have the code here for posterity. * fixed unlist for graph node ids * updates * updates * updates * updates * removed extraneous tests of partial_mvn sampler * updates per comments on PR * partial_mvn sampler can assign barker sampler * Use Cholesky of Sigma22 in partial_mvn_pp. * fixed dimension mistake * Use Cholesky in sampler_partial_mvn_pp. * Add more testing for sampler_partial_mvn. * fix spacing again * remove extra tests, and fix use of expect_success * added caching of sigma and mu in partial_mvn_pp sampler * correction in caching calculations * Add partial_mvn_pp test. --------- Co-authored-by: Christopher Paciorek <paciorek@stat.berkeley.edu>
1 parent 7598d40 commit 5e17c8b

8 files changed

Lines changed: 631 additions & 151 deletions

File tree

packages/RELEASE_INSTRUCTIONS

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,13 @@ R CMD check --as-cran nimble_${VERSION}.tar.gz
8585

8686
Officially this should be done on a development build of R, so this is just a first pass and we check with r-devel via winbuilder or r-hub below.
8787

88-
Recently I've been getting this warning on my Ubuntu machine:
88+
Recently I've been getting warnings like this on my Ubuntu machine:
8989

9090
```
9191
Compilation used the following non-portable flag(s):
9292
‘-Wdate-time’ ‘-Werror=format-security’ ‘-Wformat’
93+
Compilation used the following non-portable flag(s):
94+
‘-mno-omit-leaf-frame-pointer’
9395
```
9496

9597
Seems like it might just be my local system as it doesn't show up on winbuilder or r-hub checking.

packages/nimble/R/BUGS_model.R

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,22 @@ Details: Multiple logical input arguments may be used simultaneously. For examp
499499
}
500500
return(ans)
501501
},
502+
getMixedDataNodeNames = function(returnType = 'names') {
503+
multivariateStochBool <- sapply(modelDef$declInfo,
504+
function(di) di$type == 'stoch' && grepl(':', deparse(di$targetExpr)))
505+
multivariateStochIDsList <- lapply(modelDef$declInfo[multivariateStochBool], `[[`, 'graphIDs')
506+
multivariateStochIDs <- unlist(multivariateStochIDsList)
507+
if(length(multivariateStochIDs) == 0) multivariateStochIDs <- numeric() ## length=0 case, make a numeric vector
508+
isDataResult <- isDataFromGraphID(multivariateStochIDs, includeMixed = TRUE) ## values are in {0, 1, 2}, where 0=FALSE, 1=TRUE, 2=MIXED
509+
mixedDataIDs <- multivariateStochIDs[isDataResult == 2]
510+
if(returnType == 'ids') return(mixedDataIDs)
511+
if(returnType == 'names') return(modelDef$maps$graphID_2_nodeName[mixedDataIDs])
512+
stop('returnType argument to getMixedDataNodeNames was invalid')
513+
},
514+
isMixedData = function(nodeNames) {
515+
ids <- expandNodeNames(nodeNames, returnType = 'ids')
516+
return(isDataFromGraphID(ids, includeMixed = TRUE) == 2)
517+
},
502518
safeUpdateValidValues = function(validValues, idsVec_only, idsVec_exclude) {
503519
if(!missing(idsVec_only) && !missing(idsVec_exclude)) stop()
504520
if(!missing(idsVec_only)) {
@@ -787,15 +803,23 @@ Details: The variable or node names specified is expanded into a vector of model
787803
return(isDataFromGraphID(g_id))
788804
},
789805

790-
isDataFromGraphID = function(g_id){
791-
## returns TRUE if any elements are flagged as data
806+
isDataFromGraphID = function(g_id, includeMixed = FALSE) {
807+
## default behaviour: returns TRUE if any elements are flagged as data
808+
## when includeMixed=TRUE: 0 for FALSE, 1 for TRUE, or 2 for MIXED
792809
nodeNames <- modelDef$maps$graphID_2_nodeName[g_id]
793-
ret <- unlist(lapply(as.list(nodeNames),
794-
function(nn)
795-
return(any(eval(parse(text=nn, keep.source = FALSE)[[1]],
796-
envir=isDataEnv)))))
797-
if(is.null(ret)) ret <- logical(0)
798-
return(ret)
810+
if(includeMixed) {
811+
f <- function(nn) {
812+
vals <- eval(parse(text=nn, keep.source = FALSE)[[1]], envir=isDataEnv)
813+
if(!any(vals)) return(0)
814+
if(all(vals)) return(1)
815+
return(2)
816+
}
817+
} else {
818+
f <- function(nn) return(any(eval(parse(text=nn, keep.source = FALSE)[[1]], envir=isDataEnv)))
819+
}
820+
ret <- unlist(lapply(as.list(nodeNames), f))
821+
if(is.null(ret)) ret <- logical(0)
822+
return(ret)
799823
},
800824

801825
getDependenciesList = function(returnNames = TRUE, sort = TRUE) {

packages/nimble/R/MCMC_configuration.R

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,11 @@ print: A logical argument specifying whether to print the monitors and samplers.
245245
if(missing(nodes)) {
246246
nodes <- model$getNodeNames(stochOnly = TRUE, includeData = FALSE, includePredictive = samplePredictiveNodes)
247247
# Check of all(model$isStoch(nodes)) is not needed in this case
248+
# Check and adds any partially observed nodes
249+
mixedDataNodeNames <- model$getMixedDataNodeNames()
250+
if(length(mixedDataNodeNames)) {
251+
nodes <- model$topologicallySortNodes(c(nodes, mixedDataNodeNames))
252+
}
248253
} else if(is.null(nodes) || length(nodes)==0) {
249254
nodes <- character(0)
250255
} else nodes <- filterOutDataNodes(nodes) ## configureMCMC *never* assigns samplers to data nodes
@@ -431,6 +436,13 @@ For internal use. Adds default MCMC samplers to the specified nodes.
431436

432437
## for multivariate nodes, either add a conjugate sampler, RW_multinomial, or RW_block sampler
433438
if(nodeLength > 1) {
439+
if(model$isMixedData(node)) {
440+
if(nodeDist == 'dmnorm') {
441+
thisControlList <- c(controlDefaultsArg, multivariateNodesAsScalars = multivariateNodesAsScalars)
442+
addSampler(target = node, type = 'partial_mvn', control = thisControlList, allowData = TRUE) ; next
443+
}
444+
stop(paste0('The node ', node, ' is partially observed. NIMBLE only handles this case for multivariate normal distibutions.'))
445+
}
434446
if(useConjugacy) {
435447
conjugacyResult <- conjugacyResultsAll[[node]]
436448
if(!is.null(conjugacyResult)) {
@@ -807,7 +819,9 @@ For internal use only
807819

808820
filterOutDataNodes = function(nodes) {
809821
nodes <- model$expandNodeNames(nodes)
810-
return(nodes[!model$isData(nodes)])
822+
# We don't filter out partially observed data nodes
823+
gIDs <- model$modelDef$nodeName2GraphIDs(nodes)
824+
return(nodes[model$isDataFromGraphID(gIDs, includeMixed=TRUE) != 1])
811825
},
812826

813827
removeSamplers = function(..., ind, print = FALSE) {

packages/nimble/R/MCMC_samplers.R

Lines changed: 145 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3456,7 +3456,141 @@ sampler_barker <- nimbleFunction(
34563456
)
34573457

34583458

3459-
3459+
####################################################################
3460+
### partially observed multivariate normal sampler #################
3461+
####################################################################
3462+
3463+
#' @rdname samplers
3464+
#' @export
3465+
sampler_partial_mvn <- nimbleFunction(
3466+
name = 'sampler_partial_mvn',
3467+
contains = sampler_BASE,
3468+
setup = function(model, mvSaved, target, control) {
3469+
## control list extraction
3470+
multivariateNodesAsScalars <- extractControlElement(control, 'multivariateNodesAsScalars', default = getNimbleOption('MCMCmultivariateNodesAsScalars'))
3471+
## node list generation
3472+
targetAsScalar <- model$expandNodeNames(target, returnScalarComponents = TRUE)
3473+
isDataBool <- sapply(targetAsScalar, function(targetAsScalar) eval(parse(text = targetAsScalar)[[1]], envir = model$isDataEnv))
3474+
targetNonDataComponents <- targetAsScalar[!isDataBool]
3475+
if(length(model$getDependencies(target, self = FALSE, downstream = TRUE, dataOnly = TRUE)) > 0) {
3476+
predictiveBool <- sapply(targetNonDataComponents, function(n) length(model$getDependencies(n, self = FALSE, downstream = TRUE, dataOnly = TRUE)) == 0)
3477+
targetNonDataPP <- targetNonDataComponents[ predictiveBool] ## predictive
3478+
targetNonDataNP <- targetNonDataComponents[!predictiveBool] ## not predictive
3479+
} else {
3480+
targetNonDataPP <- targetNonDataComponents ## entirely predictive
3481+
targetNonDataNP <- character()
3482+
}
3483+
## nested function and function list definitions
3484+
samplerList <- nimbleFunctionList(sampler_BASE)
3485+
if(length(targetNonDataNP) > 0) {
3486+
if(multivariateNodesAsScalars) {
3487+
for(i in seq_along(targetNonDataNP)) {
3488+
samplerList[[i]] <- sampler_RW(model, mvSaved, targetNonDataNP[i], control)
3489+
}
3490+
} else {
3491+
if(length(targetNonDataNP) == 1) {
3492+
samplerList[[1]] <- sampler_RW(model, mvSaved, targetNonDataNP, control)
3493+
} else {
3494+
if(getNimbleOption('MCMCuseBarkerAsDefaultMV')) {
3495+
samplerList[[1]] <- sampler_barker (model, mvSaved, targetNonDataNP, control)
3496+
} else {
3497+
samplerList[[1]] <- sampler_RW_block(model, mvSaved, targetNonDataNP, control)
3498+
}
3499+
}
3500+
}
3501+
}
3502+
if(length(targetNonDataPP) > 0) {
3503+
samplerList[[ length(samplerList)+1 ]] <- sampler_partial_mvn_pp(model, mvSaved, targetNonDataPP)
3504+
}
3505+
## checks
3506+
if(model$getDistribution(target) != 'dmnorm') stop('The node ', target, ' is parially observed. NIMBLE only handles this case for multivariate normal distibutions.')
3507+
if(!model$isMixedData(target)) stop('The target node ', target, ' is not partially observed.')
3508+
},
3509+
run = function() {
3510+
for(i in seq_along(samplerList)) {
3511+
samplerList[[i]]$run()
3512+
}
3513+
},
3514+
methods = list(
3515+
reset = function() {
3516+
for(i in seq_along(samplerList)) {
3517+
samplerList[[i]]$reset()
3518+
}
3519+
}
3520+
)
3521+
)
3522+
3523+
sampler_partial_mvn_pp <- nimbleFunction(
3524+
name = 'sampler_partial_mvn_pp',
3525+
contains = sampler_BASE,
3526+
setup = function(model, mvSaved, target) {
3527+
## node list generation
3528+
mvNode <- model$expandNodeNames(target)
3529+
mvNodeComponents <- model$expandNodeNames(mvNode, returnScalarComponents = TRUE)
3530+
given <- setdiff(mvNodeComponents, target)
3531+
calcNodes <- model$getDependencies(target, downstream = TRUE, includePredictive = TRUE)
3532+
cholNode <- deparse(model$getParamExpr(mvNode, 'cholesky'))
3533+
meanNode <- deparse(model$getParamExpr(mvNode, 'mean' ))
3534+
## numeric value generation
3535+
n1 <- length(target)
3536+
n2 <- length(given)
3537+
n <- n1 + n2
3538+
mu <- array(0, c(n , 1))
3539+
mu1 <- array(0, c(n1, 1))
3540+
mu2 <- array(0, c(n2, 1))
3541+
Sigma <- array(0, c(n, n ))
3542+
Sigma11 <- array(0, c(n1, n1))
3543+
Sigma12 <- array(0, c(n1, n2))
3544+
Sigma21 <- array(0, c(n2, n1))
3545+
Sigma22 <- array(0, c(n2, n2))
3546+
tmp <- array(0, c(n2, n1))
3547+
ind1 <- match(target, mvNodeComponents)
3548+
ind2 <- match(given, mvNodeComponents)
3549+
sgConst <- length(model$getParents(cholNode, self = TRUE, upstream = TRUE, stochOnly = TRUE)) == 0
3550+
muConst <- length(model$getParents(meanNode, self = TRUE, upstream = TRUE, stochOnly = TRUE)) == 0
3551+
firstRun <- TRUE
3552+
## checks
3553+
if(length(mvNode) != 1) stop('unexpected error in sampler_partial_mvn_pp')
3554+
if(model$getDistribution(mvNode) != 'dmnorm') stop('unexpected error in sampler_partial_mvn_pp')
3555+
if(n != length(mvNodeComponents)) stop('unexpected error in sampler_partial_mvn_pp')
3556+
if(n1*n2 == 0) stop('unexpected error in sampler_partial_mvn_pp')
3557+
},
3558+
run = function() {
3559+
## Sigma12 <<- Sigma12 %*% inverse(Sigma22)
3560+
## mu1[,1] <<- mu1[,1] + (Sigma12 %*% (values(model,given) - mu2[,1]))[,1]
3561+
## Sigma11 <<- Sigma11 - Sigma12 %*% Sigma21
3562+
if(!sgConst | firstRun) {
3563+
Sigma <<- model$getParam(mvNode, 'cov' )
3564+
Sigma11[1:n1,1:n1] <<- Sigma[ind1, ind1]
3565+
Sigma12[1:n1,1:n2] <<- Sigma[ind1, ind2]
3566+
Sigma21[1:n2,1:n1] <<- Sigma[ind2, ind1]
3567+
Sigma22[1:n2,1:n2] <<- Sigma[ind2, ind2]
3568+
Sigma22 <<- chol(Sigma22)
3569+
tmp <<- forwardsolve(t(Sigma22), Sigma21)
3570+
Sigma11 <<- Sigma11 - t(tmp) %*% tmp
3571+
Sigma11 <<- chol(Sigma11)
3572+
}
3573+
if(!muConst | firstRun) {
3574+
mu[,1] <<- model$getParam(mvNode, 'mean')
3575+
mu1[,1] <<- mu[ind1,1]
3576+
mu2[,1] <<- mu[ind2,1]
3577+
}
3578+
if(!sgConst | !muConst | firstRun) {
3579+
mu1[,1] <<- mu1[,1] + (Sigma12 %*% backsolve(Sigma22, forwardsolve(t(Sigma22), values(model,given) - mu2[,1])))[,1]
3580+
}
3581+
if(firstRun) firstRun <<- FALSE
3582+
values(model, target) <<- rmnorm_chol(1, mu1[,1], Sigma11, prec_param = 0)
3583+
model$calculate(calcNodes)
3584+
nimCopy(from = model, to = mvSaved, row = 1, nodes = calcNodes, logProb = TRUE)
3585+
},
3586+
methods = list(
3587+
reset = function() {
3588+
firstRun <<- TRUE
3589+
}
3590+
)
3591+
)
3592+
3593+
34603594
#' MCMC Sampling Algorithms
34613595
#'
34623596
#' Details of the MCMC sampling algorithms provided with the NIMBLE MCMC engine; HMC samplers are in the \code{nimbleHMC} package and particle filter samplers are in the \code{nimbleSMC} package. Additional details, including some recommendations for samplers that may perform better than the samplers that NIMBLE assigns by default are provided in Section 7.11 of the User Manual.
@@ -3881,6 +4015,15 @@ sampler_barker <- nimbleFunction(
38814015
#' The posterior_predictive sampler functions by simulating new values for all downstream (dependent) nodes using their conditional distributions, as well as updating the associated model probabilities. A posterior_predictive sampler will automatically be assigned to all trailing non-data stochastic nodes in a model, or when possible, to any node at a point in the model after which all downstream (dependent) stochastic nodes are non-data.
38824016
#'
38834017
#' The posterior_predictive sampler accepts no control list arguments.
4018+
#'
4019+
#' @section partial_mvn sampler:
4020+
#'
4021+
#' The partial_mvn sampler is designed to sample multivariate normal distributions that are partially observed. That is, some dimensions of the target node are observed data values, some dimensions are not data. Sampling is accomplished using either univariate or multivariate random walk Metropolis Hastings of the unobserved dimensions, as determined by the \code{multivariateNodesAsScalars} argument.
4022+
#'
4023+
#' The \code{partial_mvn} sampler accepts the following control list elements:
4024+
#' \itemize{
4025+
#' \item multivariateNodesAsScalars. A logical argument, specifying whether the sampler should sample the unobserved parts of a partially observed node jointly or independently (default = FALSE).
4026+
#' }
38844027
#'
38854028
#' @section RJ_fixed_prior sampler:
38864029
#'
@@ -3896,7 +4039,7 @@ sampler_barker <- nimbleFunction(
38964039
#'
38974040
#' @name samplers
38984041
#'
3899-
#' @aliases sampler binary categorical prior_samples posterior_predictive RW RW_block RW_multinomial RW_dirichlet RW_wishart RW_llFunction slice AF_slice crossLevel RW_llFunction_block sampler_prior_samples sampler_posterior_predictive sampler_binary sampler_categorical sampler_RW sampler_RW_block sampler_RW_multinomial sampler_RW_dirichlet sampler_RW_wishart sampler_RW_llFunction sampler_slice sampler_AF_slice sampler_crossLevel sampler_RW_llFunction_block CRP CRP_concentration DPmeasure RJ_fixed_prior RJ_indicator RJ_toggled RW_PF RW_PF_block RW_lkj_corr_cholesky sampler_RW_lkj_corr_cholesky RW_block_lkj_corr_cholesky sampler_RW_block_lkj_corr_cholesky sampler_barker barker
4042+
#' @aliases sampler binary categorical prior_samples posterior_predictive RW RW_block RW_multinomial RW_dirichlet RW_wishart RW_llFunction slice AF_slice crossLevel RW_llFunction_block sampler_prior_samples sampler_posterior_predictive sampler_binary sampler_categorical sampler_RW sampler_RW_block sampler_RW_multinomial sampler_RW_dirichlet sampler_RW_wishart sampler_RW_llFunction sampler_slice sampler_AF_slice sampler_crossLevel sampler_RW_llFunction_block CRP CRP_concentration DPmeasure RJ_fixed_prior RJ_indicator RJ_toggled RW_PF RW_PF_block RW_lkj_corr_cholesky sampler_RW_lkj_corr_cholesky RW_block_lkj_corr_cholesky sampler_RW_block_lkj_corr_cholesky sampler_barker barker sampler_partial_mvn partial_mvn
39004043
#'
39014044
#' @examples
39024045
#' ## y[1] ~ dbern() or dbinom():

packages/nimble/R/initializeModel.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#' @author Daniel Turek
88
#' @details This nimbleFunction may be used at the beginning of nimble algorithms to perform model initialization.
99
#' The intended usage is to specialize an instance of this nimbleFunction in the setup function of an algorithm,
10-
#' then execute that specialied function at the beginning of the algorithm run function.
10+
#' then execute that specialized function at the beginning of the algorithm run function.
1111
#' The specialized function takes no arguments.
1212
#'
1313
#' Executing this function ensures that all right-hand-side only nodes have been assigned real values,

packages/nimble/man/initializeModel.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/nimble/man/samplers.Rd

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)