Skip to content

Commit b2012b0

Browse files
committed
fixup param_indices()
1 parent 165474c commit b2012b0

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

src/types.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,19 @@ nparams(model::AbstractSem) = length(params(model))
2929

3030
params(model::AbstractSemSingle) = params(model.imply)
3131
nparams(model::AbstractSemSingle) = nparams(model.imply)
32+
3233
"""
3334
param_indices(semobj)
34-
param_indices(param_names, semobj)
3535
36-
Returns either a dict of parameter names and their indices in `semobj`.
37-
If `param_names` are provided, returns a vector their indices in `semobj` instead.
36+
Returns a dict of parameter names and their indices in `semobj`.
3837
3938
# Examples
4039
```julia
4140
parind = param_indices(my_fitted_sem)
4241
parind[:param_name]
43-
44-
parind = param_indices([:param_name_1, param_name_2], my_fitted_sem)
4542
```
4643
"""
47-
param_indices(semobj) = Dict(params(semobj) .=> 1:nparams(semobj))
48-
param_indices(param_names, semobj) = getindex.([Dict(params(semobj) .=> 1:nparams(semobj))], param_names)
44+
param_indices(semobj) = Dict(par => i for (i, par) in enumerate(params(semobj)))
4945

5046
"""
5147
SemLoss(args...; loss_weights = nothing, ...)

0 commit comments

Comments
 (0)