Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -1117,37 +1117,77 @@ process_init.draws <- function(init, num_procs, model_variables = NULL,
method ="simple_no_replace")
}
draws_rvar <- posterior::as_draws_rvars(draws)
variable_names <- variable_names[variable_names %in% names(draws_rvar)]
draws_rvar <- posterior::subset_draws(draws_rvar, variable = variable_names)

# Separate tuple and non-tuple parameters. Tuple parameters use leaf names
# in draws (e.g., "b_tuple:1:1") rather than the Stan-level name ("b_tuple"),
# so they need special handling via build_tuple_init_value().
is_tuple <- if (!is.null(model_variables)) {
vapply(variable_names, function(nm) {
is_tuple_type(model_variables$parameters[[nm]])
}, logical(1))
} else {
rep(FALSE, length(variable_names))
}
tuple_names <- variable_names[is_tuple]
scalar_names <- variable_names[!is_tuple]

# Filter non-tuple names to those present in draws
scalar_names <- scalar_names[scalar_names %in% names(draws_rvar)]

# For tuple names, check that their leaf draws exist
rvar_names <- names(draws_rvar)
tuple_names <- tuple_names[vapply(tuple_names, function(nm) {
any(startsWith(rvar_names, paste0(nm, ":")))
}, logical(1))]

all_names <- c(scalar_names, tuple_names)

if (length(all_names) > 0) {
draws_rvar <- posterior::subset_draws(
draws_rvar,
variable = expand_stan_params_to_leaves(all_names, rvar_names)
)
}

