Skip to content

Commit 286bab4

Browse files
authored
Merge pull request #1167 from stan-dev/get_cmdstan_args
New cmdstan_defaults() method for getting CmdStan's default argument values
2 parents 60c0dfa + b807f16 commit 286bab4

18 files changed

Lines changed: 472 additions & 11 deletions

R/model.R

Lines changed: 243 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ cmdstan_model <- function(stan_file = NULL, exe_file = NULL, compile = TRUE, ...
196196
#' [`$hpp_file()`][model-method-compile] | Return the file path to the `.hpp` file containing the generated C++ code. |
197197
#' [`$save_hpp_file()`][model-method-compile] | Save the `.hpp` file containing the generated C++ code. |
198198
#' [`$expose_functions()`][model-method-expose_functions] | Expose Stan functions for use in R. |
199+
#' [`$cmdstan_defaults()`][model-method-cmdstan_defaults] | Get CmdStan default argument values for a method. |
199200
#'
200201
#' ## Diagnostics
201202
#'
@@ -2209,6 +2210,51 @@ expose_functions = function(global = FALSE, verbose = FALSE) {
22092210
CmdStanModel$set("public", name = "expose_functions", value = expose_functions)
22102211

22112212

2213+
#' Get CmdStan default argument values
2214+
#'
2215+
#' @name model-method-cmdstan_defaults
2216+
#' @aliases cmdstan_defaults
2217+
#' @family CmdStanModel methods
2218+
#'
2219+
#' @description The `$cmdstan_defaults()` method of a [`CmdStanModel`]
2220+
#' object queries the compiled model binary for the default argument
2221+
#' values used by a given inference method. The returned list uses
2222+
#' cmdstanr-style argument names (e.g., `iter_sampling` instead of
2223+
#' CmdStan's `num_samples`).
2224+
#'
2225+
#' The model must be compiled before calling this method.
2226+
#'
2227+
#' @param method (string) The inference method whose defaults to
2228+
#' retrieve. One of `"sample"`, `"optimize"`, `"variational"`,
2229+
#' `"pathfinder"`, or `"laplace"`.
2230+
#' @return A named list of default argument values for the specified
2231+
#' method, with cmdstanr-style argument names.
2232+
#'
2233+
#' @template seealso-docs
2234+
#'
2235+
#' @examples
2236+
#' \dontrun{
2237+
#' mod <- cmdstan_model(file.path(cmdstan_path(),
2238+
#' "examples/bernoulli/bernoulli.stan"))
2239+
#' mod$cmdstan_defaults("sample")
2240+
#' mod$cmdstan_defaults("optimize")
2241+
#' }
2242+
#'
2243+
cmdstan_defaults <- function(method = c("sample", "optimize", "variational",
2244+
"pathfinder", "laplace")) {
2245+
method <- match.arg(method)
2246+
if (length(self$exe_file()) == 0 || !file.exists(self$exe_file())) {
2247+
stop(
2248+
"'$cmdstan_defaults()' requires a compiled model. ",
2249+
"Please compile the model first with '$compile()'.",
2250+
call. = FALSE
2251+
)
2252+
}
2253+
parse_cmdstan_args(self$exe_file(), method)
2254+
}
2255+
CmdStanModel$set("public", name = "cmdstan_defaults", value = cmdstan_defaults)
2256+
2257+
22122258

22132259
# internal ----------------------------------------------------------------
22142260
assert_valid_stanc_options <- function(stanc_options) {
@@ -2289,10 +2335,10 @@ model_variables <- function(stan_file, include_paths = NULL, allow_undefined = F
22892335
variables
22902336
}
22912337

2292-
22932338
is_variables_method_supported <- function(mod) {
22942339
mod$has_stan_file() && file.exists(mod$stan_file())
22952340
}
2341+
22962342
resolve_exe_path <- function(dir = NULL,
22972343
private_dir = NULL,
22982344
self_exe_file = NULL,
@@ -2329,3 +2375,199 @@ resolve_exe_path <- function(dir = NULL,
23292375
}
23302376
exe
23312377
}
2378+
2379+
# cmdstan_defaults() helpers
2380+
2381+
#' Parse CmdStan default argument values from model binary
2382+
#'
2383+
#' Runs a CmdStan model binary with `help-all` to extract valid arguments
2384+
#' and their default values for a given inference method, returning them
2385+
#' with cmdstanr argument names.
2386+
#'
2387+
#' @noRd
2388+
#' @param model_binary Path to the CmdStan model binary.
2389+
#' @param method Inference method: `"sample"`, `"optimize"`,
2390+
#' `"variational"`, `"pathfinder"`, or `"laplace"`.
2391+
#' @return A named list with cmdstanr-style argument names and default
2392+
#' values.
2393+
parse_cmdstan_args <- function(model_binary, method) {
2394+
withr::with_path(
2395+
c(
2396+
toolchain_PATH_env_var(),
2397+
tbb_path()
2398+
),
2399+
ret <- wsl_compatible_run(
2400+
command = wsl_safe_path(model_binary),
2401+
args = c(method, "help-all"),
2402+
error_on_status = FALSE
2403+
)
2404+
)
2405+
# CmdStan may write help text to stdout or stderr depending on the platform
2406+
raw <- paste0(ret$stdout, ret$stderr)
2407+
output <- strsplit(raw, "\r?\n")[[1]]
2408+
2409+
argument_map <- map_cmdstan_to_cmdstanr(method)
2410+
cmdstan_keys <- unname(argument_map)
2411+
public_names <- names(argument_map)
2412+
2413+
defaults <- list()
2414+
n <- length(output)
2415+
# Track the current hierarchical argument key using section indentation.
2416+
section_indents <- integer(0)
2417+
section_names <- character(0)
2418+
2419+
for (i in seq_len(n)) {
2420+
line <- output[i]
2421+
content <- trimws(line)
2422+
2423+
# Skip blank lines so they don't reset the section stack
2424+
if (!nzchar(content)) next
2425+
2426+
indent <- nchar(sub("^(\\s*).*", "\\1", line))
2427+
2428+
# Drop sections at deeper or equal indentation
2429+
while (length(section_indents) > 0 &&
2430+
section_indents[[length(section_indents)]] >= indent) {
2431+
section_indents <- section_indents[-length(section_indents)]
2432+
section_names <- section_names[-length(section_names)]
2433+
}
2434+
2435+
section_name <- parse_cmdstan_section_name(content)
2436+
if (!is.null(section_name)) {
2437+
section_indents <- c(section_indents, indent)
2438+
section_names <- c(section_names, section_name)
2439+
next
2440+
}
2441+
2442+
arg_name <- parse_cmdstan_arg_name(content)
2443+
if (!is.null(arg_name)) {
2444+
2445+
# Build the full dotted argument key: method.section1.section2...arg_name
2446+
# The top-level method heading (e.g. "sample") is tracked as a section,
2447+
# so it becomes the first segment of the key.
2448+
full_key <- paste(c(section_names, arg_name), collapse = ".")
2449+
2450+
# Check if this full argument key matches one of our target arguments
2451+
match_idx <- match(full_key, cmdstan_keys, nomatch = 0L)
2452+
2453+
if (match_idx > 0L) {
2454+
default_value <- find_cmdstan_default_value(output, i, n)
2455+
defaults[[public_names[[match_idx]]]] <- default_value
2456+
}
2457+
}
2458+
}
2459+
2460+
defaults
2461+
}
2462+
2463+
#' Parse CmdStan section name from a help-all line
2464+
#' @noRd
2465+
parse_cmdstan_section_name <- function(line) {
2466+
match <- regmatches(line, regexec("^([a-z_][a-z0-9_]*)$", line))[[1]]
2467+
if (length(match) >= 2) match[2] else NULL
2468+
}
2469+
2470+
#' Parse CmdStan argument name from a help-all line
2471+
#' @noRd
2472+
parse_cmdstan_arg_name <- function(line) {
2473+
match <- regmatches(line, regexec("^([a-z_][a-z0-9_]*)=", line))[[1]]
2474+
if (length(match) >= 2) match[2] else NULL
2475+
}
2476+
2477+
#' Find CmdStan default value following a help-all argument line
2478+
#' @noRd
2479+
find_cmdstan_default_value <- function(output, line_idx, n_lines) {
2480+
default_value <- NULL
2481+
2482+
for (j in (line_idx + 1):min(line_idx + 5, n_lines)) {
2483+
next_content <- trimws(output[j])
2484+
if (grepl("^Defaults to", next_content)) {
2485+
default_value <- parse_default_value(next_content)
2486+
break
2487+
}
2488+
# Stop if we hit another argument
2489+
if (grepl("^[a-z_][a-z0-9_]*=", next_content)) break
2490+
}
2491+
2492+
default_value
2493+
}
2494+
2495+
#' Parse default value from "Defaults to ..." line
2496+
#' @noRd
2497+
parse_default_value <- function(line) {
2498+
val_str <- sub("^Defaults to\\s*", "", line)
2499+
if (val_str %in% c("true", "false")) return(val_str == "true")
2500+
if (grepl("^-?[0-9]+$", val_str)) return(as.integer(val_str))
2501+
if (grepl("^-?[0-9]*\\.?[0-9]+([eE][+-]?[0-9]+)?$", val_str)) return(as.numeric(val_str))
2502+
val_str
2503+
}
2504+
2505+
#' Map CmdStan argument names to CmdStanR argument names
2506+
#' @noRd
2507+
map_cmdstan_to_cmdstanr <- function(method) {
2508+
switch(method,
2509+
sample = c(
2510+
iter_sampling = "sample.num_samples",
2511+
iter_warmup = "sample.num_warmup",
2512+
save_warmup = "sample.save_warmup",
2513+
thin = "sample.thin",
2514+
adapt_engaged = "sample.adapt.engaged",
2515+
adapt_delta = "sample.adapt.delta",
2516+
init_buffer = "sample.adapt.init_buffer",
2517+
term_buffer = "sample.adapt.term_buffer",
2518+
window = "sample.adapt.window",
2519+
save_metric = "sample.adapt.save_metric",
2520+
max_treedepth = "sample.hmc.nuts.max_depth",
2521+
metric = "sample.hmc.metric",
2522+
metric_file = "sample.hmc.metric_file",
2523+
step_size = "sample.hmc.stepsize"
2524+
),
2525+
optimize = c(
2526+
algorithm = "optimize.algorithm",
2527+
jacobian = "optimize.jacobian",
2528+
iter = "optimize.iter",
2529+
init_alpha = "optimize.lbfgs.init_alpha",
2530+
tol_obj = "optimize.lbfgs.tol_obj",
2531+
tol_rel_obj = "optimize.lbfgs.tol_rel_obj",
2532+
tol_grad = "optimize.lbfgs.tol_grad",
2533+
tol_rel_grad = "optimize.lbfgs.tol_rel_grad",
2534+
tol_param = "optimize.lbfgs.tol_param",
2535+
history_size = "optimize.lbfgs.history_size"
2536+
),
2537+
variational = c(
2538+
algorithm = "variational.algorithm",
2539+
iter = "variational.iter",
2540+
grad_samples = "variational.grad_samples",
2541+
elbo_samples = "variational.elbo_samples",
2542+
eta = "variational.eta",
2543+
adapt_engaged = "variational.adapt.engaged",
2544+
adapt_iter = "variational.adapt.iter",
2545+
tol_rel_obj = "variational.tol_rel_obj",
2546+
eval_elbo = "variational.eval_elbo",
2547+
draws = "variational.output_samples"
2548+
),
2549+
pathfinder = c(
2550+
init_alpha = "pathfinder.init_alpha",
2551+
tol_obj = "pathfinder.tol_obj",
2552+
tol_rel_obj = "pathfinder.tol_rel_obj",
2553+
tol_grad = "pathfinder.tol_grad",
2554+
tol_rel_grad = "pathfinder.tol_rel_grad",
2555+
tol_param = "pathfinder.tol_param",
2556+
history_size = "pathfinder.history_size",
2557+
draws = "pathfinder.num_psis_draws",
2558+
num_paths = "pathfinder.num_paths",
2559+
save_single_paths = "pathfinder.save_single_paths",
2560+
psis_resample = "pathfinder.psis_resample",
2561+
calculate_lp = "pathfinder.calculate_lp",
2562+
max_lbfgs_iters = "pathfinder.max_lbfgs_iters",
2563+
single_path_draws = "pathfinder.num_draws",
2564+
num_elbo_draws = "pathfinder.num_elbo_draws"
2565+
),
2566+
laplace = c(
2567+
jacobian = "laplace.jacobian",
2568+
draws = "laplace.draws"
2569+
),
2570+
character(0)
2571+
)
2572+
}
2573+

man/CmdStanModel.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/cmdstanr-package.Rd

Lines changed: 10 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/model-method-check_syntax.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/model-method-cmdstan_defaults.Rd

Lines changed: 65 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/model-method-compile.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)