Skip to content

Commit 69d0de2

Browse files
committed
Renamed API combination test file and added python tests
1 parent 81d84bd commit 69d0de2

File tree

5 files changed

+394
-82
lines changed

5 files changed

+394
-82
lines changed

R/bart.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,7 +2124,6 @@ predict.bartmodel <- function(
21242124
X <- preprocessPredictionData(X, train_set_metadata)
21252125

21262126
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
2127-
has_rfx <- FALSE
21282127
if (predict_rfx) {
21292128
if (!is.null(rfx_group_ids)) {
21302129
rfx_unique_group_ids <- object$rfx_unique_group_ids
@@ -2135,7 +2134,6 @@ predict.bartmodel <- function(
21352134
)
21362135
}
21372136
rfx_group_ids <- as.integer(group_ids_factor)
2138-
has_rfx <- TRUE
21392137
}
21402138
}
21412139

R/posterior_transformation.R

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -409,19 +409,31 @@ compute_contrast_bart_model <- function(
409409
"rfx_group_ids_0 and rfx_group_ids_1 must be provided for this model"
410410
)
411411
}
412-
if ((has_rfx) && (is.null(rfx_basis_0) || is.null(rfx_basis_1))) {
413-
stop(
414-
"rfx_basis_0 and rfx_basis_1 must be provided for this model"
415-
)
416-
}
417-
if (
418-
(object$model_params$num_rfx_basis > 0) &&
419-
((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) ||
420-
(ncol(rfx_basis_1) != object$model_params$num_rfx_basis))
421-
) {
422-
stop(
423-
"rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model"
424-
)
412+
if (has_rfx) {
413+
if (object$model_params$rfx_model_spec == "custom") {
414+
if ((is.null(rfx_basis_0) || is.null(rfx_basis_1))) {
415+
stop(
416+
"A user-provided basis (`rfx_basis_0` and `rfx_basis_1`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
417+
)
418+
}
419+
if (!is.matrix(rfx_basis_0) || !is.matrix(rfx_basis_1)) {
420+
stop("'rfx_basis_0' and 'rfx_basis_1' must be matrices")
421+
}
422+
if ((nrow(rfx_basis_0) != nrow(X)) || (nrow(rfx_basis_1) != nrow(X))) {
423+
stop(
424+
"'rfx_basis_0' and 'rfx_basis_1' must have the same number of rows as 'X'"
425+
)
426+
}
427+
if (
428+
(object$model_params$num_rfx_basis > 0) &&
429+
((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) ||
430+
(ncol(rfx_basis_1) != object$model_params$num_rfx_basis))
431+
) {
432+
stop(
433+
"rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model"
434+
)
435+
}
436+
}
425437
}
426438

427439
# Predict for the control arm
@@ -735,16 +747,18 @@ sample_bart_posterior_predictive <- function(
735747
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
736748
)
737749
}
738-
if (is.null(rfx_basis)) {
739-
stop(
740-
"'rfx_basis' must be provided in order to compute the requested intervals"
741-
)
742-
}
743-
if (!is.matrix(rfx_basis)) {
744-
stop("'rfx_basis' must be a matrix")
745-
}
746-
if (nrow(rfx_basis) != nrow(X)) {
747-
stop("'rfx_basis' must have the same number of rows as 'X'")
750+
if (model_object$model_params$rfx_model_spec == "custom") {
751+
if (is.null(rfx_basis)) {
752+
stop(
753+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
754+
)
755+
}
756+
if (!is.matrix(rfx_basis)) {
757+
stop("'rfx_basis' must be a matrix")
758+
}
759+
if (nrow(rfx_basis) != nrow(X)) {
760+
stop("'rfx_basis' must have the same number of rows as 'X'")
761+
}
748762
}
749763
}
750764

