Skip to content

Commit 962dd83

Browse files
authored
Merge pull request #1158 from stan-dev/fix-reloaded-model-with-methods
Don't error when fitting a model using reloaded CmdStanModel with compiled model methods
2 parents 009d225 + ed3dd1f commit 962dd83

File tree

9 files changed

+119
-8
lines changed

9 files changed

+119
-8
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# cmdstanr (development version)
22

3+
* CmdStanModel objects created using `compile_model_methods = TRUE` that are
4+
then saved and reloaded no longer error in model fitting methods. Model methods
5+
are recompiled lazily if needed.
6+
37
* CmdStan versions older than 2.35.0 are no longer supported.
48
* Minimum R version increased to 4.0.0.
59
* Removed legacy Windows toolchain paths for older CmdStan releases.

R/fit.R

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ CmdStanFit <- R6::R6Class(
2323
assign(n, get(n, runset$model_methods_env()), private$model_methods_env_)
2424
}
2525
}
26+
drop_stale_model_methods(private$model_methods_env_)
2627

2728
self$functions <- new.env()
2829
if (!is.null(runset$standalone_env())) {
@@ -332,8 +333,11 @@ CmdStanFit$set("public", name = "init", value = init)
332333
#' `log_prob`, `grad_log_prob`, `hessian`, `constrain_variables`,
333334
#' `unconstrain_variables` and `unconstrain_draws` functions. These are then
334335
#' available as methods of the fitted model object. This requires the
335-
#' additional `Rcpp` package, which are not required for fitting models using
336-
#' CmdStanR.
336+
#' additional \pkg{Rcpp} package.
337+
#'
338+
#' If a model or fit object was saved with [base::saveRDS()] and later
339+
#' reloaded, any previously compiled model-method bindings will be rebuilt in
340+
#' the current R session when this method is called.
337341
#'
338342
#' Note: there may be many compiler warnings emitted during compilation but
339343
#' these can be ignored so long as they are warnings and not errors.
@@ -357,6 +361,7 @@ init_model_methods <- function(seed = 1, verbose = FALSE) {
357361
call. = FALSE)
358362
}
359363
require_suggested_package("Rcpp")
364+
drop_stale_model_methods(private$model_methods_env_)
360365
if (length(private$model_methods_env_$hpp_code_) == 0) {
361366
stop("Model methods cannot be used with a pre-compiled Stan executable, ",
362367
"the model must be compiled again", call. = FALSE)

R/model.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,11 @@ CmdStanModel <- R6::R6Class(
422422
#' via a global `cmdstanr_force_recompile` option.
423423
#' @param compile_model_methods (logical) Compile additional model methods
424424
#' (`log_prob()`, `grad_log_prob()`, `constrain_variables()`,
425-
#' `unconstrain_variables()`).
425+
#' `unconstrain_variables()`). Note: the compiled model-method bindings are
426+
#' not preserved in a usable form when saving a model object. If you plan to
427+
#' save and reload the model object before model fitting, we recommend instead
428+
#' waiting to compile the model methods until after fitting via
429+
#' [`fit$init_model_methods()`][fit-method-init_model_methods].
426430
#' @param compile_standalone (logical) Should functions in the Stan model be
427431
#' compiled for use in R? If `TRUE` the functions will be available via the
428432
#' `functions` field in the compiled model object. This can also be done after

R/utils.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,37 @@ rcpp_source_stan <- function(code, env, verbose = FALSE, ...) {
792792
invisible(NULL)
793793
}
794794

795+
# Detect serialized sourceCpp wrappers whose native symbol was lost after reload.
796+
source_cpp_native_symbol_is_null <- function(fun) {
797+
if (!is.function(fun)) {
798+
return(FALSE)
799+
}
800+
fun_body <- body(fun)
801+
if (!rlang::is_call(fun_body, ".Call") || length(fun_body) < 2) {
802+
return(FALSE)
803+
}
804+
# Rcpp::sourceCpp() wrappers call into a NativeSymbol via `.Call(...)`.
805+
# After reloading a serialized object that symbol can degrade to `<pointer: 0x0>`.
806+
symbol <- fun_body[[2]]
807+
if (!inherits(symbol, "NativeSymbol")) {
808+
return(FALSE)
809+
}
810+
symbol_text <- paste(capture.output(print(symbol)), collapse = "")
811+
grepl("<pointer: (0x0+|\\(nil\\))>", symbol_text)
812+
}
813+
814+
# Drop stale compiled bindings but keep the generated C++ so model methods
815+
# can be rebuilt lazily in the current session if they are later requested.
816+
# This avoids an error when a CmdStanModel object with compiled bindings is
817+
# loaded from an older session: https://github.com/stan-dev/cmdstanr/issues/1157
818+
drop_stale_model_methods <- function(env) {
819+
if (is.null(env$model_ptr) || !source_cpp_native_symbol_is_null(env$model_ptr)) {
820+
return(invisible(FALSE))
821+
}
822+
rm(list = setdiff(ls(env, all.names = TRUE), "hpp_code_"), envir = env)
823+
invisible(TRUE)
824+
}
825+
795826
expose_model_methods <- function(env, verbose = FALSE) {
796827
if (rlang::is_interactive()) {
797828
message("Compiling additional model methods...")

man/cmdstanr-package.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/fit-method-init_model_methods.Rd

Lines changed: 5 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/model-method-compile.Rd

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-fit-shared.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,27 @@ test_that("save_object() method works", {
178178
expect_identical(fit$summary(), s)
179179
})
180180

181+
test_that("reloaded fits rebuild model methods lazily after save_object()", {
182+
skip_if(os_is_wsl())
183+
mod <- cmdstan_model(
184+
testing_stan_file("bernoulli_log_lik"),
185+
force_recompile = TRUE,
186+
compile_model_methods = TRUE
187+
)
188+
utils::capture.output(
189+
fit <- mod$optimize(data = testing_data("bernoulli"))
190+
)
191+
192+
temp_rds_file <- tempfile(fileext = ".RDS")
193+
fit$save_object(temp_rds_file)
194+
fit2 <- readRDS(temp_rds_file)
195+
196+
expect_no_error(
197+
lp <- fit2$log_prob(unconstrained_variables = c(0.1))
198+
)
199+
expect_equal(lp, -8.6327599208828509347)
200+
})
201+
181202
test_that("save_object() method works with qs2 format", {
182203
skip_if_not_installed("qs2")
183204
fit <- fits[["sample"]]

tests/testthat/test-model-methods.R

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,44 @@ test_that("Methods can be compiled with model", {
149149
expect_equal(unconstrained_variables, c(0.6))
150150
})
151151

152+
test_that("Reloaded models recompile model methods lazily after saveRDS/readRDS", {
153+
# Also tests that fitted model objects are returned without error after
154+
# saveRDS/readRDS when model methods are compiled: https://github.com/stan-dev/cmdstanr/issues/1157
155+
mod <- cmdstan_model(
156+
testing_stan_file("bernoulli_log_lik"),
157+
force_recompile = TRUE,
158+
compile_model_methods = TRUE
159+
)
160+
temp_rds_file <- tempfile(fileext = ".RDS")
161+
saveRDS(mod, temp_rds_file)
162+
mod2 <- readRDS(temp_rds_file)
163+
164+
expect_no_error(
165+
utils::capture.output(
166+
fit <- mod2$optimize(data = data_list)
167+
)
168+
)
169+
expect_equal(fit$log_prob(unconstrained_variables = c(0.1)),
170+
-8.6327599208828509347)
171+
})
172+
173+
test_that("stale model-method bindings are detected and dropped", {
174+
mod <- cmdstan_model(
175+
testing_stan_file("bernoulli_log_lik"),
176+
force_recompile = TRUE,
177+
compile_model_methods = TRUE
178+
)
179+
temp_rds_file <- tempfile(fileext = ".RDS")
180+
saveRDS(mod, temp_rds_file)
181+
mod2 <- readRDS(temp_rds_file)
182+
model_methods_env <- mod2$.__enclos_env__$private$model_methods_env_
183+
184+
expect_true(source_cpp_native_symbol_is_null(model_methods_env$model_ptr))
185+
expect_true(drop_stale_model_methods(model_methods_env))
186+
expect_equal(ls(model_methods_env, all.names = TRUE), "hpp_code_")
187+
expect_false(drop_stale_model_methods(model_methods_env))
188+
})
189+
152190
test_that("unconstrain_variables correctly handles zero-length containers", {
153191
model_code <- "
154192
data {

0 commit comments

Comments
 (0)