From 270970accbba803b1c350f4bfac46629ccdc7f8a Mon Sep 17 00:00:00 2001 From: Aki Vehtari Date: Wed, 8 Apr 2026 14:42:12 +0300 Subject: [PATCH 1/2] fix: handle Stan tuple and complex types --- R/args.R | 76 +++- R/csv.R | 50 ++- R/data.R | 38 +- R/fit.R | 23 +- R/utils.R | 108 ++++- .../resources/stan/tuple_complex.stan | 40 ++ tests/testthat/test-tuple-complex.R | 388 ++++++++++++++++++ 7 files changed, 694 insertions(+), 29 deletions(-) create mode 100644 tests/testthat/resources/stan/tuple_complex.stan create mode 100644 tests/testthat/test-tuple-complex.R diff --git a/R/args.R b/R/args.R index 9c135025..7829e6d6 100644 --- a/R/args.R +++ b/R/args.R @@ -1096,37 +1096,77 @@ process_init.draws <- function(init, num_procs, model_variables = NULL, method ="simple_no_replace") } draws_rvar = posterior::as_draws_rvars(draws) - variable_names <- variable_names[variable_names %in% names(draws_rvar)] - draws_rvar <- posterior::subset_draws(draws_rvar, variable = variable_names) + + # Separate tuple and non-tuple parameters. Tuple parameters use leaf names + # in draws (e.g., "b_tuple:1:1") rather than the Stan-level name ("b_tuple"), + # so they need special handling via build_tuple_init_value(). + is_tuple <- if (!is.null(model_variables)) { + vapply(variable_names, function(nm) { + is_tuple_type(model_variables$parameters[[nm]]) + }, logical(1)) + } else { + rep(FALSE, length(variable_names)) + } + tuple_names <- variable_names[is_tuple] + scalar_names <- variable_names[!is_tuple] + + # Filter non-tuple names to those present in draws + scalar_names <- scalar_names[scalar_names %in% names(draws_rvar)] + + # For tuple names, check that their leaf draws exist + rvar_names <- names(draws_rvar) + tuple_names <- tuple_names[vapply(tuple_names, function(nm) { + any(startsWith(rvar_names, paste0(nm, ":"))) + }, logical(1))] + + all_names <- c(scalar_names, tuple_names) + + if (length(scalar_names) > 0) { + draws_rvar <- posterior::subset_draws( + draws_rvar, + variable = expand_stan_params_to_leaves(all_names, rvar_names) + ) + } + inits = lapply(1:num_procs, function(draw_iter) { - init_i = lapply(variable_names, function(var_name) { - x = .remove_leftmost_dim(posterior::draws_of( - posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter))) + bad_names <- character(0) + + # Extract non-tuple parameters + init_i = lapply(scalar_names, function(var_name) { + x = .extract_draw_value(var_name, draws_rvar, draw_iter) + if (any(is.infinite(x)) || any(is.na(x))) { + bad_names[[length(bad_names) + 1L]] <<- var_name + } if (model_variables$parameters[[var_name]]$dimensions == 0) { return(as.double(x)) } else { return(x) } }) - bad_names = unlist(lapply(variable_names, function(var_name) { - x = drop(posterior::draws_of(drop( - posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter)))) - if (any(is.infinite(x)) || any(is.na(x))) { - return(var_name) + names(init_i) <- scalar_names + + # Extract tuple parameters (build_tuple_init_value also validates) + for (var_name in tuple_names) { + tuple_result <- build_tuple_init_value( + var_name, model_variables$parameters[[var_name]], + draws_rvar, draw_iter + ) + init_i[[var_name]] <- tuple_result$value + if (length(tuple_result$bad_leaves) > 0) { + bad_names <- c(bad_names, var_name) } - return("") - })) - any_na_or_inf = bad_names != "" - if (any(any_na_or_inf)) { - err_msg = paste0(paste(bad_names[any_na_or_inf], collapse = ", "), " contains NA or Inf values!") - if (length(any_na_or_inf) > 1) { + } + + if (length(bad_names) > 0) { + err_msg = paste0(paste(bad_names, collapse = ", "), " contains NA or Inf values!") + if (length(bad_names) > 1) { err_msg = paste0("Variables: ", err_msg) } else { err_msg = paste0("Variable: ", err_msg) } stop(err_msg) } - names(init_i) = variable_names + return(init_i) }) return(process_init(inits, num_procs, model_variables, warn_partial)) @@ -1245,7 +1285,7 @@ validate_fit_init = function(init, model_variables) { # Convert from data.table to data.frame if (all(init$return_codes() == 1)) { stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.") - } else if (!is.null(model_variables) &&!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) { + } else if (!is.null(model_variables) && !any(stan_param_has_leaf(names(model_variables$parameters), init$metadata()$stan_variables))) { stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.") } } diff --git a/R/csv.R b/R/csv.R index 5fce18f4..986d6221 100644 --- a/R/csv.R +++ b/R/csv.R @@ -923,11 +923,36 @@ check_csv_metadata_matches <- function(csv_metadata) { } # convert names like beta.1.1 to beta[1,1] +# also handles complex suffixes (.real/.imag) and tuple separators (:) repair_variable_names <- function(names) { + # 1. Detect and strip .real/.imag suffix before dot conversion + complex_suffix <- ifelse( + grepl("\\.real$", names), ",real", + ifelse(grepl("\\.imag$", names), ",imag", "") + ) + names <- sub("\\.(real|imag)$", "", names) + + # 2. Standard dot-to-bracket conversion (remaining dots are numeric indices) names <- sub("\\.", "[", names) names <- gsub("\\.", ",", names) - names[grep("\\[", names)] <- - paste0(names[grep("\\[", names)], "]") + has_bracket <- grepl("\\[", names) + names[has_bracket] <- paste0(names[has_bracket], "]") + + # 3. Re-attach complex suffix + has_complex <- nzchar(complex_suffix) + has_both <- has_complex & has_bracket + has_complex_only <- has_complex & !has_bracket + # Had numeric indices: insert complex suffix before closing ] + names[has_both] <- paste0( + sub("\\]$", "", names[has_both]), + complex_suffix[has_both], "]" + ) + # No numeric indices: wrap in brackets + names[has_complex_only] <- paste0( + names[has_complex_only], "[", + sub("^,", "", complex_suffix[has_complex_only]), "]" + ) + names } @@ -994,8 +1019,25 @@ variable_dims <- function(variable_names = NULL) { var_indices <- var_names[grep(pattern, var_names)] var_indices <- gsub(pattern, "", var_indices) if (length(var_indices)) { - var_indices <- strsplit(var_indices[length(var_indices)], ",")[[1]] - dims[[var]] <- as.numeric(var_indices) + # Split the last index entry by comma to determine number of dimensions + last_indices <- strsplit(var_indices[length(var_indices)], ",")[[1]] + ndims <- length(last_indices) + dim_sizes <- integer(ndims) + for (d in seq_len(ndims)) { + num_idx <- suppressWarnings(as.integer(last_indices[d])) + if (!is.na(num_idx)) { + # Numeric index: the maximum value is the dimension size + dim_sizes[d] <- num_idx + } else { + # Non-numeric index (e.g., "real"/"imag" for complex, or + # tuple indices like "1:2"): count unique values across all + # entries for this dimension position + all_indices <- strsplit(var_indices, ",") + unique_vals <- unique(vapply(all_indices, `[`, character(1), d)) + dim_sizes[d] <- length(unique_vals) + } + } + dims[[var]] <- dim_sizes } else { dims[[var]] <- 1 } diff --git a/R/data.R b/R/data.R index d13605d7..f8f26d29 100644 --- a/R/data.R +++ b/R/data.R @@ -92,7 +92,11 @@ write_stan_json <- function(data, file, always_decimal = FALSE) { } else if (is.data.frame(var)) { var <- data.matrix(var) } else if (is.list(var)) { - var <- list_to_array(var, var_name) + if (is_tuple_list(var)) { + var <- prepare_tuple_for_json(var) + } else { + var <- list_to_array(var, var_name) + } } data[[var_name]] <- var } @@ -110,6 +114,38 @@ write_stan_json <- function(data, file, always_decimal = FALSE) { } +# Detect whether a list represents a Stan tuple value. +# Tuple lists are named lists with string-integer keys ("1", "2", ...) +# corresponding to the tuple element positions. +is_tuple_list <- function(x) { + nms <- names(x) + if (is.null(nms) || length(nms) == 0) { + return(FALSE) + } + expected <- as.character(seq_along(x)) + identical(nms, expected) +} + +# Recursively prepare a tuple value for JSON serialization. +# Processes sub-elements: nested tuple lists are recursed into, +# array-style lists (unnamed, homogeneous) are converted via list_to_array, +# and numeric/logical values are left as-is. +prepare_tuple_for_json <- function(x) { + for (i in seq_along(x)) { + val <- x[[i]] + if (is.list(val)) { + if (is_tuple_list(val)) { + x[[i]] <- prepare_tuple_for_json(val) + } else { + x[[i]] <- list_to_array(val) + } + } else if (is.logical(val)) { + mode(x[[i]]) <- "integer" + } + } + x +} + list_to_array <- function(x, name = NULL) { list_length <- length(x) if (list_length == 0) { diff --git a/R/fit.R b/R/fit.R index bf60a0e9..d0dc2e93 100644 --- a/R/fit.R +++ b/R/fit.R @@ -499,8 +499,12 @@ unconstrain_variables <- function(variables) { model_variables <- self$runset$args$model_variables # If zero-length parameters are present, they will be listed in model_variables - # but not in metadata()$variables - nonzero_length_params <- names(model_variables$parameters) %in% model_par_names + # but not in metadata()$variables. For tuple parameters, model_variables uses + # the Stan-level name (e.g., "b_tuple") while model_par_names uses leaf names + # with ":" separators (e.g., "b_tuple:1:1"), so we use prefix matching. + nonzero_length_params <- stan_param_has_leaf( + names(model_variables$parameters), model_par_names + ) model_par_names <- names(model_variables$parameters[nonzero_length_params]) model_pars_not_prov <- which(!(model_par_names %in% prov_par_names)) @@ -589,14 +593,23 @@ unconstrain_draws <- function(files = NULL, draws = NULL, model_variables <- self$runset$args$model_variables # If zero-length parameters are present, they will be listed in model_variables - # but not in metadata()$variables - nonzero_length_params <- names(model_variables$parameters) %in% model_par_names + # but not in metadata()$variables. For tuple parameters, model_variables uses + # the Stan-level name (e.g., "b_tuple") while model_par_names uses leaf names + # with ":" separators (e.g., "b_tuple:1:1"), so we use prefix matching. + nonzero_length_params <- stan_param_has_leaf( + names(model_variables$parameters), model_par_names + ) # Remove zero-length parameters from model_variables, otherwise process_init # warns about missing inputs pars <- names(model_variables$parameters[nonzero_length_params]) - draws <- posterior::subset_draws(draws, variable = pars) + # For subset_draws, we need to use the leaf-level names from stan_variables + # (e.g., "b_tuple:1:1") rather than Stan-level names (e.g., "b_tuple"), + # because posterior doesn't recognize Stan-level tuple names. + pars_for_draws <- expand_stan_params_to_leaves(pars, model_par_names) + + draws <- posterior::subset_draws(draws, variable = pars_for_draws) unconstrained <- private$model_methods_env_$unconstrain_draws(private$model_methods_env_$model_ptr_, draws) uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE) names(unconstrained) <- repair_variable_names(uncon_names) diff --git a/R/utils.R b/R/utils.R index 40fa62c3..aef561a9 100644 --- a/R/utils.R +++ b/R/utils.R @@ -867,6 +867,94 @@ initialize_model_pointer <- function(env, datafile_path, seed = 0) { invisible(NULL) } +# Check if Stan-level parameter names (which may include tuple names like +# "b_tuple") have a match among leaf-level variable names (which use ":" +# to separate tuple elements, e.g., "b_tuple:1:1", "b_tuple:1:2"). +# A Stan-level name matches if it appears directly in leaf_names, or if +# any leaf name starts with ":" (tuple expansion). +stan_param_has_leaf <- function(stan_names, leaf_names) { + vapply(stan_names, function(nm) { + nm %in% leaf_names || any(startsWith(leaf_names, paste0(nm, ":"))) + }, logical(1), USE.NAMES = FALSE) +} + +# Check if a parameter's type info represents a tuple. +# Tuples have $type as a list; non-tuples have $type as a string. +is_tuple_type <- function(var_info) { + is.list(var_info$type) +} + +# Reconstruct a tuple init value as a nested named list from flat leaf draws. +# Also validates that no leaf values contain NA or Inf. +# +# @param path The accumulated `:` path (e.g., "b_tuple", "b_tuple:1") +# @param var_info The type info at this level (from model_variables) +# @param draws_rvar The draws_rvars object containing leaf entries +# @param draw_iter Which draw iteration to extract +# @return A list with two elements: +# - `value`: nested named list suitable for CmdStan JSON +# - `bad_leaves`: character vector of leaf names with NA/Inf values +build_tuple_init_value <- function(path, var_info, draws_rvar, draw_iter) { + components <- var_info$type + result <- vector("list", length(components)) + names(result) <- as.character(seq_along(components)) + bad_leaves <- character(0) + for (i in seq_along(components)) { + child_path <- paste0(path, ":", i) + child_info <- components[[i]] + if (is_tuple_type(child_info)) { + child <- build_tuple_init_value( + child_path, child_info, draws_rvar, draw_iter + ) + result[[i]] <- child$value + bad_leaves <- c(bad_leaves, child$bad_leaves) + } else { + x <- .extract_draw_value(child_path, draws_rvar, draw_iter) + if (any(is.infinite(x)) || any(is.na(x))) { + bad_leaves <- c(bad_leaves, child_path) + } + if (child_info$dimensions == 0) { + result[[i]] <- as.double(x) + } else { + result[[i]] <- x + } + } + } + list(value = result, bad_leaves = bad_leaves) +} + +# Extract a single draw value from draws_rvar for a given variable name. +# Handles the subset → draws_of → remove_leftmost_dim pipeline. +.extract_draw_value <- function(var_name, draws_rvar, draw_iter) { + .remove_leftmost_dim(posterior::draws_of( + posterior::subset_draws(draws_rvar[[var_name]], draw = draw_iter) + )) +} + +# Expand Stan-level parameter names to their leaf-level equivalents in +# stan_variables. Non-tuple names pass through unchanged. Tuple names +# (e.g., "b_tuple") are expanded to all matching leaf names +# (e.g., "b_tuple:1:1", "b_tuple:1:2", "b_tuple:2"). +expand_stan_params_to_leaves <- function(stan_params, leaf_names) { + result <- character(0) + for (param in stan_params) { + if (param %in% leaf_names) { + result <- c(result, param) + } else { + # Find leaf-level names for this tuple parameter + prefix <- paste0(param, ":") + leaves <- leaf_names[startsWith(leaf_names, prefix)] + if (length(leaves) > 0) { + result <- c(result, leaves) + } else { + # No match found, include as-is (will be caught by subset_draws) + result <- c(result, param) + } + } + } + result +} + create_skeleton <- function(param_metadata, model_variables, transformed_parameters, generated_quantities) { target_params <- names(model_variables$parameters) @@ -878,7 +966,25 @@ create_skeleton <- function(param_metadata, model_variables, target_params <- c(target_params, names(model_variables$generated_quantities)) } - lapply(param_metadata[target_params], function(par_dims) { + # Expand target_params to match param_metadata leaf names. + # For tuple parameters, the Stan-level name (e.g., "b_tuple") maps to + # multiple leaf entries in param_metadata (e.g., "b_tuple.1.1", + # "b_tuple.1.2", "b_tuple.2"). We expand by matching the prefix. + meta_names <- names(param_metadata) + expanded_params <- character(0) + for (param in target_params) { + if (param %in% meta_names) { + expanded_params <- c(expanded_params, param) + } else { + # Find leaf entries with this prefix (tuple expansion) + prefix <- paste0(param, ".") + leaves <- meta_names[startsWith(meta_names, prefix)] + if (length(leaves) > 0) { + expanded_params <- c(expanded_params, leaves) + } + } + } + lapply(param_metadata[expanded_params], function(par_dims) { if ((length(par_dims) == 0)) { array(0, dim = 1) } else { diff --git a/tests/testthat/resources/stan/tuple_complex.stan b/tests/testthat/resources/stan/tuple_complex.stan new file mode 100644 index 00000000..7dfda200 --- /dev/null +++ b/tests/testthat/resources/stan/tuple_complex.stan @@ -0,0 +1,40 @@ +parameters { + real a_scalar; + tuple(tuple(array[2] real, vector[2]), matrix[2, 2]) b_tuple; + matrix[2, 2] c_matrix; + complex z; +} +model { + a_scalar ~ normal(0, 1); + b_tuple.1.1 ~ normal(0, 1); + b_tuple.1.2 ~ normal(0, 1); + to_vector(b_tuple.2) ~ normal(0, 1); + to_vector(c_matrix) ~ normal(0, 1); + get_real(z) ~ normal(0, 1); + get_imag(z) ~ normal(0, 1); +} +generated quantities { + tuple(real, tuple(tuple(array[2] real, vector[2]), matrix[2, 2]), matrix[2, 2]) d_tuple + = (a_scalar, b_tuple, c_matrix); + + real mu = normal_rng(a_scalar, 1); + matrix[2, 3] m = mu * to_matrix(linspaced_vector(6, 5, 11), 2, 3); + array[4] matrix[2, 3] threeD; + for (i in 1 : 4) { + threeD[i] = i * mu * to_matrix(linspaced_vector(6, 5, 11), 2, 3); + } + + complex_vector[2] zv = to_complex([mu * 3, mu * 5]', [mu * 4, mu * 6]'); + complex_matrix[2, 3] zm = to_complex(m, m + 1); + array[4] complex_matrix[2, 3] z3D; + for (i in 1 : 4) { + z3D[i] = to_complex(threeD[i], threeD[i] + 1); + } + + real base = normal_rng(0, 1); + int base_i = to_int(normal_rng(10, 10)); + + tuple(real, real) pair = (base, base * 2); + tuple(real, tuple(int, complex)) nested = (base * 3, (base_i, base * 4.0i)); + array[2] tuple(real, real) arr_pair = {pair, (base * 5, base * 6)}; +} diff --git a/tests/testthat/test-tuple-complex.R b/tests/testthat/test-tuple-complex.R new file mode 100644 index 00000000..cab232ee --- /dev/null +++ b/tests/testthat/test-tuple-complex.R @@ -0,0 +1,388 @@ +# Tests for tuple and complex variable handling + +# --- Unit tests for repair_variable_names --- + +test_that("repair_variable_names handles standard array/matrix variables", { + expect_equal( + repair_variable_names(c("mu", "beta.1", "beta.2", "sigma.1.2")), + c("mu", "beta[1]", "beta[2]", "sigma[1,2]") + ) +}) + +test_that("repair_variable_names handles complex scalar variables", { + expect_equal( + repair_variable_names(c("z.real", "z.imag")), + c("z[real]", "z[imag]") + ) +}) + +test_that("repair_variable_names handles complex vector variables", { + expect_equal( + repair_variable_names(c("zv.1.real", "zv.1.imag", "zv.2.real", "zv.2.imag")), + c("zv[1,real]", "zv[1,imag]", "zv[2,real]", "zv[2,imag]") + ) +}) + +test_that("repair_variable_names handles complex matrix variables", { + expect_equal( + repair_variable_names(c("zm.1.1.real", "zm.1.1.imag", "zm.2.3.real", "zm.2.3.imag")), + c("zm[1,1,real]", "zm[1,1,imag]", "zm[2,3,real]", "zm[2,3,imag]") + ) +}) + +test_that("repair_variable_names handles array of complex matrix variables", { + expect_equal( + repair_variable_names(c("z3D.1.1.1.real", "z3D.1.1.1.imag", "z3D.4.2.3.real", "z3D.4.2.3.imag")), + c("z3D[1,1,1,real]", "z3D[1,1,1,imag]", "z3D[4,2,3,real]", "z3D[4,2,3,imag]") + ) +}) + +test_that("repair_variable_names handles simple tuple variables", { + expect_equal( + repair_variable_names(c("pair:1", "pair:2")), + c("pair:1", "pair:2") + ) +}) + +test_that("repair_variable_names handles nested tuple variables", { + expect_equal( + repair_variable_names(c("nested:1", "nested:2:1", "nested:2:2.real", "nested:2:2.imag")), + c("nested:1", "nested:2:1", "nested:2:2[real]", "nested:2:2[imag]") + ) +}) + +test_that("repair_variable_names handles tuple with array indices", { + expect_equal( + repair_variable_names(c("b_tuple:1:1.1", "b_tuple:1:1.2", "b_tuple:2.1.1", "b_tuple:2.2.2")), + c("b_tuple:1:1[1]", "b_tuple:1:1[2]", "b_tuple:2[1,1]", "b_tuple:2[2,2]") + ) +}) + +test_that("repair_variable_names handles array of tuples", { + expect_equal( + repair_variable_names(c("arr_pair.1:1", "arr_pair.1:2", "arr_pair.2:1", "arr_pair.2:2")), + c("arr_pair[1:1]", "arr_pair[1:2]", "arr_pair[2:1]", "arr_pair[2:2]") + ) +}) + +# --- Unit tests for unrepair_variable_names (inverse) --- + +test_that("unrepair_variable_names is inverse of repair_variable_names", { + raw_names <- c( + "mu", "beta.1.2", + "z.real", "z.imag", + "zv.1.real", "zv.2.imag", + "zm.1.1.real", "zm.2.3.imag", + "z3D.1.1.1.real", "z3D.4.2.3.imag", + "pair:1", "pair:2", + "nested:1", "nested:2:1", "nested:2:2.real", "nested:2:2.imag", + "b_tuple:1:1.1", "b_tuple:2.1.1", + "arr_pair.1:1", "arr_pair.2:2" + ) + expect_equal(unrepair_variable_names(repair_variable_names(raw_names)), raw_names) +}) + +# --- Unit tests for variable_dims --- + +test_that("variable_dims handles complex scalar (no NA, no warning)", { + expect_silent(result <- variable_dims(c("z[real]", "z[imag]"))) + expect_equal(result, list(z = 2L)) +}) + +test_that("variable_dims handles complex vector", { + expect_silent( + result <- variable_dims(c("zv[1,real]", "zv[1,imag]", "zv[2,real]", "zv[2,imag]")) + ) + expect_equal(result, list(zv = c(2L, 2L))) +}) + +test_that("variable_dims handles complex matrix", { + vars <- c( + "zm[1,1,real]", "zm[1,1,imag]", "zm[2,1,real]", "zm[2,1,imag]", + "zm[1,2,real]", "zm[1,2,imag]", "zm[2,2,real]", "zm[2,2,imag]", + "zm[1,3,real]", "zm[1,3,imag]", "zm[2,3,real]", "zm[2,3,imag]" + ) + expect_silent(result <- variable_dims(vars)) + expect_equal(result, list(zm = c(2L, 3L, 2L))) +}) + +test_that("variable_dims handles tuple leaf variables", { + vars <- c("pair:1", "pair:2", "nested:1", "nested:2:1", + "nested:2:2[real]", "nested:2:2[imag]") + expect_silent(result <- variable_dims(vars)) + expect_equal(result$`pair:1`, 1) + expect_equal(result$`pair:2`, 1) + expect_equal(result$`nested:1`, 1) + expect_equal(result$`nested:2:1`, 1) + expect_equal(result$`nested:2:2`, 2L) +}) + +test_that("variable_dims handles array of tuples", { + vars <- c("arr_pair[1:1]", "arr_pair[1:2]", "arr_pair[2:1]", "arr_pair[2:2]") + expect_silent(result <- variable_dims(vars)) + expect_equal(result, list(arr_pair = 4L)) +}) + +test_that("variable_dims handles mixed standard and complex/tuple variables", { + vars <- c( + "lp__", "a_scalar", + "b_tuple:1:1[1]", "b_tuple:1:1[2]", + "b_tuple:1:2[1]", "b_tuple:1:2[2]", + "b_tuple:2[1,1]", "b_tuple:2[2,1]", "b_tuple:2[1,2]", "b_tuple:2[2,2]", + "c_matrix[1,1]", "c_matrix[2,1]", "c_matrix[1,2]", "c_matrix[2,2]", + "z[real]", "z[imag]" + ) + expect_silent(result <- variable_dims(vars)) + expect_equal(result$lp__, 1) + expect_equal(result$a_scalar, 1) + expect_equal(result$`b_tuple:1:1`, 2L) + expect_equal(result$`b_tuple:1:2`, 2L) + expect_equal(result$`b_tuple:2`, c(2L, 2L)) + expect_equal(result$c_matrix, c(2L, 2L)) + expect_equal(result$z, 2L) +}) + +# --- Helper function tests --- + +test_that("stan_param_has_leaf matches direct and tuple-expanded names", { + leaf_names <- c("a_scalar", "b_tuple:1:1", "b_tuple:1:2", "b_tuple:2", "c_matrix", "z") + expect_equal( + stan_param_has_leaf(c("a_scalar", "b_tuple", "c_matrix", "z"), leaf_names), + c(TRUE, TRUE, TRUE, TRUE) + ) + expect_equal( + stan_param_has_leaf(c("nonexistent", "b_tup"), leaf_names), + c(FALSE, FALSE) + ) +}) + +test_that("expand_stan_params_to_leaves expands tuple names", { + leaf_names <- c("a_scalar", "b_tuple:1:1", "b_tuple:1:2", "b_tuple:2", "c_matrix", "z") + expect_equal( + expand_stan_params_to_leaves(c("a_scalar", "b_tuple", "c_matrix", "z"), leaf_names), + c("a_scalar", "b_tuple:1:1", "b_tuple:1:2", "b_tuple:2", "c_matrix", "z") + ) + # Non-tuple names pass through + expect_equal( + expand_stan_params_to_leaves(c("a_scalar", "c_matrix"), leaf_names), + c("a_scalar", "c_matrix") + ) +}) + +# --- End-to-end tests with Stan model --- + +test_that("sampling model with tuple and complex types produces no warnings", { + mod <- testing_model("tuple_complex") + expect_no_warning( + utils::capture.output( + fit <- mod$sample(seed = 123, chains = 2, iter_sampling = 100, + iter_warmup = 100, refresh = 0) + ) + ) + + # Check metadata + meta <- fit$metadata() + expect_true("a_scalar" %in% meta$stan_variables) + expect_true("b_tuple:1:1" %in% meta$stan_variables) + expect_true("b_tuple:1:2" %in% meta$stan_variables) + expect_true("b_tuple:2" %in% meta$stan_variables) + expect_true("c_matrix" %in% meta$stan_variables) + expect_true("z" %in% meta$stan_variables) + expect_true("d_tuple:1" %in% meta$stan_variables) + expect_true("pair:1" %in% meta$stan_variables) + expect_true("pair:2" %in% meta$stan_variables) + expect_true("nested:1" %in% meta$stan_variables) + expect_true("nested:2:1" %in% meta$stan_variables) + expect_true("nested:2:2" %in% meta$stan_variables) + expect_true("arr_pair" %in% meta$stan_variables) + expect_true("zv" %in% meta$stan_variables) + expect_true("zm" %in% meta$stan_variables) + expect_true("z3D" %in% meta$stan_variables) + + # Check dimensions have no NAs + sizes <- meta$stan_variable_sizes + for (var_name in names(sizes)) { + expect_false( + any(is.na(sizes[[var_name]])), + info = paste("NA in stan_variable_sizes for", var_name) + ) + } + + # Check specific dimensions + expect_equal(sizes$z, 2L) + expect_equal(sizes$zv, c(2L, 2L)) + expect_equal(sizes$zm, c(2L, 3L, 2L)) + expect_equal(sizes$z3D, c(4L, 2L, 3L, 2L)) + expect_equal(sizes$`b_tuple:1:1`, 2L) + expect_equal(sizes$`b_tuple:2`, c(2L, 2L)) + expect_equal(sizes$`pair:1`, 1) + expect_equal(sizes$`nested:2:2`, 2L) + + # Check draws work with posterior + dr <- fit$draws() + expect_true("z[real]" %in% posterior::variables(dr)) + expect_true("z[imag]" %in% posterior::variables(dr)) + expect_true("pair:1" %in% posterior::variables(dr)) + expect_true("arr_pair[1:1]" %in% posterior::variables(dr)) +}) + +test_that("variable_skeleton works with tuple and complex parameters", { + mod <- cmdstan_model(testing_stan_file("tuple_complex"), force_recompile = TRUE) + utils::capture.output( + fit <- mod$sample(seed = 123, chains = 2, iter_sampling = 100, + iter_warmup = 100, refresh = 0) + ) + + skel <- fit$variable_skeleton() + expect_false(any(is.na(names(skel)))) + expect_true("a_scalar" %in% names(skel)) + expect_true("c_matrix" %in% names(skel)) + expect_true("z" %in% names(skel)) + # Tuple leaves should be expanded + expect_true("b_tuple.1.1" %in% names(skel)) + expect_true("b_tuple.1.2" %in% names(skel)) + expect_true("b_tuple.2" %in% names(skel)) + # Check dimensions + expect_equal(dim(skel$z), 2L) + expect_equal(dim(skel$c_matrix), c(2L, 2L)) + expect_equal(dim(skel$b_tuple.1.1), 2L) + expect_equal(dim(skel$b_tuple.2), c(2L, 2L)) +}) + +test_that("unconstrain_draws works with tuple and complex parameters", { + mod <- cmdstan_model(testing_stan_file("tuple_complex"), force_recompile = TRUE) + utils::capture.output( + fit <- mod$sample(seed = 123, chains = 2, iter_sampling = 100, + iter_warmup = 100, refresh = 0) + ) + + expect_no_error(udraws <- fit$unconstrain_draws()) + uvar_names <- posterior::variables(udraws) + # Should include the tuple leaf names and complex parts + expect_true("a_scalar" %in% uvar_names) + expect_true("b_tuple:1:1[1]" %in% uvar_names) + expect_true("b_tuple:1:1[2]" %in% uvar_names) + expect_true("b_tuple:2[1,1]" %in% uvar_names) + expect_true("c_matrix[1,1]" %in% uvar_names) + expect_true("z[real]" %in% uvar_names) + expect_true("z[imag]" %in% uvar_names) +}) + +# --- Tests for write_stan_json with tuples --- + +test_that("write_stan_json handles simple tuple values", { + f <- tempfile(fileext = ".json") + data <- list(pair = list("1" = 1.5, "2" = 3.4)) + write_stan_json(data, f) + json <- jsonlite::read_json(f) + expect_equal(json$pair[["1"]], 1.5) + expect_equal(json$pair[["2"]], 3.4) +}) + +test_that("write_stan_json handles nested tuple values", { + f <- tempfile(fileext = ".json") + data <- list( + b_tuple = list( + "1" = list("1" = c(1.1, 2.2), "2" = c(3.3, 4.4)), + "2" = matrix(c(1, 2, 3, 4), 2, 2) + ) + ) + write_stan_json(data, f) + json_text <- paste(readLines(f), collapse = "\n") + # Verify the JSON contains the expected nested structure + expect_true(grepl('"b_tuple"', json_text)) + expect_true(grepl('"1":\\s*\\{', json_text)) + expect_true(grepl('"2":\\s*\\[', json_text)) + # Verify inner arrays + expect_true(grepl("1.1,\\s*2.2", json_text)) + expect_true(grepl("3.3,\\s*4.4", json_text)) +}) + +test_that("write_stan_json still handles array lists correctly", { + f <- tempfile(fileext = ".json") + data <- list(x = list(1:3, 4:6)) + write_stan_json(data, f) + json <- jsonlite::read_json(f, simplifyVector = TRUE) + expect_equal(json$x[1, ], c(1, 2, 3)) + expect_equal(json$x[2, ], c(4, 5, 6)) +}) + +# --- Tests for build_tuple_init_value --- + +test_that("build_tuple_init_value reconstructs nested tuple from draws", { + mod <- cmdstan_model(testing_stan_file("tuple_complex"), force_recompile = TRUE) + utils::capture.output( + fit <- mod$sample(seed = 123, chains = 1, iter_sampling = 10, refresh = 0) + ) + dr <- fit$draws() + draws_rvar <- posterior::as_draws_rvars(dr) + mv <- mod$variables() + + tuple_result <- build_tuple_init_value("b_tuple", mv$parameters$b_tuple, + draws_rvar, 1) + expect_true(is.list(tuple_result)) + expect_equal(length(tuple_result$bad_leaves), 0) + result <- tuple_result$value + # Should be a named list with keys "1" and "2" + expect_true(is.list(result)) + expect_equal(names(result), c("1", "2")) + # Element "1" is a nested tuple with keys "1" and "2" + expect_true(is.list(result[["1"]])) + expect_equal(names(result[["1"]]), c("1", "2")) + # Element "1"."1" is array[2] real + expect_equal(length(result[["1"]][["1"]]), 2) + expect_true(is.numeric(result[["1"]][["1"]])) + # Element "2" is matrix[2,2] + expect_equal(dim(result[["2"]]), c(2, 2)) +}) + +# --- End-to-end init = fit with tuple parameters --- + +test_that("init = fit works with tuple parameters", { + mod <- cmdstan_model(testing_stan_file("tuple_complex"), force_recompile = TRUE) + utils::capture.output( + fit <- mod$sample(seed = 123, chains = 2, iter_sampling = 100, + iter_warmup = 100, refresh = 0) + ) + + # Check that init JSON includes b_tuple + init_files <- process_init(fit, num_procs = 1, + model_variables = mod$variables()) + json <- jsonlite::read_json(init_files) + expect_true("b_tuple" %in% names(json)) + expect_true("a_scalar" %in% names(json)) + expect_true("c_matrix" %in% names(json)) + expect_true("z" %in% names(json)) + # b_tuple should be a nested object + expect_true(is.list(json$b_tuple)) + expect_true("1" %in% names(json$b_tuple)) + expect_true("2" %in% names(json$b_tuple)) + + # Second sampling with init = fit should succeed without missing-param warnings + suppressMessages( + utils::capture.output( + fit2 <- mod$sample(seed = 456, chains = 2, iter_sampling = 100, + iter_warmup = 100, refresh = 0, init = fit) + ) + ) + expect_true(inherits(fit2, "CmdStanMCMC")) +}) + +test_that("manual init list with tuple values works", { + mod <- cmdstan_model(testing_stan_file("tuple_complex"), force_recompile = TRUE) + init_list <- list(list( + a_scalar = 0.5, + b_tuple = list( + "1" = list("1" = c(0.1, 0.2), "2" = c(0.3, 0.4)), + "2" = matrix(c(0.5, 0.6, 0.7, 0.8), 2, 2) + ), + c_matrix = matrix(c(0.1, 0.2, 0.3, 0.4), 2, 2), + z = c(0.1, 0.2) + )) + expect_no_error( + utils::capture.output( + fit <- mod$sample(seed = 123, chains = 1, iter_sampling = 100, + iter_warmup = 100, refresh = 0, init = init_list) + ) + ) +}) From db15c82b09abfa30454e061ac595130b5edabd7c Mon Sep 17 00:00:00 2001 From: Aki Vehtari Date: Fri, 10 Apr 2026 21:46:50 +0300 Subject: [PATCH 2/2] test: skip two model methods tests if os_is_wsl() --- tests/testthat/test-tuple-complex.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/testthat/test-tuple-complex.R b/tests/testthat/test-tuple-complex.R index cab232ee..46c97068 100644 --- a/tests/testthat/test-tuple-complex.R +++ b/tests/testthat/test-tuple-complex.R @@ -227,6 +227,7 @@ test_that("sampling model with tuple and complex types produces no warnings", { }) test_that("variable_skeleton works with tuple and complex parameters", { + skip_if(os_is_wsl()) mod <- cmdstan_model(testing_stan_file("tuple_complex"), force_recompile = TRUE) utils::capture.output( fit <- mod$sample(seed = 123, chains = 2, iter_sampling = 100, @@ -250,6 +251,7 @@ test_that("variable_skeleton works with tuple and complex parameters", { }) test_that("unconstrain_draws works with tuple and complex parameters", { + skip_if(os_is_wsl()) mod <- cmdstan_model(testing_stan_file("tuple_complex"), force_recompile = TRUE) utils::capture.output( fit <- mod$sample(seed = 123, chains = 2, iter_sampling = 100,