Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions src/GsvdInitialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f;
return W, H
else
W_recover, H_recover = gsvdrecover(X, copy(W), copy(H), kadd, f)
result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=W_recover, H0=H_recover)
return result_recover.W, result_recover.H
result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=copy(W_recover), H0=copy(H_recover))
return result_recover, W_recover, H_recover
end
end
gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, n2::Int; kwargs...) = gsvdnmf(X, W, H, tsvd(X, n2); kwargs...)
Expand Down Expand Up @@ -117,7 +117,7 @@ function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kad
U0, S0, V0 = f
U0, S0, V0 = U0[:,1:n], S0[1:n], V0[:,1:n]
Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd)
Wadd, a = init_W(X, W0, H0, Hadd)
Wadd, a = init_W(X, W0, H0, Hadd)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trailing whitespace. You can set vscode to delete it automatically: https://www.codexcafe.com/blog/remove-trailing-whitespace/

Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd'))
W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn]
cs = Wcols_modification(X, W0_1, H0_1)
Expand Down Expand Up @@ -176,4 +176,40 @@ function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::Abstrac
return β[:]
end

function gsvdrecover_2r(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple)
m, n = size(W0)
kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components"))
if kadd == 0
return W0, H0, 0
else
U0, S0, V0 = f
U0, S0, V0 = U0[:,1:n], S0[1:n], V0[:,1:n]
Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd)
Wadd, a = init_W(X, W0, H0, Hadd)
# @show Wadd, Hadd
Wadd_nn, Hadd_nn = init2r(Wadd, Hadd)
W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn]
cs = Wcols_modification(X, W0_1, H0_1)
# @show cs
W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1
# W0_2, H0_2 = W0_1, H0_1
return abs.(W0_2), abs.(H0_2), Λ
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the abs needed? If it has tiny negative values, would max.(0, W0_2) be more appropriate? abs nominally runs the risk of turning a large negative number positive.

end
end

function init2r(U, Vt)
@assert size(U, 2) == size(Vt, 1)
z = zero(eltype(U))
r = size(U, 2)
W, H = similar(U, size(U, 1), 2r), similar(Vt, 2r, size(Vt, 2))
W[:, 1:r] .= max.(z, U[:, 1:r])
H[1:r, :] .= max.(z, Vt[1:r, :])
W[:, r+1:2r] .= -1 .* min.(z, U[:, 1:r])
H[r+1:2r, :] .= -1 .* min.(z, Vt[1:r, :])
keep = vec(sum(W; dims=1)) .> 0 .&& vec(sum(H; dims=2) .> 0)
W = W[:, keep]
H = H[keep, :]
return W, H
end

end
Loading