Skip to content

Commit 2bf7b04

Browse files
committed
Updating for logistic_reg() engine specific params
1 parent f5adad6 commit 2bf7b04

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

R/logistic_reg.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ update.logistic_reg <-
115115
parameters = NULL,
116116
penalty = NULL, mixture = NULL,
117117
fresh = FALSE, ...) {
118-
update_dot_check(...)
118+
119+
eng_args <- update_engine_parameters(object$eng_args, ...)
119120

120121
if (!is.null(parameters)) {
121122
parameters <- check_final_param(parameters)
@@ -129,12 +130,15 @@ update.logistic_reg <-
129130

130131
if (fresh) {
131132
object$args <- args
133+
object$eng_args <- eng_args
132134
} else {
133135
null_args <- map_lgl(args, null_value)
134136
if (any(null_args))
135137
args <- args[!null_args]
136138
if (length(args) > 0)
137139
object$args[names(args)] <- args
140+
if (length(eng_args) > 0)
141+
object$eng_args[names(eng_args)] <- eng_args
138142
}
139143

140144
new_model_spec(

tests/testthat/test_logistic_reg.R

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,11 @@ test_that('updating', {
181181
expr1_exp <- logistic_reg(mixture = 0) %>%
182182
set_engine("glm", family = expr(binomial(link = "probit")))
183183

184-
expr2 <- logistic_reg(mixture = varying()) %>% set_engine("glmnet")
184+
expr2 <- logistic_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = varying())
185185
expr2_exp <- logistic_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = 10)
186186

187-
expr3 <- logistic_reg(mixture = 0, penalty = varying())
188-
expr3_exp <- logistic_reg(mixture = 1)
187+
expr3 <- logistic_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet", nlambda = varying())
188+
expr3_exp <- logistic_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10)
189189

190190
expr4 <- logistic_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10)
191191
expr4_exp <- logistic_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10, pmax = 2)
@@ -194,7 +194,8 @@ test_that('updating', {
194194
expr5_exp <- logistic_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10, pmax = 2)
195195

196196
expect_equal(update(expr1, mixture = 0), expr1_exp)
197-
expect_equal(update(expr3, mixture = 1, fresh = TRUE), expr3_exp)
197+
expect_equal(update(expr2, nlambda = 10), expr2_exp)
198+
expect_equal(update(expr3, mixture = 1, fresh = TRUE, nlambda = 10), expr3_exp)
198199

199200
param_tibb <- tibble::tibble(mixture = 1/3, penalty = 1)
200201
param_list <- as.list(param_tibb)

0 commit comments

Comments
 (0)