Skip to content

Commit 24a7947

Browse files
fix nobs var check in update_observed, keep previous args in update_observed, fix bs + tests
1 parent c27706d commit 24a7947

File tree

10 files changed

+71
-36
lines changed

10 files changed

+71
-36
lines changed

src/frontend/fit/standard_errors/bootstrap.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
bootstrap(
3-
fitted::SemFit;
3+
fitted::SemFit,
4+
specification::SemSpecification;
45
statistic = solution,
56
n_boot = 3000,
67
data = nothing,
7-
specification = nothing,
88
engine = :Optim,
99
parallel = false,
1010
fit_kwargs = Dict(),
@@ -14,12 +14,11 @@ Return bootstrap samples for `statistic`.
1414
1515
# Arguments
1616
- `fitted`: a fitted SEM.
17+
- `specification`: a `ParameterTable` or `RAMMatrices` object passed to `replace_observed`.
1718
- `statistic`: any function that can be called on a `SemFit` object.
1819
The output will be returned as the bootstrap sample.
1920
- `n_boot`: number of boostrap samples
2021
- `data`: data to sample from. Only needed if different than the data from `sem_fit`
21-
- `specification`: a `ParameterTable` or `RAMMatrices` object passed to `replace_observed`.
22-
Necessary for FIML / WLS models.
2322
- `engine`: optimizer engine, passed to `fit`.
2423
- `parallel`: if `true`, run bootstrap samples in parallel on all available threads.
2524
The number of threads is controlled by the `JULIA_NUM_THREADS` environment variable or
@@ -40,11 +39,11 @@ bootstrap(
4039
```
4140
"""
4241
function bootstrap(
43-
fitted::SemFit;
42+
fitted::SemFit,
43+
specification::SemSpecification;
4444
statistic = solution,
4545
n_boot = 3000,
4646
data = nothing,
47-
specification = nothing,
4847
engine = :Optim,
4948
parallel = false,
5049
fit_kwargs = Dict(),
@@ -56,6 +55,7 @@ function bootstrap(
5655
# pre-allocations
5756
out = []
5857
conv = []
58+
errors = []
5959
n_failed = Ref(0)
6060
# fit to bootstrap samples
6161
if !parallel
@@ -73,8 +73,9 @@ function bootstrap(
7373
c = converged(new_fit)
7474
push!(out, sample)
7575
push!(conv, c)
76-
catch
76+
catch e
7777
n_failed[] += 1
78+
push!(errors, e)
7879
end
7980
end
8081
else
@@ -103,9 +104,10 @@ function bootstrap(
103104
push!(out, sample)
104105
push!(conv, c)
105106
end
106-
catch
107+
catch e
107108
lock(lk) do
108109
n_failed[] += 1
110+
push!(errors, e)
109111
end
110112
finally
111113
put!(model_pool, thread_model)
@@ -119,19 +121,19 @@ function bootstrap(
119121
return Dict(
120122
:samples => out,
121123
:n_boot => n_boot,
122-
:n_converged => sum(conv),
124+
:n_converged => isempty(conv) ? 0 : sum(conv),
123125
:converged => conv,
124126
:n_errored => n_failed[],
127+
:errors => errors
125128
)
126129
end
127130

128131
"""
129132
se_bootstrap(
130-
fitted::SemFit;
133+
fitted::SemFit,
134+
specification::SemSpecification;
131135
n_boot = 3000,
132136
data = nothing,
133-
specification = nothing,
134-
engine = :Optim,
135137
parallel = false,
136138
fit_kwargs = Dict(),
137139
replace_kwargs = Dict())
@@ -140,10 +142,9 @@ Return bootstrap standard errors.
140142
141143
# Arguments
142144
- `fitted`: a fitted SEM.
145+
- `specification`: a `ParameterTable` or `RAMMatrices` object passed to `replace_observed`.
143146
- `n_boot`: number of boostrap samples
144147
- `data`: data to sample from. Only needed if different than the data from `sem_fit`
145-
- `specification`: a `ParameterTable` or `RAMMatrices` object passed to `replace_observed`.
146-
Necessary for FIML / WLS models.
147148
- `engine`: optimizer engine, passed to `fit`.
148149
- `parallel`: if `true`, run bootstrap samples in parallel on all available threads.
149150
The number of threads is controlled by the `JULIA_NUM_THREADS` environment variable or
@@ -165,10 +166,10 @@ se_bootstrap(
165166
```
166167
"""
167168
function se_bootstrap(
168-
fitted::SemFit;
169+
fitted::SemFit,
170+
specification::SemSpecification;
169171
n_boot = 3000,
170172
data = nothing,
171-
specification = nothing,
172173
engine = :Optim,
173174
parallel = false,
174175
fit_kwargs = Dict(),

src/implied/RAM/generic.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,13 @@ end
196196
############################################################################################
197197

198198
function update_observed(implied::RAM, observed::SemObserved; kwargs...)
199-
if nobserved_vars(observed) == size(implied.Σ, 1)
199+
if nobserved_vars(observed) == nobserved_vars(implied)
200200
return implied
201201
else
202-
return RAM(; observed = observed, kwargs...)
202+
return RAM(;
203+
observed = observed,
204+
gradient_required = !isnothing(implied.∇A),
205+
meanstructure = MeanStruct(implied) == HasMeanStruct,
206+
kwargs...)
203207
end
204208
end

src/implied/RAM/symbolic.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,17 @@ end
210210
############################################################################################
211211

212212
function update_observed(implied::RAMSymbolic, observed::SemObserved; kwargs...)
213-
if nobserved_vars(observed) == size(implied.Σ, 1)
213+
if nobserved_vars(observed) == nobserved_vars(implied)
214214
return implied
215215
else
216-
return RAMSymbolic(; observed = observed, kwargs...)
216+
return RAMSymbolic(;
217+
observed = observed,
218+
vech = implied.Σ isa Vector,
219+
gradient = !isnothing(implied.∇Σ),
220+
hessian = !isnothing(implied.∇²Σ),
221+
meanstructure = MeanStruct(implied) == HasMeanStruct,
222+
approximate_hessian = isnothing(implied.∇²Σ),
223+
kwargs...)
217224
end
218225
end
219226

src/loss/ML/ML.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,9 @@ function update_observed(lossfun::SemML, observed::SemObserved; kwargs...)
237237
if size(lossfun.Σ⁻¹) == size(obs_cov(observed))
238238
return lossfun
239239
else
240-
return SemML(; observed = observed, kwargs...)
240+
return SemML(;
241+
observed = observed,
242+
approximate_hessian = HessianEval(lossfun) == ApproxHessian,
243+
kwargs...)
241244
end
242245
end

src/loss/WLS/WLS.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,5 +173,7 @@ end
173173
### Recommended methods
174174
############################################################################################
175175

176-
update_observed(lossfun::SemWLS, observed::SemObserved; kwargs...) =
177-
SemWLS(; observed = observed, kwargs...)
176+
update_observed(lossfun::SemWLS, observed::SemObserved; kwargs...) = SemWLS(;
177+
observed = observed,
178+
meanstructure = MeanStruct(kwargs[:implied]) == HasMeanStruct,
179+
kwargs...)

test/examples/helper.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,28 @@ function test_estimates(
136136
end
137137
end
138138

139-
function test_bootstrap(model_fit, spec; n_boot = 500)
139+
function test_bootstrap(model_fit, spec; compare_hessian = true, compare_bs = true, n_boot = 500)
140+
se_bs = se_bootstrap(model_fit, spec; n_boot = n_boot)
140141
# hessian and bootstrap se are close
141-
se_he = se_hessian(model_fit)
142-
se_bs = se_bootstrap(model_fit; specification = spec, n_boot = n_boot)
143-
@test isapprox(se_bs, se_he, rtol = 0.2)
142+
if compare_hessian
143+
se_he = se_hessian(model_fit)
144+
@test isapprox(se_bs, se_he, rtol = 0.2)
145+
end
144146
# se_bootstrap and bootstrap |> se are close
145-
bs_samples = bootstrap(model_fit; specification = spec, n_boot = n_boot)
146-
@test bs_samples[:n_converged] > 0.95*n_boot
147-
bs_samples = cat(bs_samples[:samples][BitVector(bs_samples[:converged])]..., dims = 2)
148-
se_bs_2 = sqrt.(var(bs_samples, corrected = false, dims = 2))
149-
@test isapprox(se_bs_2, se_bs, rtol = 0.05)
147+
if compare_bs
148+
bs_samples = bootstrap(model_fit, spec; n_boot = n_boot)
149+
@test bs_samples[:n_converged] > 0.95*n_boot
150+
bs_samples = cat(bs_samples[:samples][BitVector(bs_samples[:converged])]..., dims = 2)
151+
se_bs_2 = sqrt.(var(bs_samples, corrected = false, dims = 2))
152+
@test isapprox(se_bs_2, se_bs, rtol = 0.05)
153+
end
154+
end
155+
156+
function smoketest_bootstrap(model_fit, spec; n_boot = 5)
157+
# hessian and bootstrap se are close
158+
se_bs = se_bootstrap(model_fit, spec; n_boot = n_boot)
159+
bs_samples = bootstrap(model_fit, spec; n_boot = n_boot)
160+
return se_bs, bs_samples
150161
end
151162

152163
function smoketest_CI_z(model_fit, partable)

test/examples/multigroup/build_models.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ end
293293
lav_col = :se,
294294
lav_groups = Dict(:Pasteur => 1, :Grant_White => 2),
295295
)
296-
test_bootstrap(solution_ls, partable)
296+
test_bootstrap(solution_ls, partable; compare_bs = false)
297297
smoketest_CI_z(solution_ls, partable)
298298
end
299299

test/examples/multigroup/multigroup.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
using StructuralEquationModels, Test, FiniteDiff, Suppressor
22
using LinearAlgebra: diagind, LowerTriangular
33
using Statistics: var
4+
using Random
5+
6+
Random.seed!(948723)
47

58
const SEM = StructuralEquationModels
69

test/examples/political_democracy/constructor.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ end
161161
lav_col = :se,
162162
)
163163

164-
test_bootstrap(solution_ls, partable)
164+
test_bootstrap(solution_ls, partable; compare_bs = false)
165165
smoketest_CI_z(solution_ls, partable)
166166
end
167167

@@ -373,7 +373,8 @@ end
373373
lav_col = :se,
374374
)
375375

376-
test_bootstrap(solution_ls, partable_mean)
376+
test_bootstrap(solution_ls, partable_mean, compare_bs = false)
377+
# smoketest_bootstrap(solution_ls, partable_mean)
377378
smoketest_CI_z(solution_ls, partable_mean)
378379
end
379380

@@ -507,6 +508,7 @@ end
507508
lav_col = :se,
508509
)
509510

510-
test_bootstrap(solution_ml, partable_mean)
511+
# test_bootstrap(solution_ml, partable_mean) # too much compute
512+
smoketest_bootstrap(solution_ml, partable_mean)
511513
smoketest_CI_z(solution_ml, partable_mean)
512514
end

test/examples/political_democracy/political_democracy.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using StructuralEquationModels, Test, Suppressor, FiniteDiff
22
using Statistics: cov, mean, var
33
using Random, NLopt
44

5+
Random.seed!(464577)
6+
57
SEM = StructuralEquationModels
68

79
include(

0 commit comments

Comments
 (0)