Skip to content

Commit 5058f38

Browse files
committed
Update engine specific arguments for boost_tree()
1 parent 93233c4 commit 5058f38

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

R/boost_tree.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ update.boost_tree <-
163163
loss_reduction = NULL, sample_size = NULL,
164164
stop_iter = NULL,
165165
fresh = FALSE, ...) {
166-
update_dot_check(...)
166+
167+
eng_args <- update_engine_parameters(object$eng_args, ...)
167168

168169
if (!is.null(parameters)) {
169170
parameters <- check_final_param(parameters)
@@ -185,12 +186,15 @@ update.boost_tree <-
185186
# TODO make these blocks into a function and document well
186187
if (fresh) {
187188
object$args <- args
189+
object$eng_args <- eng_args
188190
} else {
189191
null_args <- map_lgl(args, null_value)
190192
if (any(null_args))
191193
args <- args[!null_args]
192194
if (length(args) > 0)
193195
object$args[names(args)] <- args
196+
if (length(eng_args) > 0)
197+
object$eng_args[names(eng_args)] <- eng_args
194198
}
195199

196200
new_model_spec(

tests/testthat/test_boost_tree.R

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,19 @@ test_that('updating', {
109109
expr1 <- boost_tree() %>% set_engine("xgboost", verbose = 0)
110110
expr1_exp <- boost_tree(trees = 10) %>% set_engine("xgboost", verbose = 0)
111111

112-
expr2 <- boost_tree(trees = varying()) %>% set_engine("xgboost")
113-
expr2_exp <- boost_tree(trees = varying()) %>% set_engine("xgboost", verbose = 0)
112+
expr2 <- boost_tree(trees = varying()) %>% set_engine("C5.0", bands = varying())
113+
expr2_exp <- boost_tree(trees = varying()) %>% set_engine("C5.0", bands = 10)
114114

115115
expr3 <- boost_tree(trees = 1, sample_size = varying())
116116
expr3_exp <- boost_tree(trees = 1)
117117

118+
expr4 <- boost_tree() %>% set_engine("C5.0", noGlobalPruning = varying())
119+
expr4_exp <- boost_tree() %>% set_engine("C5.0", noGlobalPruning = TRUE)
120+
118121
expect_equal(update(expr1, trees = 10), expr1_exp)
122+
expect_equal(update(expr2, bands = 10), expr2_exp)
119123
expect_equal(update(expr3, trees = 1, fresh = TRUE), expr3_exp)
124+
expect_equal(update(expr4, noGlobalPruning = TRUE), expr4_exp)
120125

121126
param_tibb <- tibble::tibble(trees = 7, mtry = 1)
122127
param_list <- as.list(param_tibb)

0 commit comments

Comments
 (0)