Skip to content

Commit 19c7fac

Browse files
committed
add find_peak_program
1 parent 7a410ff commit 19c7fac

7 files changed

Lines changed: 101 additions & 51 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dependencies = [
3636
[project.scripts]
3737
SCALEX = "scalex.function:main"
3838
scalex = "scalex.function:main"
39+
frag = "scalex.atac.fragments:main"
3940

4041
[project.optional-dependencies]
4142
dev = [

scalex/analysis.py

Lines changed: 86 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def enrich_analysis(gene_names, organism='hsapiens', gene_sets='GO_Biological_Pr
5757

5858
results = pd.DataFrame()
5959
for group, genes in gene_names.items():
60+
# print(group, genes)
61+
genes = list(genes)
6062
enr = gp.enrichr(genes, gene_sets=gene_sets, cutoff=cutoff).results
6163
enr['cell_type'] = group # Add the group label to the results
6264
results = pd.concat([results, enr])
@@ -157,21 +159,34 @@ def flatten_dict(markers):
157159
return flatten_markers
158160

159161

160-
def find_gene_program(adata, groupby='cell_type', processed=False, n_clusters=None, top_n=300):
162+
def filter_marker_dict(markers, vars):
163+
marker_dict = {}
164+
for cluster, genes in markers.items():
165+
marker_dict[cluster] = [i for i in genes if i in vars]
166+
return marker_dict
167+
168+
def rename_marker_dict(markers, rename_dict):
169+
marker_dict = {}
170+
for cluster, genes in markers.items():
171+
marker_dict[rename_dict[cluster]] = genes
172+
return marker_dict
173+
174+
175+
def find_gene_program(adata, groupby='cell_type', processed=False, n_clusters=None, top_n=300, filter_pseudo=True, **kwargs):
161176
"""
162177
Find gene program for each cell type
163178
"""
164179
adata = adata.copy()
165180
from scalex.data import aggregate_data
166-
adata_avg = aggregate_data(adata, groupby=groupby, processed=processed)
181+
adata_avg = aggregate_data(adata, groupby=groupby, processed=processed, scale=True)
167182

168-
markers = get_markers(adata, groupby=groupby, processed=processed, top_n=top_n)
183+
markers = get_markers(adata, groupby=groupby, processed=processed, top_n=top_n, filter_pseudo=filter_pseudo, **kwargs)
169184
for cluster, genes in markers.items():
170185
print(cluster, len(genes))
171186

172187
marker_list = flatten_dict(markers)
173188

174-
sc.pp.scale(adata_avg, zero_center=True)
189+
# sc.pp.scale(adata_avg, zero_center=True)
175190
adata_avg_ = adata_avg[:, marker_list].copy()
176191

177192
if n_clusters is None:
@@ -180,7 +195,7 @@ def find_gene_program(adata, groupby='cell_type', processed=False, n_clusters=No
180195
from sklearn.cluster import KMeans
181196
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
182197
adata_avg_.var['cluster'] = np.array(kmeans.fit_predict(adata_avg_.X.T)).astype(str)
183-
print(adata_avg_.var)
198+
# print(adata_avg_.var)
184199

185200
gene_cluster_dict = adata_avg_.var.groupby('cluster').groups
186201
gene_cluster_dict = {k: v.tolist() for k, v in gene_cluster_dict.items()}
@@ -192,20 +207,40 @@ def find_gene_program(adata, groupby='cell_type', processed=False, n_clusters=No
192207
return gene_cluster_dict, adata_avg
193208

194209

210+
def find_peak_program(adata, groupby='cell_type', processed=False, n_clusters=None, top_n=-1, pvalue_cutoff=0.05, logfc_cutoff=1., filter_pseudo=False, **kwargs):
211+
"""
212+
Find peak program for each cell type
213+
"""
214+
return find_gene_program(adata, groupby=groupby, processed=processed, top_n=top_n, filter_pseudo=filter_pseudo, pval_cutoff=pvalue_cutoff, logfc_cutoff=logfc_cutoff, **kwargs)
215+
216+
217+
def find_consensus_program(adata, groupby='cell_type', across=None, set_type='gene', top_n=-1, **kwargs):
218+
"""
219+
Find consensus program for each cell type
220+
"""
221+
if across is not None:
222+
adata[groupby+'_'+across] = adata.obs[groupby].astype(str) + '_' + adata.obs[across].astype(str)
223+
groupby = groupby+'_'+across
224+
225+
if set_type == 'gene':
226+
return find_gene_program(adata, groupby=groupby, top_n=top_n, **kwargs)
227+
elif set_type == 'peak':
228+
return find_peak_program(adata, groupby=groupby, top_n=top_n, **kwargs)
229+
230+
195231
def annotate(
196232
adata,
197233
cell_type='leiden',
198234
color = ['cell_type', 'leiden', 'tissue', 'donor'],
199235
cell_type_markers='macrophage', #None,
200-
marker_dict = None,
201236
show_markers=False,
202237
gene_sets='GO_Biological_Process_2023',
203-
n_tops = [], #[100],
204-
options = ['pos'], # ['pos', 'neg']
205238
additional={},
206239
go=True,
207-
out_dir = None, #'../../results/go_and_pathway/NSCLC_macrophage/'
208-
cutoff = 0.05
240+
out_dir = None,
241+
cutoff = 0.05,
242+
processed=False,
243+
top_n=300,
209244
):
210245

211246
color = [i for i in color if i in adata.obs.columns]
@@ -221,28 +256,31 @@ def annotate(
221256
sc.pl.dotplot(adata, cell_type_markers_, groupby=cell_type, standard_scale='var', cmap='coolwarm')
222257
sc.pl.heatmap(adata, cell_type_markers_, groupby=cell_type, show_gene_labels=True, vmax=6)
223258

224-
if marker_dict is None:
225-
sc.tl.rank_genes_groups(adata, groupby=cell_type, key_added=cell_type, dendrogram=False)
226-
sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, cmap='coolwarm', key=cell_type, standard_scale='var', figsize=(22, 5), dendrogram=False)
227-
marker = pd.DataFrame(adata.uns[cell_type]['names'])
228-
# marker_dict = marker.head(5).to_dict(orient='list')
229-
plt.show()
230-
else:
231-
pass
259+
marker_genes = get_markers(adata, groupby=cell_type, processed=processed, top_n=top_n)
260+
# print(marker_genes)
261+
enrich_and_plot(marker_genes, gene_sets=gene_sets, cutoff=cutoff, out_dir=out_dir)
262+
# if marker_dict is None:
263+
# sc.tl.rank_genes_groups(adata, groupby=cell_type, key_added=cell_type, dendrogram=False)
264+
# sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, cmap='coolwarm', key=cell_type, standard_scale='var', figsize=(22, 5), dendrogram=False)
265+
# marker = pd.DataFrame(adata.uns[cell_type]['names'])
266+
# # marker_dict = marker.head(5).to_dict(orient='list')
267+
# plt.show()
268+
# else:
269+
# pass
232270
# sc.pl.heatmap(adata, marker_dict, groupby=cell_type, show_gene_labels=True, vmax=6)
233271

234-
if show_markers:
235-
for k, v in marker_dict.items():
236-
print(k)
237-
sc.pl.umap(adata, color=v, ncols=5)
272+
# if show_markers:
273+
# for k, v in marker_dict.items():
274+
# print(k)
275+
# sc.pl.umap(adata, color=v, ncols=5)
238276

239-
if marker_dict is not None:
240-
enrich_and_plot(marker_dict, gene_sets=gene_sets, cutoff=cutoff, out_dir=out_dir)
241-
elif len(n_tops) > 0:
242-
for n_top in n_tops:
243-
print('-'*20+'\n', n_top, '\n'+'-'*20)
244-
marker_dict = marker.head(n_top).to_dict(orient='list')
245-
enrich_and_plot(marker_dict, gene_sets=gene_sets, cutoff=cutoff, out_dir=out_dir)
277+
# if marker_dict is not None:
278+
# enrich_and_plot(marker_dict, gene_sets=gene_sets, cutoff=cutoff, out_dir=out_dir)
279+
# elif len(n_tops) > 0:
280+
# for n_top in n_tops:
281+
# print('-'*20+'\n', n_top, '\n'+'-'*20)
282+
# marker_dict = marker.head(n_top).to_dict(orient='list')
283+
# enrich_and_plot(marker_dict, gene_sets=gene_sets, cutoff=cutoff, out_dir=out_dir)
246284
# if go:
247285
# for option in options:
248286
# if option == 'pos':
@@ -271,26 +309,26 @@ def annotate(
271309
# go_results[['Gene_set','Term','Overlap', 'Adjusted P-value', 'Genes', 'cell_type']].to_csv(out_dir + f'/{option}_go_results_{n_top}.csv')
272310
# plt.show()
273311

274-
for pathway_name, pathways in additional.items():
275-
try:
276-
pathway_results = enrich_analysis(marker_dict, gene_sets=pathways)
277-
except:
278-
continue
279-
ax = dotplot(pathway_results,
280-
column="Adjusted P-value",
281-
x='cell_type', # set x axis, so you could do a multi-sample/library comparsion
282-
# size=10,
283-
top_term=10,
284-
figsize=(8,10),
285-
title = pathway_name,
286-
xticklabels_rot=45, # rotate xtick labels
287-
show_ring=False, # set to False to revmove outer ring
288-
marker='o',
289-
cutoff=0.05,
290-
cmap='viridis'
291-
)
312+
# for pathway_name, pathways in additional.items():
313+
# try:
314+
# pathway_results = enrich_analysis(marker_dict, gene_sets=pathways)
315+
# except:
316+
# continue
317+
# ax = dotplot(pathway_results,
318+
# column="Adjusted P-value",
319+
# x='cell_type', # set x axis, so you could do a multi-sample/library comparsion
320+
# # size=10,
321+
# top_term=10,
322+
# figsize=(8,10),
323+
# title = pathway_name,
324+
# xticklabels_rot=45, # rotate xtick labels
325+
# show_ring=False, # set to False to revmove outer ring
326+
# marker='o',
327+
# cutoff=0.05,
328+
# cmap='viridis'
329+
# )
292330

293-
plt.show()
331+
# plt.show()
294332

295333

296334

scalex/atac/snapatac2/_misc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def aggregate_X(
5959
from natsort import natsorted
6060
from anndata import AnnData
6161

62+
adata = adata.copy()
6263
def norm(x):
6364
if normalize is None:
6465
return x

scalex/data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,13 @@ def concat_data(
183183
adata = load_files(root)
184184
adata_list.append(adata)
185185

186-
# if batch_categories is None:
187186
# batch_categories = list(map(str, range(len(adata_list))))
188187
# else:
189188
# assert len(adata_list) == len(batch_categories)
190189

191190
adata_concat = concat(adata_list, join=join, label=batch_key, keys=batch_categories, index_unique=index_unique)
191+
if batch_categories is None:
192+
adata_concat.obs['batch'] = 'batch'
192193
return adata_concat
193194
# [print(b, adata.shape) for adata,b in zip(adata_list, batch_categories)]
194195
# concat = AnnData.concatenate(*adata_list, join=join, batch_key=batch_key,
@@ -197,7 +198,7 @@ def concat_data(
197198
# concat.write(save, compression='gzip')
198199
# return concat
199200

200-
def aggregate_data(rna, groupby='cell_type', processed=False):
201+
def aggregate_data(rna, groupby='cell_type', processed=False, scale=False):
201202
if processed:
202203
if rna.raw is not None:
203204
rna = rna.raw.to_adata()
@@ -210,6 +211,8 @@ def aggregate_data(rna, groupby='cell_type', processed=False):
210211
sc.pp.normalize_total(rna_agg, target_sum=1e4)
211212
sc.pp.log1p(rna_agg)
212213

214+
if scale:
215+
sc.pp.scale(rna_agg, zero_center=True)
213216
return rna_agg
214217

215218
def preprocessing_rna(

scalex/function.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def SCALEX(
4040
outdir:str=None,
4141
projection:str=None,
4242
repeat:bool=False,
43+
name:str=None,
4344
impute:str=None,
4445
chunk_size:int=20000,
4546
ignore_umap:bool=False,
@@ -136,6 +137,9 @@ def SCALEX(
136137
device='cpu'
137138

138139
if outdir:
140+
if name is not None and projection is not None:
141+
outdir = os.path.join(projection, 'projection', name)
142+
os.makedirs(outdir, exist_ok=True)
139143
# outdir = outdir+'/'
140144
os.makedirs(os.path.join(outdir, 'checkpoint'), exist_ok=True)
141145
log = create_logger('SCALEX', fh=os.path.join(outdir, 'log.txt'), overwrite=True)
@@ -362,6 +366,7 @@ def main():
362366
parser.add_argument('--eval', action='store_true')
363367
parser.add_argument('--num_workers', type=int, default=4)
364368
parser.add_argument('--keep_mt', action='store_true')
369+
parser.add_argument('--name', type=str, default=None)
365370
# parser.add_argument('--version', type=int, default=2)
366371
# parser.add_argument('--k', type=str, default=30)
367372
# parser.add_argument('--embed', type=str, default='UMAP')
@@ -403,6 +408,7 @@ def main():
403408
chunk_size=args.chunk_size,
404409
ignore_umap=args.ignore_umap,
405410
repeat=args.repeat,
411+
name=args.name,
406412
verbose=True,
407413
assess=args.assess,
408414
eval=args.eval,

scalex/net/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, patience=10, verbose=False, checkpoint_file=''):
5454
self.counter = 0
5555
self.best_score = None
5656
self.early_stop = False
57-
self.loss_min = np.Inf
57+
self.loss_min = np.inf
5858
self.checkpoint_file = checkpoint_file
5959

6060
def __call__(self, loss, model):

scalex/plot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def plot_meta2(
224224
use_rep='latent',
225225
color='cell_type',
226226
batch='batch',
227+
groupby='cell_type',
227228
color_map=None,
228229
figsize=(10, 10),
229230
cmap='Blues',

0 commit comments

Comments
 (0)