11"""This file implements semi-NMF, where doc_topic proportions are not allowed to be negative, but components are unbounded."""
22
33import warnings
4+ from functools import partial
45from typing import Optional
56
67import numpy as np
@@ -33,20 +34,17 @@ def init_G(
3334 return G + constant
3435
3536
36- @jit
3737def separate (A ):
3838 abs_A = jnp .abs (A )
3939 pos = (abs_A + A ) / 2
4040 neg = (abs_A - A ) / 2
4141 return pos , neg
4242
4343
44- @jit
4544def update_F (X , G ):
4645 return X @ G @ jnp .linalg .inv (G .T @ G )
4746
4847
49- @jit
5048def update_G (X , G , F , sparsity = 0 ):
5149 pos_xtf , neg_xtf = separate (X .T @ F )
5250 pos_gftf , neg_gftf = separate (G @ (F .T @ F ))
@@ -59,12 +57,18 @@ def update_G(X, G, F, sparsity=0):
5957 return G
6058
6159
62- @jit
6360def rec_err (X , F , G ):
6461 err = X - (F @ G .T )
6562 return jnp .linalg .norm (err )
6663
6764
65+ def step (G , F , X , sparsity = 0 ):
66+ G = update_G (X .T , G , F , sparsity )
67+ F = update_F (X .T , G )
68+ error = rec_err (X .T , F , G )
69+ return G , F , error
70+
71+
6872class SNMF (TransformerMixin , BaseEstimator ):
6973 def __init__ (
7074 self ,
@@ -89,14 +93,13 @@ def fit_transform(self, X: np.ndarray, y=None):
8993 F = update_F (X .T , G )
9094 error_at_init = rec_err (X .T , F , G )
9195 prev_error = error_at_init
96+ _step = jit (partial (step , sparsity = self .sparsity , X = X ))
9297 for i in trange (
9398 self .max_iter ,
9499 desc = "Iterative updates." ,
95100 disable = not self .progress_bar ,
96101 ):
97- G = update_G (X .T , G , F , self .sparsity )
98- F = update_F (X .T , G )
99- error = rec_err (X .T , F , G )
102+ G , F , error = _step (G , F )
100103 difference = prev_error - error
101104 if (error < error_at_init ) and (
102105 (prev_error - error ) / error_at_init
0 commit comments