inits <- lapply(1:num_procs, function(draw_iter) {
init_i <- lapply(variable_names, function(var_name) {
x <- .remove_leftmost_dim(posterior::draws_of(
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter)))
bad_names <- character(0)

# Extract non-tuple parameters
init_i <- lapply(scalar_names, function(var_name) {
x <- .extract_draw_value(var_name, draws_rvar, draw_iter)
if (any(is.infinite(x)) || any(is.na(x))) {
bad_names[[length(bad_names) + 1L]] <<- var_name
}
if (model_variables$parameters[[var_name]]$dimensions == 0) {
return(as.double(x))
} else {
return(x)
}
})
bad_names <- unlist(lapply(variable_names, function(var_name) {
x <- drop(posterior::draws_of(drop(
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter))))
if (any(is.infinite(x)) || any(is.na(x))) {
return(var_name)
names(init_i) <- scalar_names

# Extract tuple parameters (build_tuple_init_value also validates)
for (var_name in tuple_names) {
tuple_result <- build_tuple_init_value(
var_name, model_variables$parameters[[var_name]],
draws_rvar, draw_iter
)
init_i[[var_name]] <- tuple_result$value
if (length(tuple_result$bad_leaves) > 0) {
bad_names <- c(bad_names, var_name)
}
return("")
}))
any_na_or_inf <- bad_names != ""
if (any(any_na_or_inf)) {
err_msg <- paste0(paste(bad_names[any_na_or_inf], collapse = ", "), " contains NA or Inf values!")
if (length(any_na_or_inf) > 1) {
}

if (length(bad_names) > 0) {
err_msg <- paste0(paste(bad_names, collapse = ", "), " contains NA or Inf values!")
if (length(bad_names) > 1) {
err_msg <- paste0("Variables: ", err_msg)
} else {
err_msg <- paste0("Variable: ", err_msg)
}
stop(err_msg)
}
names(init_i) <- variable_names

return(init_i)
})
return(process_init(inits, num_procs, model_variables, warn_partial))
Expand Down Expand Up @@ -1265,7 +1305,7 @@ process_init.function <- function(init, num_procs, model_variables = NULL,
validate_fit_init <- function(init, model_variables) {
if (all(init$return_codes() == 1)) {
stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.")
} else if (!is.null(model_variables) &&!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) {
} else if (!is.null(model_variables) && !any(stan_param_has_leaf(names(model_variables$parameters), init$metadata()$stan_variables))) {
stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.")
}
}
Expand Down
50 changes: 46 additions & 4 deletions R/csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -923,11 +923,36 @@ check_csv_metadata_matches <- function(csv_metadata) {
}

# convert names like beta.1.1 to beta[1,1]
# also handles complex suffixes (.real/.imag) and tuple separators (:)
repair_variable_names <- function(names) {
# 1. Detect and strip .real/.imag suffix before dot conversion
complex_suffix <- ifelse(
grepl("\\.real$", names), ",real",
ifelse(grepl("\\.imag$", names), ",imag", "")
)
names <- sub("\\.(real|imag)$", "", names)

# 2. Standard dot-to-bracket conversion (remaining dots are numeric indices)
names <- sub("\\.", "[", names)
names <- gsub("\\.", ",", names)
names[grep("\\[", names)] <-
paste0(names[grep("\\[", names)], "]")
has_bracket <- grepl("\\[", names)
names[has_bracket] <- paste0(names[has_bracket], "]")

# 3. Re-attach complex suffix
has_complex <- nzchar(complex_suffix)
has_both <- has_complex & has_bracket
has_complex_only <- has_complex & !has_bracket
# Had numeric indices: insert complex suffix before closing ]
names[has_both] <- paste0(
sub("\\]$", "", names[has_both]),
complex_suffix[has_both], "]"
)
# No numeric indices: wrap in brackets
names[has_complex_only] <- paste0(
names[has_complex_only], "[",
sub("^,", "", complex_suffix[has_complex_only]), "]"
)

names
}

Expand Down Expand Up @@ -994,8 +1019,25 @@ variable_dims <- function(variable_names = NULL) {
var_indices <- var_names[grep(pattern, var_names)]
var_indices <- gsub(pattern, "", var_indices)
if (length(var_indices)) {
var_indices <- strsplit(var_indices[length(var_indices)], ",")[[1]]
dims[[var]] <- as.numeric(var_indices)
# Split the last index entry by comma to determine number of dimensions
last_indices <- strsplit(var_indices[length(var_indices)], ",")[[1]]
ndims <- length(last_indices)
dim_sizes <- integer(ndims)
for (d in seq_len(ndims)) {
num_idx <- suppressWarnings(as.integer(last_indices[d]))
if (!is.na(num_idx)) {
# Numeric index: the maximum value is the dimension size
dim_sizes[d] <- num_idx
} else {
# Non-numeric index (e.g., "real"/"imag" for complex, or
# tuple indices like "1:2"): count unique values across all
# entries for this dimension position
all_indices <- strsplit(var_indices, ",")
unique_vals <- unique(vapply(all_indices, `[`, character(1), d))
dim_sizes[d] <- length(unique_vals)
}
}
dims[[var]] <- dim_sizes
} else {
dims[[var]] <- 1
}
Expand Down
38 changes: 37 additions & 1 deletion R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ write_stan_json <- function(data, file, always_decimal = FALSE) {
} else if (is.data.frame(var)) {
var <- data.matrix(var)
} else if (is.list(var)) {
var <- list_to_array(var, var_name)
if (is_tuple_list(var)) {
var <- prepare_tuple_for_json(var)
} else {
var <- list_to_array(var, var_name)
}
}
data[[var_name]] <- var
}
Expand All @@ -110,6 +114,38 @@ write_stan_json <- function(data, file, always_decimal = FALSE) {
}


# Detect whether a list represents a Stan tuple value.
# Tuple lists are named lists with string-integer keys ("1", "2", ...)
# corresponding to the tuple element positions.
is_tuple_list <- function(x) {
nms <- names(x)
if (is.null(nms) || length(nms) == 0) {
return(FALSE)
}
expected <- as.character(seq_along(x))
identical(nms, expected)
}

# Recursively prepare a tuple value for JSON serialization.
# Processes sub-elements: nested tuple lists are recursed into,
# array-style lists (unnamed, homogeneous) are converted via list_to_array,
# and numeric/logical values are left as-is.
prepare_tuple_for_json <- function(x) {
for (i in seq_along(x)) {
val <- x[[i]]
if (is.list(val)) {
if (is_tuple_list(val)) {
x[[i]] <- prepare_tuple_for_json(val)
} else {
x[[i]] <- list_to_array(val)
}
} else if (is.logical(val)) {
mode(x[[i]]) <- "integer"
}
}
x
}

list_to_array <- function(x, name = NULL) {
list_length <- length(x)
if (list_length == 0) {
Expand Down
23 changes: 18 additions & 5 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,12 @@ unconstrain_variables <- function(variables) {
model_variables <- self$runset$args$model_variables

# If zero-length parameters are present, they will be listed in model_variables
# but not in metadata()$variables
nonzero_length_params <- names(model_variables$parameters) %in% model_par_names
# but not in metadata()$variables. For tuple parameters, model_variables uses
# the Stan-level name (e.g., "b_tuple") while model_par_names uses leaf names
# with ":" separators (e.g., "b_tuple:1:1"), so we use prefix matching.
nonzero_length_params <- stan_param_has_leaf(
names(model_variables$parameters), model_par_names
)
model_par_names <- names(model_variables$parameters[nonzero_length_params])

model_pars_not_prov <- which(!(model_par_names %in% prov_par_names))
Expand Down Expand Up @@ -589,14 +593,23 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
model_variables <- self$runset$args$model_variables

# If zero-length parameters are present, they will be listed in model_variables
# but not in metadata()$variables
nonzero_length_params <- names(model_variables$parameters) %in% model_par_names
# but not in metadata()$variables. For tuple parameters, model_variables uses
# the Stan-level name (e.g., "b_tuple") while model_par_names uses leaf names
# with ":" separators (e.g., "b_tuple:1:1"), so we use prefix matching.
nonzero_length_params <- stan_param_has_leaf(
names(model_variables$parameters), model_par_names
)

# Remove zero-length parameters from model_variables, otherwise process_init
# warns about missing inputs
pars <- names(model_variables$parameters[nonzero_length_params])

draws <- posterior::subset_draws(draws, variable = pars)
# For subset_draws, we need to use the leaf-level names from stan_variables
# (e.g., "b_tuple:1:1") rather than Stan-level names (e.g., "b_tuple"),
# because posterior doesn't recognize Stan-level tuple names.
pars_for_draws <- expand_stan_params_to_leaves(pars, model_par_names)

draws <- posterior::subset_draws(draws, variable = pars_for_draws)
unconstrained <- private$model_methods_env_$unconstrain_draws(private$model_methods_env_$model_ptr_, draws)
uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
names(unconstrained) <- repair_variable_names(uncon_names)
Expand Down
108 changes: 107 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,94 @@ initialize_model_pointer <- function(env, datafile_path, seed = 0) {
invisible(NULL)
}

# Check if Stan-level parameter names (which may include tuple names like
# "b_tuple") have a match among leaf-level variable names (which use ":"
# to separate tuple elements, e.g., "b_tuple:1:1", "b_tuple:1:2").
# A Stan-level name matches if it appears directly in leaf_names, or if
# any leaf name starts with "<name>:" (tuple expansion).
stan_param_has_leaf <- function(stan_names, leaf_names) {
vapply(stan_names, function(nm) {
nm %in% leaf_names || any(startsWith(leaf_names, paste0(nm, ":")))
}, logical(1), USE.NAMES = FALSE)
}

# Check if a parameter's type info represents a tuple.
# Tuples have $type as a list; non-tuples have $type as a string.
is_tuple_type <- function(var_info) {
is.list(var_info$type)
}

# Reconstruct a tuple init value as a nested named list from flat leaf draws.
# Also validates that no leaf values contain NA or Inf.
#
# @param path The accumulated `:` path (e.g., "b_tuple", "b_tuple:1")
# @param var_info The type info at this level (from model_variables)
# @param draws_rvar The draws_rvars object containing leaf entries
# @param draw_iter Which draw iteration to extract
# @return A list with two elements:
# - `value`: nested named list suitable for CmdStan JSON
# - `bad_leaves`: character vector of leaf names with NA/Inf values
build_tuple_init_value <- function(path, var_info, draws_rvar, draw_iter) {
components <- var_info$type
result <- vector("list", length(components))
names(result) <- as.character(seq_along(components))
bad_leaves <- character(0)
for (i in seq_along(components)) {
child_path <- paste0(path, ":", i)
child_info <- components[[i]]
if (is_tuple_type(child_info)) {
child <- build_tuple_init_value(
child_path, child_info, draws_rvar, draw_iter
)
result[[i]] <- child$value
bad_leaves <- c(bad_leaves, child$bad_leaves)
} else {
x <- .extract_draw_value(child_path, draws_rvar, draw_iter)
if (any(is.infinite(x)) || any(is.na(x))) {
bad_leaves <- c(bad_leaves, child_path)
}
if (child_info$dimensions == 0) {
result[[i]] <- as.double(x)
} else {
result[[i]] <- x
}
}
}
list(value = result, bad_leaves = bad_leaves)
}

# Extract a single draw value from draws_rvar for a given variable name.
# Handles the subset → draws_of → remove_leftmost_dim pipeline.
.extract_draw_value <- function(var_name, draws_rvar, draw_iter) {
.remove_leftmost_dim(posterior::draws_of(
posterior::subset_draws(draws_rvar[[var_name]], draw = draw_iter)
))
}

# Expand Stan-level parameter names to their leaf-level equivalents in
# stan_variables. Non-tuple names pass through unchanged. Tuple names
# (e.g., "b_tuple") are expanded to all matching leaf names
# (e.g., "b_tuple:1:1", "b_tuple:1:2", "b_tuple:2").
expand_stan_params_to_leaves <- function(stan_params, leaf_names) {
result <- character(0)
for (param in stan_params) {
if (param %in% leaf_names) {
result <- c(result, param)
} else {
# Find leaf-level names for this tuple parameter
prefix <- paste0(param, ":")
leaves <- leaf_names[startsWith(leaf_names, prefix)]
if (length(leaves) > 0) {
result <- c(result, leaves)
} else {
# No match found, include as-is (will be caught by subset_draws)
result <- c(result, param)
}
}
}
result
}

create_skeleton <- function(param_metadata, model_variables,
transformed_parameters, generated_quantities) {
target_params <- names(model_variables$parameters)
Expand All @@ -878,7 +966,25 @@ create_skeleton <- function(param_metadata, model_variables,
target_params <- c(target_params,
names(model_variables$generated_quantities))
}
lapply(param_metadata[target_params], function(par_dims) {
# Expand target_params to match param_metadata leaf names.
# For tuple parameters, the Stan-level name (e.g., "b_tuple") maps to
# multiple leaf entries in param_metadata (e.g., "b_tuple.1.1",
# "b_tuple.1.2", "b_tuple.2"). We expand by matching the prefix.
meta_names <- names(param_metadata)
expanded_params <- character(0)
for (param in target_params) {
if (param %in% meta_names) {
expanded_params <- c(expanded_params, param)
} else {
# Find leaf entries with this prefix (tuple expansion)
prefix <- paste0(param, ".")
leaves <- meta_names[startsWith(meta_names, prefix)]
if (length(leaves) > 0) {
expanded_params <- c(expanded_params, leaves)
}
}
}
lapply(param_metadata[expanded_params], function(par_dims) {
if ((length(par_dims) == 0)) {
array(0, dim = 1)
} else {
Expand Down
Loading
Loading