Skip to content

Commit 002da4b

Browse files
author
xfchen0912
committed
Fix some bugs
1 parent 27d8b4e commit 002da4b

5 files changed

Lines changed: 29 additions & 9 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
.DS_Store
33
*~
44
buck-out/
5+
tmp/
56

67
# Compiled files
78
.venv/
@@ -22,7 +23,10 @@ __pycache__/
2223
# docs
2324
/docs/generated/
2425
/docs/_build/
26+
27+
# Vibe code assistant
2528
.aider*
29+
CLAUDE*
2630

2731
# tox
2832
.tox/

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ dependencies = [
2828
"decoupler>=2.1.1",
2929
"deprecated",
3030
"genomepy>=0.16.1",
31+
"ipython<9",
3132
"moods-python",
3233
"mudata<=0.2.3",
3334
"netgraph",
35+
"numpy<2",
36+
"pandas<=2.3.1",
3437
"pillow<12",
3538
"pooch",
3639
"pyarrow<=20",
@@ -73,11 +76,16 @@ optional-dependencies.doc = [
7376
"sphinxext-opengraph",
7477
]
7578

79+
optional-dependencies.jupyter = [
80+
"ipykernel",
81+
"ipywidgets",
82+
]
7683
optional-dependencies.test = [
7784
"coverage",
7885
"pytest",
7986
"tox",
8087
]
88+
8189
# https://docs.pypi.org/project_metadata/#project-urls
8290
urls.Documentation = "https://scMagnify.readthedocs.io/"
8391
urls.Homepage = "https://github.com/xfchen0912/scMagnify"

src/scmagnify/models/_train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import decoupler as dc
66
import numpy as np
77
import pandas as pd
8-
import scanpy as sc
98
import torch
109
import torch.nn as nn
1110
import torch.nn.functional as F
@@ -82,6 +81,7 @@ def __init__(
8281
time_key: str = "palantir_pseudotime",
8382
gene_selected: list[str] | None = None,
8483
basal_grn: NDArray | None = None,
84+
use_rep: str = "X_pca",
8585
func: nn.Module = MSNGC,
8686
hidden: list[int] = [50],
8787
lag: int = 5,
@@ -152,7 +152,7 @@ def __init__(
152152
)
153153

154154
# Preprocess data.
155-
self.AX, self.Y, self.T = self._preprocess_data()
155+
self.AX, self.Y, self.T = self._preprocess_data(use_rep=use_rep)
156156

157157
self.n_reg = self.adata_fil[:, self.adata_fil.var["is_reg"]].shape[1]
158158
self.n_target = self.adata_fil[:, self.adata_fil.var["is_target"]].shape[1]
@@ -181,7 +181,7 @@ def __init__(
181181

182182
self.criterion = loss.MSELoss()
183183

184-
def _preprocess_data(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
184+
def _preprocess_data(self, use_rep="X_pca") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
185185
"""
186186
Preprocess data for training.
187187
@@ -191,7 +191,7 @@ def _preprocess_data(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
191191
Preprocessed data tensors (AX, Y, T).
192192
"""
193193
# Preprocess data.
194-
sc.pp.neighbors(self.adata_fil, n_neighbors=30)
194+
# sc.pp.neighbors(self.adata_fil, n_neighbors=30, use_rep=use_rep)
195195
AX = partial_ordering(self.adata_fil[:, self.adata_fil.var["is_reg"]], dyn=self.time_key, lag=self.lag)
196196
Y = normalize_data(self.adata_fil[:, self.adata_fil.var["is_target"]].X.A)
197197
T = self.adata_fil.obs[self.time_key].values
@@ -272,7 +272,7 @@ def _train_epoch(
272272

273273
# Temporal smoothness penalty term
274274
T_idx = np.argsort(T_batch.detach().cpu().numpy())
275-
T_plus1 = T + 1
275+
T_plus1 = T_idx + 1
276276
AX_Tplus1 = AX[np.where(np.isin(T_idx, T_plus1))[0], :, :]
277277

278278
coeffs_Tplus1, _, _ = self.model(AX_Tplus1)

src/scmagnify/tools/_motif_scan.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from rich.progress import Progress, TaskID
1818
from rich.table import Table
1919

20-
import scmagnify as scm
2120
from scmagnify import logging as logg
2221
from scmagnify.settings import settings
2322
from scmagnify.utils import _list_to_str, d
@@ -41,7 +40,7 @@
4140
]
4241

4342
_BACKGROUND = Literal["subject", "genome", "even"]
44-
MOTIF_DIR = os.path.join(os.path.dirname(scm.__file__), "data", "motifs")
43+
MOTIF_DIR = os.path.join(settings.scm_data, "motifs")
4544

4645

4746
def _add_peak_seq(
@@ -133,7 +132,7 @@ def _add_peak_info(
133132
"start": start,
134133
"end": end,
135134
"width": width,
136-
"GC": gc_content,
135+
"GC_bin": gc_content,
137136
"N": n_content,
138137
}
139138
)
@@ -723,7 +722,9 @@ def write_jaspar(motif_dict: dict[str, pd.DataFrame], file_path: str, pseudo_cou
723722
"""
724723
with open(file_path, "w") as f:
725724
for motif_id, df in motif_dict.items():
726-
f.write(f">{motif_id}\n")
725+
# MA0007.3 Ar
726+
factor_name = motif_id.split("_")[-1] # Extract motif name if needed
727+
f.write(f">{motif_id} {factor_name}\n")
727728
# Convert probabilities to pseudo-counts
728729
counts_df = (df * pseudo_counts).round().astype(int)
729730

src/scmagnify/tools/_peak_gene_corr.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,13 @@ def connect_peaks_genes(
438438

439439
sc.pp.filter_genes(meta_rna_adata, min_cells=3)
440440

441+
# # Check GC content in ATAC data
442+
# if "GC" not in meta_atac_adata.var.columns:
443+
# from scmagnify.tools._motif_scan import _add_peak_info
444+
445+
# logg.info("Adding GC content to ATAC data...")
446+
# _add_peak_info(meta_atac_adata)
447+
441448
if gene_selected is None:
442449
adata = _get_data_modal(data, rna_key)
443450
if "significant_genes" in adata.var.keys():

0 commit comments

Comments
 (0)