44
55
66class SNMFOptimizer :
7- """A implementation of stretched NMF (sNMF), including sparse stretched NMF.
7+ """An implementation of stretched NMF (sNMF), including sparse stretched NMF.
88
99 Instantiating the SNMFOptimizer class runs all the analysis immediately.
1010 The results matrices can then be accessed as instance attributes
@@ -117,35 +117,38 @@ def __init__(
117117 self .rho = rho
118118 self .eta = eta
119119 # Capture matrix dimensions
120- self ._signal_len , self ._num_conditions = source_matrix .shape
120+ self .signal_length , self .n_signals = source_matrix .shape
121121 self .num_updates = 0
122122 self ._rng = np .random .default_rng (random_state )
123123
124124 # Enforce exclusive specification of n_components or Y0
125125 if (n_components is None and init_weights is None ) or (
126126 n_components is not None and init_weights is not None
127127 ):
128- raise ValueError ("Must provide exactly one of init_weights or n_components, but not both." )
128+ raise ValueError (
129+ "Conflicting source for n_components. Must provide either init_weights or n_components "
130+ "directly, but not both."
131+ )
129132
130133 # Initialize weights and determine number of components
131134 if init_weights is None :
132- self ._n_components = n_components
133- self .weights = self ._rng .beta (a = 2.5 , b = 1.5 , size = (self ._n_components , self ._num_conditions ))
135+ self .n_components = n_components
136+ self .weights = self ._rng .beta (a = 2.5 , b = 1.5 , size = (self .n_components , self .n_signals ))
134137 else :
135- self ._n_components = init_weights .shape [0 ]
138+ self .n_components = init_weights .shape [0 ]
136139 self .weights = init_weights
137140
138141 # Initialize stretching matrix if not provided
139142 if init_stretch is None :
140- self .stretch = np .ones ((self ._n_components , self ._num_conditions )) + self ._rng .normal (
141- 0 , 1e-3 , size = (self ._n_components , self ._num_conditions )
143+ self .stretch = np .ones ((self .n_components , self .n_signals )) + self ._rng .normal (
144+ 0 , 1e-3 , size = (self .n_components , self .n_signals )
142145 )
143146 else :
144147 self .stretch = init_stretch
145148
146149 # Initialize component matrix if not provided
147150 if init_components is None :
148- self .components = self ._rng .random ((self ._signal_len , self ._n_components ))
151+ self .components = self ._rng .random ((self .signal_length , self .n_components ))
149152 else :
150153 self .components = init_components
151154
@@ -155,7 +158,7 @@ def __init__(
155158
156159 # Second-order spline: Tridiagonal (-2 on diagonal, 1 on sub/superdiagonals)
157160 self .spline_smooth_operator = 0.25 * diags (
158- [1 , - 2 , 1 ], offsets = [0 , 1 , 2 ], shape = (self ._num_conditions - 2 , self ._num_conditions )
161+ [1 , - 2 , 1 ], offsets = [0 , 1 , 2 ], shape = (self .n_signals - 2 , self .n_signals )
159162 )
160163 self .spline_smooth_penalty = self .spline_smooth_operator .T @ self .spline_smooth_operator
161164
@@ -351,34 +354,34 @@ def apply_interpolation_matrix(self, components=None, weights=None, stretch=None
351354 stretch = self .stretch
352355
353356 # Compute scaled indices (MATLAB: AA = repmat(reshape(A',1,M*K).^-1, N,1))
354- stretch_flat = stretch .reshape (1 , self ._num_conditions * self ._n_components ) ** - 1
355- stretch_tiled = np .tile (stretch_flat , (self ._signal_len , 1 ))
357+ stretch_flat = stretch .reshape (1 , self .n_signals * self .n_components ) ** - 1
358+ stretch_tiled = np .tile (stretch_flat , (self .signal_length , 1 ))
356359
357360 # Compute `ii` (MATLAB: ii = repmat((0:N-1)',1,K*M).*tiled_stretch)
358361 fractional_indices = (
359- np .tile (np .arange (self ._signal_len )[:, None ], (1 , self ._num_conditions * self ._n_components ))
362+ np .tile (np .arange (self .signal_length )[:, None ], (1 , self .n_signals * self .n_components ))
360363 * stretch_tiled
361364 )
362365
363366 # Weighting matrix (MATLAB: YY = repmat(reshape(Y',1,M*K), N,1))
364- weights_flat = weights .reshape (1 , self ._num_conditions * self ._n_components )
365- weights_tiled = np .tile (weights_flat , (self ._signal_len , 1 ))
367+ weights_flat = weights .reshape (1 , self .n_signals * self .n_components )
368+ weights_tiled = np .tile (weights_flat , (self .signal_length , 1 ))
366369
367370 # Bias for indexing into reshaped X (MATLAB: bias = kron((0:K-1)*(N+1),ones(N,M)))
368371 # TODO break this up or describe what it does better
369372 bias = np .kron (
370- np .arange (self ._n_components ) * (self ._signal_len + 1 ),
371- np .ones ((self ._signal_len , self ._num_conditions ), dtype = int ),
372- ).reshape (self ._signal_len , self ._n_components * self ._num_conditions )
373+ np .arange (self .n_components ) * (self .signal_length + 1 ),
374+ np .ones ((self .signal_length , self .n_signals ), dtype = int ),
375+ ).reshape (self .signal_length , self .n_components * self .n_signals )
373376
374377 # Handle boundary conditions for interpolation (MATLAB: X1=[X;X(end,:)])
375378 components_bounded = np .vstack ([components , components [- 1 , :]]) # Duplicate last row (like MATLAB)
376379
377380 # Compute floor indices (MATLAB: II = floor(ii); II1=min(II+1,N+1); II2=min(II1+1,N+1))
378381 floor_indices = np .floor (fractional_indices ).astype (int )
379382
380- floor_ind_1 = np .minimum (floor_indices + 1 , self ._signal_len )
381- floor_ind_2 = np .minimum (floor_ind_1 + 1 , self ._signal_len )
383+ floor_ind_1 = np .minimum (floor_indices + 1 , self .signal_length )
384+ floor_ind_2 = np .minimum (floor_ind_1 + 1 , self .signal_length )
382385
383386 # Compute fractional part (MATLAB: iI = ii - II)
384387 fractional_floor_indices = fractional_indices - floor_indices
@@ -391,10 +394,10 @@ def apply_interpolation_matrix(self, components=None, weights=None, stretch=None
391394 # Note: this "-1" corrects an off-by-one error that may have originated in an earlier line
392395 # order = F uses FORTRAN, column major order
393396 components_val_1 = components_bounded .flatten (order = "F" )[(offset_floor_ind_1 - 1 ).ravel ()].reshape (
394- self ._signal_len , self ._n_components * self ._num_conditions
397+ self .signal_length , self .n_components * self .n_signals
395398 )
396399 components_val_2 = components_bounded .flatten (order = "F" )[(offset_floor_ind_2 - 1 ).ravel ()].reshape (
397- self ._signal_len , self ._n_components * self ._num_conditions
400+ self .signal_length , self .n_components * self .n_signals
398401 )
399402
400403 # Interpolation (MATLAB: Ax2=XI1.*(1-iI)+XI2.*(iI); stretched_components=Ax2.*YY)
@@ -435,30 +438,30 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
435438
436439 # Compute scaling matrix (MATLAB: AA = repmat(reshape(A,1,M*K).^-1,Nindex,1))
437440 stretch_tiled = np .tile (
438- stretch .reshape (1 , self ._num_conditions * self ._n_components , order = "F" ) ** - 1 , (self ._signal_len , 1 )
441+ stretch .reshape (1 , self .n_signals * self .n_components , order = "F" ) ** - 1 , (self .signal_length , 1 )
439442 )
440443
441444 # Compute indices (MATLAB: ii = repmat((index-1)',1,K*M).*AA)
442- indices = np .arange (self ._signal_len )[:, None ] * stretch_tiled # Shape (N, M*K), replacing `index`
445+ indices = np .arange (self .signal_length )[:, None ] * stretch_tiled # Shape (N, M*K), replacing `index`
443446
444447 # Weighting coefficients (MATLAB: YY = repmat(reshape(Y,1,M*K),Nindex,1))
445448 weights_tiled = np .tile (
446- weights .reshape (1 , self ._num_conditions * self ._n_components , order = "F" ), (self ._signal_len , 1 )
449+ weights .reshape (1 , self .n_signals * self .n_components , order = "F" ), (self .signal_length , 1 )
447450 )
448451
449452 # Compute floor indices (MATLAB: II = floor(ii); II1 = min(II+1,N+1); II2 = min(II1+1,N+1))
450453 floor_indices = np .floor (indices ).astype (int )
451- floor_indices_1 = np .minimum (floor_indices + 1 , self ._signal_len )
452- floor_indices_2 = np .minimum (floor_indices_1 + 1 , self ._signal_len )
454+ floor_indices_1 = np .minimum (floor_indices + 1 , self .signal_length )
455+ floor_indices_2 = np .minimum (floor_indices_1 + 1 , self .signal_length )
453456
454457 # Compute fractional part (MATLAB: iI = ii - II)
455458 fractional_indices = indices - floor_indices
456459
457460 # Expand row indices (MATLAB: repm = repmat(1:K, Nindex, M))
458- repm = np .tile (np .arange (self ._n_components ), (self ._signal_len , self ._num_conditions ))
461+ repm = np .tile (np .arange (self .n_components ), (self .signal_length , self .n_signals ))
459462
460463 # Compute transformations (MATLAB: kro = kron(R(index,:), ones(1, K)))
461- kron = np .kron (residuals , np .ones ((1 , self ._n_components )))
464+ kron = np .kron (residuals , np .ones ((1 , self .n_components )))
462465
463466 # (MATLAB: kroiI = kro .* (iI); iIYY = (iI-1) .* YY)
464467 fractional_kron = kron * fractional_indices
@@ -467,16 +470,16 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
467470 # Construct sparse matrices (MATLAB: sparse(II1_,repm,kro.*-iIYY,(N+1),K))
468471 x2 = coo_matrix (
469472 ((- kron * fractional_weights ).flatten (), (floor_indices_1 .flatten () - 1 , repm .flatten ())),
470- shape = (self ._signal_len + 1 , self ._n_components ),
473+ shape = (self .signal_length + 1 , self .n_components ),
471474 ).tocsc ()
472475 x3 = coo_matrix (
473476 ((fractional_kron * weights_tiled ).flatten (), (floor_indices_2 .flatten () - 1 , repm .flatten ())),
474- shape = (self ._signal_len + 1 , self ._n_components ),
477+ shape = (self .signal_length + 1 , self .n_components ),
475478 ).tocsc ()
476479
477480 # Combine the last row into previous, then remove the last row
478- x2 [self ._signal_len - 1 , :] += x2 [self ._signal_len , :]
479- x3 [self ._signal_len - 1 , :] += x3 [self ._signal_len , :]
481+ x2 [self .signal_length - 1 , :] += x2 [self .signal_length , :]
482+ x3 [self .signal_length - 1 , :] += x3 [self .signal_length , :]
480483 x2 = x2 [:- 1 , :]
481484 x3 = x3 [:- 1 , :]
482485
@@ -543,10 +546,10 @@ def update_components(self):
543546 stretched_components , _ , _ = self .apply_interpolation_matrix () # Skip the other two outputs (derivatives)
544547 # Compute RA and RR
545548 intermediate_reshaped = stretched_components .flatten (order = "F" ).reshape (
546- (self ._signal_len * self ._num_conditions , self ._n_components ), order = "F"
549+ (self .signal_length * self .n_signals , self .n_components ), order = "F"
547550 )
548551 reshaped_stretched_components = intermediate_reshaped .sum (axis = 1 ).reshape (
549- (self ._signal_len , self ._num_conditions ), order = "F"
552+ (self .signal_length , self .n_signals ), order = "F"
550553 )
551554 component_residuals = reshaped_stretched_components - self .source_matrix
552555 # Compute gradient `GraX`
@@ -603,11 +606,11 @@ def update_weights(self):
603606 Updates weights using matrix operations, solving a quadratic program via to do so.
604607 """
605608
606- for m in range (self ._num_conditions ):
607- t = np .zeros ((self ._signal_len , self ._n_components ))
609+ for m in range (self .n_signals ):
610+ t = np .zeros ((self .signal_length , self .n_components ))
608611
609612 # Populate T using apply_interpolation
610- for k in range (self ._n_components ):
613+ for k in range (self .n_components ):
611614 t [:, k ] = self .apply_interpolation (
612615 self .stretch [k , m ], self .components [:, k ], return_derivatives = True
613616 )[0 ].squeeze ()
@@ -635,21 +638,19 @@ def regularize_function(self, stretch=None):
635638
636639 # Compute residual
637640 intermediate_diff = stretched_components .flatten (order = "F" ).reshape (
638- (self ._signal_len * self ._num_conditions , self ._n_components ), order = "F"
639- )
640- stretch_difference = intermediate_diff .sum (axis = 1 ).reshape (
641- (self ._signal_len , self ._num_conditions ), order = "F"
641+ (self .signal_length * self .n_signals , self .n_components ), order = "F"
642642 )
643+ stretch_difference = intermediate_diff .sum (axis = 1 ).reshape ((self .signal_length , self .n_signals ), order = "F" )
643644 stretch_difference = stretch_difference - self .source_matrix
644645
645646 # Compute objective function
646647 reg_func = self .get_objective_function (stretch_difference , stretch )
647648
648649 # Compute gradient
649650 tiled_derivative = np .sum (
650- d_stretch_components * np .tile (stretch_difference , (1 , self ._n_components )), axis = 0
651+ d_stretch_components * np .tile (stretch_difference , (1 , self .n_components )), axis = 0
651652 )
652- der_reshaped = np .asarray (tiled_derivative ).reshape ((self ._num_conditions , self ._n_components ), order = "F" )
653+ der_reshaped = np .asarray (tiled_derivative ).reshape ((self .n_signals , self .n_components ), order = "F" )
653654 func_grad = (
654655 der_reshaped .T + self .rho * stretch @ self .spline_smooth_operator .T @ self .spline_smooth_operator
655656 )
0 commit comments