@@ -1172,16 +1186,18 @@ compute_bart_posterior_interval <- function(
11721186
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
11731187
)
11741188
}
1175-
if (is.null(rfx_basis)) {
1176-
stop(
1177-
"'rfx_basis' must be provided in order to compute the requested intervals"
1178-
)
1179-
}
1180-
if (!is.matrix(rfx_basis)) {
1181-
stop("'rfx_basis' must be a matrix")
1182-
}
1183-
if (nrow(rfx_basis) != nrow(X)) {
1184-
stop("'rfx_basis' must have the same number of rows as 'X'")
1189+
if (model_object$model_params$rfx_model_spec == "custom") {
1190+
if (is.null(rfx_basis)) {
1191+
stop(
1192+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
1193+
)
1194+
}
1195+
if (!is.matrix(rfx_basis)) {
1196+
stop("'rfx_basis' must be a matrix")
1197+
}
1198+
if (nrow(rfx_basis) != nrow(X)) {
1199+
stop("'rfx_basis' must have the same number of rows as 'X'")
1200+
}
11851201
}
11861202
}
11871203

stochtree/bart.py

Lines changed: 56 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ def __init__(self) -> None:
7070

7171
def sample(
7272
self,
73-
X_train: Union[np.array, pd.DataFrame],
74-
y_train: np.array,
75-
leaf_basis_train: np.array = None,
76-
rfx_group_ids_train: np.array = None,
77-
rfx_basis_train: np.array = None,
78-
X_test: Union[np.array, pd.DataFrame] = None,
79-
leaf_basis_test: np.array = None,
80-
rfx_group_ids_test: np.array = None,
81-
rfx_basis_test: np.array = None,
73+
X_train: Union[np.ndarray, pd.DataFrame],
74+
y_train: np.ndarray,
75+
leaf_basis_train: Optional[np.ndarray] = None,
76+
rfx_group_ids_train: Optional[np.ndarray] = None,
77+
rfx_basis_train: Optional[np.ndarray] = None,
78+
X_test: Optional[Union[np.ndarray, pd.DataFrame]] = None,
79+
leaf_basis_test: Optional[np.ndarray] = None,
80+
rfx_group_ids_test: Optional[np.ndarray] = None,
81+
rfx_basis_test: Optional[np.ndarray] = None,
8282
num_gfr: int = 5,
8383
num_burnin: int = 0,
8484
num_mcmc: int = 100,
@@ -859,6 +859,13 @@ def sample(
859859
if num_features_subsample_variance is None:
860860
num_features_subsample_variance = X_train.shape[1]
861861

862+
# Runtime check for multivariate leaf regression
863+
if sample_sigma2_leaf and self.num_basis > 1:
864+
warnings.warn(
865+
"Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model."
866+
)
867+
sample_sigma2_leaf = False
868+
862869
# Preliminary runtime checks for probit link
863870
if not self.include_mean_forest:
864871
self.probit_outcome_model = False
@@ -872,15 +879,15 @@ def sample(
872879
raise ValueError(
873880
"You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1"
874881
)
875-
if self.include_variance_forest:
876-
raise ValueError(
877-
"We do not support heteroskedasticity with a probit link"
878-
)
879882
if sample_sigma2_global:
880883
warnings.warn(
881884
"Global error variance will not be sampled with a probit link as it is fixed at 1"
882885
)
883886
sample_sigma2_global = False
887+
if self.include_variance_forest:
888+
raise ValueError(
889+
"We do not support heteroskedasticity with a probit link"
890+
)
884891

885892
# Handle standardization, prior calibration, and initialization of forest
886893
# differently for binary and continuous outcomes
@@ -1217,7 +1224,7 @@ def sample(
12171224
else:
12181225
leaf_model_mean_forest = 2
12191226
leaf_dimension_mean = self.num_basis
1220-
1227+
12211228
# Sampling data structures
12221229
global_model_config = GlobalModelConfig(global_error_variance=current_sigma2)
12231230
if self.include_mean_forest:
@@ -1900,6 +1907,9 @@ def predict(
19001907
if leaf_basis is not None:
19011908
if leaf_basis.ndim == 1:
19021909
leaf_basis = np.expand_dims(leaf_basis, 1)
1910+
if rfx_basis is not None:
1911+
if rfx_basis.ndim == 1:
1912+
rfx_basis = np.expand_dims(rfx_basis, 1)
19031913

19041914
# Covariate preprocessing
19051915
if not self._covariate_preprocessor._check_is_fitted():
@@ -1958,21 +1968,18 @@ def predict(
19581968
mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar
19591969

19601970
# Random effects data checks
1961-
if has_rfx:
1962-
if rfx_group_ids is None:
1963-
raise ValueError(
1964-
"rfx_group_ids must be provided if rfx_basis is provided"
1965-
)
1966-
if rfx_basis is not None:
1967-
if rfx_basis.ndim == 1:
1968-
rfx_basis = np.expand_dims(rfx_basis, 1)
1969-
if rfx_basis.shape[0] != X.shape[0]:
1970-
raise ValueError("X and rfx_basis must have the same number of rows")
1971+
if predict_rfx and rfx_group_ids is None:
1972+
raise ValueError(
1973+
"Random effect group labels (rfx_group_ids) must be provided for this model"
1974+
)
1975+
if predict_rfx and rfx_basis is None and not rfx_intercept:
1976+
raise ValueError("Random effects basis (rfx_basis) must be provided for this model")
1977+
if self.num_rfx_basis > 0 and not rfx_intercept:
19711978
if rfx_basis.shape[1] != self.num_rfx_basis:
19721979
raise ValueError(
1973-
"rfx_basis must have the same number of columns as the random effects basis used to sample this model"
1980+
"Random effects basis has a different dimension than the basis used to train this model"
19741981
)
1975-
1982+
19761983
# Random effects predictions
19771984
if predict_rfx or predict_rfx_intermediate:
19781985
if rfx_basis is not None:
@@ -1983,7 +1990,7 @@ def predict(
19831990
# Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only"
19841991
if not rfx_intercept:
19851992
raise ValueError(
1986-
"rfx_basis must be provided for random effects models with random slopes"
1993+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
19871994
)
19881995

19891996
# Extract the raw RFX samples and scale by train set outcome standard deviation
@@ -2321,16 +2328,17 @@ def compute_posterior_interval(
23212328
raise ValueError(
23222329
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
23232330
)
2324-
if rfx_basis is None:
2325-
raise ValueError(
2326-
"'rfx_basis' must be provided in order to compute the requested intervals"
2327-
)
2328-
if not isinstance(rfx_basis, np.ndarray):
2329-
raise ValueError("'rfx_basis' must be a numpy array")
2330-
if rfx_basis.shape[0] != X.shape[0]:
2331-
raise ValueError(
2332-
"'rfx_basis' must have the same number of rows as 'X'"
2333-
)
2331+
if self.rfx_model_spec == "custom":
2332+
if rfx_basis is None:
2333+
raise ValueError(
2334+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
2335+
)
2336+
if not isinstance(rfx_basis, np.ndarray):
2337+
raise ValueError("'rfx_basis' must be a numpy array")
2338+
if rfx_basis.shape[0] != X.shape[0]:
2339+
raise ValueError(
2340+
"'rfx_basis' must have the same number of rows as 'X'"
2341+
)
23342342

23352343
# Compute posterior matrices for the requested model terms
23362344
predictions = self.predict(
@@ -2427,16 +2435,17 @@ def sample_posterior_predictive(
24272435
raise ValueError(
24282436
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
24292437
)
2430-
if rfx_basis is None:
2431-
raise ValueError(
2432-
"'rfx_basis' must be provided in order to compute the requested intervals"
2433-
)
2434-
if not isinstance(rfx_basis, np.ndarray):
2435-
raise ValueError("'rfx_basis' must be a numpy array")
2436-
if rfx_basis.shape[0] != X.shape[0]:
2437-
raise ValueError(
2438-
"'rfx_basis' must have the same number of rows as 'X'"
2439-
)
2438+
if self.rfx_model_spec == "custom":
2439+
if rfx_basis is None:
2440+
raise ValueError(
2441+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
2442+
)
2443+
if not isinstance(rfx_basis, np.ndarray):
2444+
raise ValueError("'rfx_basis' must be a numpy array")
2445+
if rfx_basis.shape[0] != X.shape[0]:
2446+
raise ValueError(
2447+
"'rfx_basis' must have the same number of rows as 'X'"
2448+
)
24402449

24412450
# Compute posterior predictive samples
24422451
bart_preds = self.predict(

0 commit comments

Comments
 (0)