55import decoupler as dc
66import numpy as np
77import pandas as pd
8- import scanpy as sc
98import torch
109import torch .nn as nn
1110import torch .nn .functional as F
@@ -82,6 +81,7 @@ def __init__(
8281 time_key : str = "palantir_pseudotime" ,
8382 gene_selected : list [str ] | None = None ,
8483 basal_grn : NDArray | None = None ,
84+ use_rep : str = "X_pca" ,
8585 func : nn .Module = MSNGC ,
8686 hidden : list [int ] = [50 ],
8787 lag : int = 5 ,
@@ -152,7 +152,7 @@ def __init__(
152152 )
153153
154154 # Preprocess data.
155- self .AX , self .Y , self .T = self ._preprocess_data ()
155+ self .AX , self .Y , self .T = self ._preprocess_data (use_rep = use_rep )
156156
157157 self .n_reg = self .adata_fil [:, self .adata_fil .var ["is_reg" ]].shape [1 ]
158158 self .n_target = self .adata_fil [:, self .adata_fil .var ["is_target" ]].shape [1 ]
@@ -181,7 +181,7 @@ def __init__(
181181
182182 self .criterion = loss .MSELoss ()
183183
184- def _preprocess_data (self ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
184+ def _preprocess_data (self , use_rep = "X_pca" ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
185185 """
186186 Preprocess data for training.
187187
@@ -191,7 +191,7 @@ def _preprocess_data(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
191191 Preprocessed data tensors (AX, Y, T).
192192 """
193193 # Preprocess data.
194- sc .pp .neighbors (self .adata_fil , n_neighbors = 30 )
194+ # sc.pp.neighbors(self.adata_fil, n_neighbors=30, use_rep=use_rep )
195195 AX = partial_ordering (self .adata_fil [:, self .adata_fil .var ["is_reg" ]], dyn = self .time_key , lag = self .lag )
196196 Y = normalize_data (self .adata_fil [:, self .adata_fil .var ["is_target" ]].X .A )
197197 T = self .adata_fil .obs [self .time_key ].values
@@ -272,7 +272,7 @@ def _train_epoch(
272272
273273 # Temporal smoothness penalty term
274274 T_idx = np .argsort (T_batch .detach ().cpu ().numpy ())
275- T_plus1 = T + 1
275+ T_plus1 = T_idx + 1
276276 AX_Tplus1 = AX [np .where (np .isin (T_idx , T_plus1 ))[0 ], :, :]
277277
278278 coeffs_Tplus1 , _ , _ = self .model (AX_Tplus1 )
0 commit comments