diff --git a/src/models/rfe.jl b/src/models/rfe.jl index c87f1d4..b907a7a 100644 --- a/src/models/rfe.jl +++ b/src/models/rfe.jl @@ -95,6 +95,12 @@ Train the machine using `fit!(mach, rows=...)`. `transform(mach, X)` above and predict using the fitted base model on the transformed table. +!!! note + + Because models wrapped in `RecursiveFeatureElimanation` are `Supervised`, the output + of `predict` is propagated in MLJ pipelines. To make this `transform` instead, + additionally wrap in `Transformer` as shown in the example below. + # Fitted parameters The fields of `fitted_params(mach)` are: @@ -146,6 +152,17 @@ predict(mach, Xnew) # transform data with all features to the reduced feature set: transform(mach, Xnew) ``` + +To use `selector` as a transformer in an MLJ pipeline, you must explicitly wrap in +`Transformer`, for otherwise it is the output of `predict` and not `transform` that is +propagated to the next model in the pipeline: + +```julia +pipe = Transformer(selector) |> ConstantRegressor() +mach = machine(pipe, X, y) |> fit! +predict(mach, Xnew) # prediction of `ConstantRegressor()` based on reduced features. +``` + """ function RecursiveFeatureElimination( args...;