Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- 'lts'
- '1'
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJIteration"
uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.6.4"
version = "0.6.5"

[deps]
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
Expand All @@ -11,7 +11,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[compat]
IterationControl = "0.5"
MLJBase = "1.4"
MLJBase = "1.5"
julia = "1.6"

[extras]
Expand Down
18 changes: 13 additions & 5 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const IterationResamplingTypes =

## TYPES AND CONSTRUCTOR

mutable struct DeterministicIteratedModel{M<:Deterministic} <: MLJBase.Deterministic
mutable struct DeterministicIteratedModel{M<:Deterministic,L} <: MLJBase.Deterministic
model::M
controls
resampling # resampling strategy
Expand All @@ -16,9 +16,10 @@ mutable struct DeterministicIteratedModel{M<:Deterministic} <: MLJBase.Determini
check_measure::Bool
iteration_parameter::Union{Nothing,Symbol,Expr}
cache::Bool
logger::L
end

mutable struct ProbabilisticIteratedModel{M<:Probabilistic} <: MLJBase.Probabilistic
mutable struct ProbabilisticIteratedModel{M<:Probabilistic,L} <: MLJBase.Probabilistic
model::M
controls
resampling # resampling strategy
Expand All @@ -30,6 +31,7 @@ mutable struct ProbabilisticIteratedModel{M<:Probabilistic} <: MLJBase.Probabili
check_measure::Bool
iteration_parameter::Union{Nothing,Symbol,Expr}
cache::Bool
logger::L
end

const ERR_MISSING_TRAINING_CONTROL =
Expand All @@ -39,8 +41,8 @@ const ERR_MISSING_TRAINING_CONTROL =

const ERR_TOO_MANY_ARGUMENTS =
ArgumentError("At most one non-keyword argument allowed. ")
const EitherIteratedModel{M} =
Union{DeterministicIteratedModel{M},ProbabilisticIteratedModel{M}}
const EitherIteratedModel{M,L} =
Union{DeterministicIteratedModel{M,L},ProbabilisticIteratedModel{M,L}}
const ERR_NOT_SUPERVISED =
ArgumentError("Only `Deterministic` and `Probabilistic` "*
"model types supported.")
Expand Down Expand Up @@ -148,6 +150,10 @@ Available controls: $CONTROLS_LIST.
between iteration parameter increments; specify `cache=false` to prioritize memory over
speed.

- `logger=default_logger()`: a logger for externally reporting model performance
evaluations, such as an `MLJFlow.Logger` instance. On startup,
`default_logger()=nothing`; use `default_logger(logger)` to set a global logger.


# Training

Expand Down Expand Up @@ -236,7 +242,8 @@ function IteratedModel(args...;
retrain=false,
check_measure=true,
iteration_parameter=nothing,
cache=true)
cache=true,
logger=MLJBase.default_logger())

length(args) < 2 || throw(ArgumentError("At most one non-keyword argument allowed. "))
if length(args) === 1
Expand All @@ -260,6 +267,7 @@ function IteratedModel(args...;
check_measure,
iteration_parameter,
cache,
logger,
)

if atom isa Deterministic
Expand Down
3 changes: 2 additions & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ function MLJBase.fit(iterated_model::EitherIteratedModel, verbosity, data...)
class_weights=iterated_model.class_weights,
operation=iterated_model.operation,
check_measure=iterated_model.check_measure,
cache=iterated_model.cache)
cache=iterated_model.cache,
logger=iterated_model.logger)
machine(resampler, data..., cache=false)
end

Expand Down
2 changes: 1 addition & 1 deletion src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ for trait in [:supports_weights,
quote
# try to get trait at level of types ("failure" here just
# means falling back to `Unknown`):
MLJBase.$trait(::Type{<:$T{M}}) where M = MLJBase.$trait(M)
MLJBase.$trait(::Type{<:$T{M,L}}) where {M,L} = MLJBase.$trait(M)
end |> eval
end
end
119 changes: 119 additions & 0 deletions test/logger.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
module TestLogger

using Test
using MLJIteration
using MLJBase
using StatisticalMeasures
using ..DummyModel

X, y = make_dummy(N=20)

# A minimal logger that records each evaluation's first measurement into a buffer.
struct DummyLogger
buffer::IOBuffer
end

MLJBase.log_evaluation(logger::DummyLogger, performance_evaluation) =
write(logger.buffer, performance_evaluation.measurement[1])

@testset "explicit logger with Holdout" begin
buffer = IOBuffer()
logger = DummyLogger(buffer)

model = DummyIterativeModel(n=0)
controls = [Step(2), NumberLimit(5)]

imodel = IteratedModel(
model=model,
resampling=Holdout(fraction_train=0.7),
controls=controls,
measure=l2,
logger=logger,
)
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)

# Each control cycle triggers one evaluate! call, which should log once.
# With Step(2) and NumberLimit(5), we get 5 control cycles.
seekstart(buffer)
logged_values = Float64[]
while !eof(buffer)
push!(logged_values, read(buffer, Float64))
end

@test length(logged_values) == 5
close(buffer)
end

@testset "logger=nothing produces no logging" begin
model = DummyIterativeModel(n=0)
controls = [Step(2), NumberLimit(3)]

imodel = IteratedModel(
model=model,
resampling=Holdout(fraction_train=0.7),
controls=controls,
measure=l2,
logger=nothing,
)
mach = machine(imodel, X, y)
# Should run without error; log_evaluation(::Nothing, ...) is a no-op.
fit!(mach, verbosity=0)
@test true
end

@testset "default_logger integration" begin
buffer = IOBuffer()
logger = DummyLogger(buffer)
default_logger(logger)

model = DummyIterativeModel(n=0)
controls = [Step(1), NumberLimit(3)]

# No explicit logger; should pick up the global default.
imodel = IteratedModel(
model=model,
resampling=Holdout(fraction_train=0.7),
controls=controls,
measure=l2,
)
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)

seekstart(buffer)
logged_values = Float64[]
while !eof(buffer)
push!(logged_values, read(buffer, Float64))
end

@test length(logged_values) == 3

# Reset global default.
default_logger(nothing)
close(buffer)
end

@testset "logger not invoked when resampling=nothing" begin
buffer = IOBuffer()
logger = DummyLogger(buffer)

model = DummyIterativeModel(n=0)
controls = [Step(2), NumberLimit(3)]

imodel = IteratedModel(
model=model,
resampling=nothing,
controls=controls,
logger=logger,
)
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)

# No Resampler means no evaluate! call, so nothing should be logged.
@test position(buffer) == 0
close(buffer)
end

end

true
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ end
@testset "traits" begin
include("traits.jl")
end

@testset "logger" begin
include("logger.jl")
end
Loading