Skip to content

Commit 9997387

Browse files
Merge pull request #304 from StructuralEquationModels/fiml_meanstructure_argument
Fiml meanstructure argument
2 parents 90a8009 + 811ae0f commit 9997387

File tree

8 files changed

+105
-10
lines changed

8 files changed

+105
-10
lines changed

src/implied/RAM/generic.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ function RAM(;
9898
)
9999
ram_matrices = convert(RAMMatrices, specification)
100100

101+
check_meanstructure_specification(meanstructure, ram_matrices)
102+
101103
# get dimensions of the model
102104
n_par = nparams(ram_matrices)
103105
n_obs = nobserved_vars(ram_matrices)
@@ -126,11 +128,6 @@ function RAM(;
126128
# μ
127129
if meanstructure
128130
MS = HasMeanStruct
129-
!isnothing(ram_matrices.M) || throw(
130-
ArgumentError(
131-
"You set `meanstructure = true`, but your model specification contains no mean parameters.",
132-
),
133-
)
134131
M_pre = materialize(ram_matrices.M, rand_params)
135132
∇M = gradient_required ? sparse_gradient(ram_matrices.M) : nothing
136133
μ = zeros(n_obs)

src/implied/RAM/symbolic.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ function RAMSymbolic(;
9292
)
9393
ram_matrices = convert(RAMMatrices, specification)
9494

95+
check_meanstructure_specification(meanstructure, ram_matrices)
96+
9597
n_par = nparams(ram_matrices)
9698
par = (Symbolics.@variables θ[1:n_par])[1]
9799

src/implied/abstract.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,17 @@ function check_acyclic(A::AbstractMatrix; verbose::Bool = false)
3131
return A
3232
end
3333
end
34+
35+
# Verify that the `meanstructure` argument aligns with the model specification.
36+
function check_meanstructure_specification(meanstructure, ram_matrices)
37+
if meanstructure & isnothing(ram_matrices.M)
38+
throw(ArgumentError(
39+
"You set `meanstructure = true`, but your model specification contains no mean parameters."
40+
))
41+
end
42+
if !meanstructure & !isnothing(ram_matrices.M)
43+
throw(ArgumentError(
44+
"If your model specification contains mean parameters, you have to set `Sem(..., meanstructure = true)`."
45+
))
46+
end
47+
end

src/loss/ML/FIML.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,14 @@ end
4343
### Constructors
4444
############################################################################################
4545

46-
function SemFIML(; observed::SemObservedMissing, specification, kwargs...)
46+
function SemFIML(; observed::SemObservedMissing, implied, specification, kwargs...)
47+
48+
if implied.meanstruct isa NoMeanStruct
49+
throw(ArgumentError(
50+
"Full information maximum likelihood (FIML) can only be used with a meanstructure.
51+
Did you forget to set `Sem(..., meanstructure = true)`?"))
52+
end
53+
4754
inverses =
4855
[zeros(nmeasured_vars(pat), nmeasured_vars(pat)) for pat in observed.patterns]
4956
choleskys = Array{Cholesky{Float64, Array{Float64, 2}}, 1}(undef, length(inverses))

src/loss/ML/ML.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ end
3939
############################################################################################
4040

4141
function SemML(; observed::SemObserved, approximate_hessian::Bool = false, kwargs...)
42+
43+
if observed isa SemObservedMissing
44+
throw(ArgumentError(
45+
"Normal maximum likelihood estimation can't be used with `SemObservedMissing`.
46+
Use full information maximum likelihood (FIML) estimation or remove missing
47+
values in your data.
48+
A FIML model can be constructed with
49+
Sem(
50+
...,
51+
observed = SemObservedMissing,
52+
loss = SemFIML,
53+
meanstructure = true
54+
)"))
55+
end
56+
4257
obsmean = obs_mean(observed)
4358
obscov = obs_cov(observed)
4459
meandiff = isnothing(obsmean) ? nothing : copy(obsmean)

src/loss/WLS/WLS.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,33 @@ SemWLS{HE}(args...) where {HE <: HessianEval} =
5151

5252
function SemWLS(;
5353
observed,
54+
implied,
5455
wls_weight_matrix = nothing,
5556
wls_weight_matrix_mean = nothing,
5657
approximate_hessian = false,
5758
meanstructure = false,
5859
kwargs...,
5960
)
61+
62+
if observed isa SemObservedMissing
63+
throw(ArgumentError(
64+
"WLS estimation can't be used with `SemObservedMissing`.
65+
Use full information maximum likelihood (FIML) estimation or remove missing
66+
values in your data.
67+
A FIML model can be constructed with
68+
Sem(
69+
...,
70+
observed = SemObservedMissing,
71+
loss = SemFIML,
72+
meanstructure = true
73+
)"))
74+
end
75+
76+
if !(implied isa RAMSymbolic)
77+
throw(ArgumentError(
78+
"WLS estimation is only available with the implied type RAMSymbolic at the moment."))
79+
end
80+
6081
nobs_vars = nobserved_vars(observed)
6182
tril_ind = filter(x -> (x[1] >= x[2]), CartesianIndices(obs_cov(observed)))
6283
s = obs_cov(observed)[tril_ind]

src/observed/data.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,24 @@ function SemObservedData(;
3838
observed_var_prefix::Union{Symbol, AbstractString} = :obs,
3939
kwargs...,
4040
)
41+
4142
data, obs_vars, _ =
4243
prepare_data(data, observed_vars, specification; observed_var_prefix)
4344
obs_mean, obs_cov = mean_and_cov(data, 1)
4445

46+
if any(ismissing.(data))
47+
throw(ArgumentError(
48+
"Your dataset contains missing values.
49+
Remove missing values or use full information maximum likelihood (FIML) estimation.
50+
A FIML model can be constructed with
51+
Sem(
52+
...,
53+
observed = SemObservedMissing,
54+
loss = SemFIML,
55+
meanstructure = true
56+
)"))
57+
end
58+
4559
return SemObservedData(data, obs_vars, obs_cov, vec(obs_mean), size(data, 1))
4660
end
4761

test/unit_tests/model.jl

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,13 @@ function test_params_api(semobj, spec::SemSpecification)
4646
@test @inferred(param_labels(semobj)) == param_labels(spec)
4747
end
4848

49-
@testset "Sem(implied=$impliedtype, loss=$losstype)" for impliedtype in (RAM, RAMSymbolic),
50-
losstype in (SemML, SemWLS)
49+
@testset "Sem(implied=$impliedtype, loss=SemML)" for impliedtype in (RAM, RAMSymbolic)
5150

5251
model = Sem(
5352
specification = ram_matrices,
5453
observed = obs,
5554
implied = impliedtype,
56-
loss = losstype,
55+
loss = SemML,
5756
)
5857

5958
@test model isa Sem
@@ -68,7 +67,33 @@ end
6867

6968
@test @inferred(loss(model)) isa SemLoss
7069
semloss = loss(model).functions[1]
71-
@test semloss isa losstype
70+
@test semloss isa SemML
7271

7372
@test @inferred(nsamples(model)) == nsamples(obs)
7473
end
74+
75+
@testset "Sem(implied=RAMSymbolic, loss=SemWLS)" begin
76+
77+
model = Sem(
78+
specification = ram_matrices,
79+
observed = obs,
80+
implied = RAMSymbolic,
81+
loss = SemWLS,
82+
)
83+
84+
@test model isa Sem
85+
@test @inferred(implied(model)) isa RAMSymbolic
86+
@test @inferred(observed(model)) isa SemObserved
87+
88+
test_vars_api(model, ram_matrices)
89+
test_params_api(model, ram_matrices)
90+
91+
test_vars_api(implied(model), ram_matrices)
92+
test_params_api(implied(model), ram_matrices)
93+
94+
@test @inferred(loss(model)) isa SemLoss
95+
semloss = loss(model).functions[1]
96+
@test semloss isa SemWLS
97+
98+
@test @inferred(nsamples(model)) == nsamples(obs)
99+
end

0 commit comments

Comments
 (0)