Skip to content

Commit f5adad6

Browse files
committed
Updating for linear_reg() engine specific args
1 parent 111c996 commit f5adad6

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

R/linear_reg.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ update.linear_reg <-
131131
parameters = NULL,
132132
penalty = NULL, mixture = NULL,
133133
fresh = FALSE, ...) {
134-
update_dot_check(...)
134+
135+
eng_args <- update_engine_parameters(object$eng_args, ...)
135136

136137
if (!is.null(parameters)) {
137138
parameters <- check_final_param(parameters)
@@ -145,12 +146,15 @@ update.linear_reg <-
145146

146147
if (fresh) {
147148
object$args <- args
149+
object$eng_args <- eng_args
148150
} else {
149151
null_args <- map_lgl(args, null_value)
150152
if (any(null_args))
151153
args <- args[!null_args]
152154
if (length(args) > 0)
153155
object$args[names(args)] <- args
156+
if (length(eng_args) > 0)
157+
object$eng_args[names(eng_args)] <- eng_args
154158
}
155159

156160
new_model_spec(

tests/testthat/test_linear_reg.R

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,12 @@ test_that('updating', {
166166
expr1 <- linear_reg() %>% set_engine("lm", model = FALSE)
167167
expr1_exp <- linear_reg(mixture = 0) %>% set_engine("lm", model = FALSE)
168168

169-
expr2 <- linear_reg(mixture = varying()) %>% set_engine("glmnet")
170-
expr2_exp <- linear_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = 10)
169+
expr2 <- linear_reg() %>% set_engine("glmnet", nlambda = varying())
170+
expr2_exp <- linear_reg() %>% set_engine("glmnet", nlambda = 10)
171171

172-
expr3 <- linear_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet")
173-
expr3_exp <- linear_reg(mixture = 1) %>% set_engine("glmnet")
172+
expr3 <- linear_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet", nlambda = varying())
173+
expr3_exp <- linear_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet", nlambda = 10)
174+
expr3_fre <- linear_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10)
174175

175176
expr4 <- linear_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10)
176177
expr4_exp <- linear_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10, pmax = 2)
@@ -179,7 +180,9 @@ test_that('updating', {
179180
expr5_exp <- linear_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10, pmax = 2)
180181

181182
expect_equal(update(expr1, mixture = 0), expr1_exp)
182-
expect_equal(update(expr3, mixture = 1, fresh = TRUE), expr3_exp)
183+
expect_equal(update(expr2, nlambda = 10), expr2_exp)
184+
expect_equal(update(expr3, mixture = 1, fresh = TRUE, nlambda = 10), expr3_fre)
185+
expect_equal(update(expr3, nlambda = 10), expr3_exp)
183186

184187
param_tibb <- tibble::tibble(mixture = 1/3, penalty = 1)
185188
param_list <- as.list(param_tibb)

0 commit comments

Comments
 (0)