@@ -70,15 +70,15 @@ def __init__(self) -> None:
7070
7171 def sample (
7272 self ,
73- X_train : Union [np .array , pd .DataFrame ],
74- y_train : np .array ,
75- leaf_basis_train : np .array = None ,
76- rfx_group_ids_train : np .array = None ,
77- rfx_basis_train : np .array = None ,
78- X_test : Union [np .array , pd .DataFrame ] = None ,
79- leaf_basis_test : np .array = None ,
80- rfx_group_ids_test : np .array = None ,
81- rfx_basis_test : np .array = None ,
73+ X_train : Union [np .ndarray , pd .DataFrame ],
74+ y_train : np .ndarray ,
75+ leaf_basis_train : Optional [ np .ndarray ] = None ,
76+ rfx_group_ids_train : Optional [ np .ndarray ] = None ,
77+ rfx_basis_train : Optional [ np .ndarray ] = None ,
78+ X_test : Optional [ Union [np .ndarray , pd .DataFrame ] ] = None ,
79+ leaf_basis_test : Optional [ np .ndarray ] = None ,
80+ rfx_group_ids_test : Optional [ np .ndarray ] = None ,
81+ rfx_basis_test : Optional [ np .ndarray ] = None ,
8282 num_gfr : int = 5 ,
8383 num_burnin : int = 0 ,
8484 num_mcmc : int = 100 ,
@@ -859,6 +859,13 @@ def sample(
859859 if num_features_subsample_variance is None :
860860 num_features_subsample_variance = X_train .shape [1 ]
861861
862+ # Runtime check for multivariate leaf regression
863+ if sample_sigma2_leaf and self .num_basis > 1 :
864+ warnings .warn (
865+ "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model."
866+ )
867+ sample_sigma2_leaf = False
868+
862869 # Preliminary runtime checks for probit link
863870 if not self .include_mean_forest :
864871 self .probit_outcome_model = False
@@ -872,15 +879,15 @@ def sample(
872879 raise ValueError (
873880 "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1"
874881 )
875- if self .include_variance_forest :
876- raise ValueError (
877- "We do not support heteroskedasticity with a probit link"
878- )
879882 if sample_sigma2_global :
880883 warnings .warn (
881884 "Global error variance will not be sampled with a probit link as it is fixed at 1"
882885 )
883886 sample_sigma2_global = False
887+ if self .include_variance_forest :
888+ raise ValueError (
889+ "We do not support heteroskedasticity with a probit link"
890+ )
884891
885892 # Handle standardization, prior calibration, and initialization of forest
886893 # differently for binary and continuous outcomes
@@ -1217,7 +1224,7 @@ def sample(
12171224 else :
12181225 leaf_model_mean_forest = 2
12191226 leaf_dimension_mean = self .num_basis
1220-
1227+
12211228 # Sampling data structures
12221229 global_model_config = GlobalModelConfig (global_error_variance = current_sigma2 )
12231230 if self .include_mean_forest :
@@ -1900,6 +1907,9 @@ def predict(
19001907 if leaf_basis is not None :
19011908 if leaf_basis .ndim == 1 :
19021909 leaf_basis = np .expand_dims (leaf_basis , 1 )
1910+ if rfx_basis is not None :
1911+ if rfx_basis .ndim == 1 :
1912+ rfx_basis = np .expand_dims (rfx_basis , 1 )
19031913
19041914 # Covariate preprocessing
19051915 if not self ._covariate_preprocessor ._check_is_fitted ():
@@ -1958,21 +1968,18 @@ def predict(
19581968 mean_forest_predictions = mean_pred_raw * self .y_std + self .y_bar
19591969
19601970 # Random effects data checks
1961- if has_rfx :
1962- if rfx_group_ids is None :
1963- raise ValueError (
1964- "rfx_group_ids must be provided if rfx_basis is provided"
1965- )
1966- if rfx_basis is not None :
1967- if rfx_basis .ndim == 1 :
1968- rfx_basis = np .expand_dims (rfx_basis , 1 )
1969- if rfx_basis .shape [0 ] != X .shape [0 ]:
1970- raise ValueError ("X and rfx_basis must have the same number of rows" )
1971+ if predict_rfx and rfx_group_ids is None :
1972+ raise ValueError (
1973+ "Random effect group labels (rfx_group_ids) must be provided for this model"
1974+ )
1975+ if predict_rfx and rfx_basis is None and not rfx_intercept :
1976+ raise ValueError ("Random effects basis (rfx_basis) must be provided for this model" )
1977+ if self .num_rfx_basis > 0 and not rfx_intercept :
19711978 if rfx_basis .shape [1 ] != self .num_rfx_basis :
19721979 raise ValueError (
1973- "rfx_basis must have the same number of columns as the random effects basis used to sample this model"
1980+ "Random effects basis has a different dimension than the basis used to train this model"
19741981 )
1975-
1982+
19761983 # Random effects predictions
19771984 if predict_rfx or predict_rfx_intermediate :
19781985 if rfx_basis is not None :
@@ -1983,7 +1990,7 @@ def predict(
19831990 # Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only"
19841991 if not rfx_intercept :
19851992 raise ValueError (
1986- "rfx_basis must be provided for random effects models with random slopes "
1993+ "A user-provided basis (` rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom' "
19871994 )
19881995
19891996 # Extract the raw RFX samples and scale by train set outcome standard deviation
@@ -2321,16 +2328,17 @@ def compute_posterior_interval(
23212328 raise ValueError (
23222329 "'rfx_group_ids' must have the same length as the number of rows in 'X'"
23232330 )
2324- if rfx_basis is None :
2325- raise ValueError (
2326- "'rfx_basis' must be provided in order to compute the requested intervals"
2327- )
2328- if not isinstance (rfx_basis , np .ndarray ):
2329- raise ValueError ("'rfx_basis' must be a numpy array" )
2330- if rfx_basis .shape [0 ] != X .shape [0 ]:
2331- raise ValueError (
2332- "'rfx_basis' must have the same number of rows as 'X'"
2333- )
2331+ if self .rfx_model_spec == "custom" :
2332+ if rfx_basis is None :
2333+ raise ValueError (
2334+ "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
2335+ )
2336+ if not isinstance (rfx_basis , np .ndarray ):
2337+ raise ValueError ("'rfx_basis' must be a numpy array" )
2338+ if rfx_basis .shape [0 ] != X .shape [0 ]:
2339+ raise ValueError (
2340+ "'rfx_basis' must have the same number of rows as 'X'"
2341+ )
23342342
23352343 # Compute posterior matrices for the requested model terms
23362344 predictions = self .predict (
@@ -2427,16 +2435,17 @@ def sample_posterior_predictive(
24272435 raise ValueError (
24282436 "'rfx_group_ids' must have the same length as the number of rows in 'X'"
24292437 )
2430- if rfx_basis is None :
2431- raise ValueError (
2432- "'rfx_basis' must be provided in order to compute the requested intervals"
2433- )
2434- if not isinstance (rfx_basis , np .ndarray ):
2435- raise ValueError ("'rfx_basis' must be a numpy array" )
2436- if rfx_basis .shape [0 ] != X .shape [0 ]:
2437- raise ValueError (
2438- "'rfx_basis' must have the same number of rows as 'X'"
2439- )
2438+ if self .rfx_model_spec == "custom" :
2439+ if rfx_basis is None :
2440+ raise ValueError (
2441+ "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
2442+ )
2443+ if not isinstance (rfx_basis , np .ndarray ):
2444+ raise ValueError ("'rfx_basis' must be a numpy array" )
2445+ if rfx_basis .shape [0 ] != X .shape [0 ]:
2446+ raise ValueError (
2447+ "'rfx_basis' must have the same number of rows as 'X'"
2448+ )
24402449
24412450 # Compute posterior predictive samples
24422451 bart_preds = self .predict (
0 commit comments