Skip to content

Commit 111c996

Browse files
committed
Update engine specific args for decision_tree()
1 parent 5058f38 commit 111c996

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

R/decision_tree.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ update.decision_tree <-
116116
parameters = NULL,
117117
cost_complexity = NULL, tree_depth = NULL, min_n = NULL,
118118
fresh = FALSE, ...) {
119-
update_dot_check(...)
119+
120+
eng_args <- update_engine_parameters(object$eng_args, ...)
120121

121122
if (!is.null(parameters)) {
122123
parameters <- check_final_param(parameters)
@@ -131,12 +132,15 @@ update.decision_tree <-
131132

132133
if (fresh) {
133134
object$args <- args
135+
object$eng_args <- eng_args
134136
} else {
135137
null_args <- map_lgl(args, null_value)
136138
if (any(null_args))
137139
args <- args[!null_args]
138140
if (length(args) > 0)
139141
object$args[names(args)] <- args
142+
if (length(eng_args) > 0)
143+
object$eng_args[names(eng_args)] <- eng_args
140144
}
141145

142146
new_model_spec(

tests/testthat/test_decision_tree.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,14 @@ test_that('updating', {
9999
expr1 <- decision_tree() %>% set_engine("rpart", model = FALSE)
100100
expr1_exp <- decision_tree(cost_complexity = .1) %>% set_engine("rpart", model = FALSE)
101101

102-
expr2 <- decision_tree(cost_complexity = varying()) %>% set_engine("rpart")
102+
expr2 <- decision_tree(cost_complexity = varying()) %>% set_engine("rpart", model = varying())
103103
expr2_exp <- decision_tree(cost_complexity = varying()) %>% set_engine("rpart", model = FALSE)
104104

105105
expr3 <- decision_tree(cost_complexity = 1, min_n = varying())
106106
expr3_exp <- decision_tree(cost_complexity = 1)
107107

108108
expect_equal(update(expr1, cost_complexity = .1), expr1_exp)
109+
expect_equal(update(expr2, model = FALSE), expr2_exp)
109110
expect_equal(update(expr3, cost_complexity = 1, fresh = TRUE), expr3_exp)
110111

111112
param_tibb <- tibble::tibble(cost_complexity = 0.1, min_n = 1)

0 commit comments

Comments
 (0)