Skip to content

Commit e683019

Browse files
2x speedup with smarter JITing on SNMF
1 parent fbeba62 commit e683019

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

turftopic/models/_snmf.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""This file implements semi-NMF, where doc_topic proportions are not allowed to be negative, but components are unbounded."""
22

33
import warnings
4+
from functools import partial
45
from typing import Optional
56

67
import numpy as np
@@ -33,20 +34,17 @@ def init_G(
3334
return G + constant
3435

3536

36-
@jit
3737
def separate(A):
3838
abs_A = jnp.abs(A)
3939
pos = (abs_A + A) / 2
4040
neg = (abs_A - A) / 2
4141
return pos, neg
4242

4343

44-
@jit
4544
def update_F(X, G):
4645
return X @ G @ jnp.linalg.inv(G.T @ G)
4746

4847

49-
@jit
5048
def update_G(X, G, F, sparsity=0):
5149
pos_xtf, neg_xtf = separate(X.T @ F)
5250
pos_gftf, neg_gftf = separate(G @ (F.T @ F))
@@ -59,12 +57,18 @@ def update_G(X, G, F, sparsity=0):
5957
return G
6058

6159

62-
@jit
6360
def rec_err(X, F, G):
6461
err = X - (F @ G.T)
6562
return jnp.linalg.norm(err)
6663

6764

65+
def step(G, F, X, sparsity=0):
66+
G = update_G(X.T, G, F, sparsity)
67+
F = update_F(X.T, G)
68+
error = rec_err(X.T, F, G)
69+
return G, F, error
70+
71+
6872
class SNMF(TransformerMixin, BaseEstimator):
6973
def __init__(
7074
self,
@@ -89,14 +93,13 @@ def fit_transform(self, X: np.ndarray, y=None):
8993
F = update_F(X.T, G)
9094
error_at_init = rec_err(X.T, F, G)
9195
prev_error = error_at_init
96+
_step = jit(partial(step, sparsity=self.sparsity, X=X))
9297
for i in trange(
9398
self.max_iter,
9499
desc="Iterative updates.",
95100
disable=not self.progress_bar,
96101
):
97-
G = update_G(X.T, G, F, self.sparsity)
98-
F = update_F(X.T, G)
99-
error = rec_err(X.T, F, G)
102+
G, F, error = _step(G, F)
100103
difference = prev_error - error
101104
if (error < error_at_init) and (
102105
(prev_error - error) / error_at_init

0 commit comments

Comments
 (0)