Skip to content

Commit c71ad65

Browse files
committed
Updating for mlp() engine specific args
1 parent 4ac31fb commit c71ad65

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

R/mlp.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ update.mlp <-
120120
hidden_units = NULL, penalty = NULL, dropout = NULL,
121121
epochs = NULL, activation = NULL,
122122
fresh = FALSE, ...) {
123-
update_dot_check(...)
123+
124+
eng_args <- update_engine_parameters(object$eng_args, ...)
124125

125126
if (!is.null(parameters)) {
126127
parameters <- check_final_param(parameters)
@@ -139,12 +140,15 @@ update.mlp <-
139140
# TODO make these blocks into a function and document well
140141
if (fresh) {
141142
object$args <- args
143+
object$eng_args <- eng_args
142144
} else {
143145
null_args <- map_lgl(args, null_value)
144146
if (any(null_args))
145147
args <- args[!null_args]
146148
if (length(args) > 0)
147149
object$args[names(args)] <- args
150+
if (length(eng_args) > 0)
151+
object$eng_args[names(eng_args)] <- eng_args
148152
}
149153

150154
new_model_spec(

tests/testthat/test_mlp.R

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,21 +131,23 @@ test_that('updating', {
131131
expr1_exp <- mlp(mode = "regression", hidden_units = 2) %>%
132132
set_engine("nnet", Hess = FALSE, abstol = varying())
133133

134-
expr2 <- mlp(mode = "regression", hidden_units = 7) %>% set_engine("nnet")
135-
expr2_exp <- mlp(mode = "regression", hidden_units = 7) %>% set_engine("nnet", Hess = FALSE)
134+
expr2 <- mlp(mode = "regression") %>% set_engine("nnet", Hess = varying())
135+
expr2_exp <- mlp(mode = "regression") %>% set_engine("nnet", Hess = FALSE)
136136

137137
expr3 <- mlp(mode = "regression", hidden_units = 7, epochs = varying()) %>% set_engine("keras")
138138

139139
expr3_exp <- mlp(mode = "regression", hidden_units = 2) %>% set_engine("keras")
140140

141141
expr4 <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = varying())
142-
expr4_exp <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = varying())
142+
expr4_exp <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = 1e-3)
143143

144144
expr5 <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE)
145145
expr5_exp <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = varying())
146146

147147
expect_equal(update(expr1, hidden_units = 2), expr1_exp)
148+
expect_equal(update(expr2, Hess = FALSE), expr2_exp)
148149
expect_equal(update(expr3, hidden_units = 2, fresh = TRUE), expr3_exp)
150+
expect_equal(update(expr4, abstol = 1e-3), expr4_exp)
149151

150152
param_tibb <- tibble::tibble(hidden_units = 3, dropout = .1)
151153
param_list <- as.list(param_tibb)

0 commit comments

Comments
 (0)