@@ -57,6 +57,8 @@ def enrich_analysis(gene_names, organism='hsapiens', gene_sets='GO_Biological_Pr
5757
5858 results = pd .DataFrame ()
5959 for group , genes in gene_names .items ():
60+ # print(group, genes)
61+ genes = list (genes )
6062 enr = gp .enrichr (genes , gene_sets = gene_sets , cutoff = cutoff ).results
6163 enr ['cell_type' ] = group # Add the group label to the results
6264 results = pd .concat ([results , enr ])
@@ -157,21 +159,34 @@ def flatten_dict(markers):
157159 return flatten_markers
158160
159161
160- def find_gene_program (adata , groupby = 'cell_type' , processed = False , n_clusters = None , top_n = 300 ):
162+ def filter_marker_dict (markers , vars ):
163+ marker_dict = {}
164+ for cluster , genes in markers .items ():
165+ marker_dict [cluster ] = [i for i in genes if i in vars ]
166+ return marker_dict
167+
168+ def rename_marker_dict (markers , rename_dict ):
169+ marker_dict = {}
170+ for cluster , genes in markers .items ():
171+ marker_dict [rename_dict [cluster ]] = genes
172+ return marker_dict
173+
174+
175+ def find_gene_program (adata , groupby = 'cell_type' , processed = False , n_clusters = None , top_n = 300 , filter_pseudo = True , ** kwargs ):
161176 """
162177 Find gene program for each cell type
163178 """
164179 adata = adata .copy ()
165180 from scalex .data import aggregate_data
166- adata_avg = aggregate_data (adata , groupby = groupby , processed = processed )
181+ adata_avg = aggregate_data (adata , groupby = groupby , processed = processed , scale = True )
167182
168- markers = get_markers (adata , groupby = groupby , processed = processed , top_n = top_n )
183+ markers = get_markers (adata , groupby = groupby , processed = processed , top_n = top_n , filter_pseudo = filter_pseudo , ** kwargs )
169184 for cluster , genes in markers .items ():
170185 print (cluster , len (genes ))
171186
172187 marker_list = flatten_dict (markers )
173188
174- sc .pp .scale (adata_avg , zero_center = True )
189+ # sc.pp.scale(adata_avg, zero_center=True)
175190 adata_avg_ = adata_avg [:, marker_list ].copy ()
176191
177192 if n_clusters is None :
@@ -180,7 +195,7 @@ def find_gene_program(adata, groupby='cell_type', processed=False, n_clusters=No
180195 from sklearn .cluster import KMeans
181196 kmeans = KMeans (n_clusters = n_clusters , random_state = 0 )
182197 adata_avg_ .var ['cluster' ] = np .array (kmeans .fit_predict (adata_avg_ .X .T )).astype (str )
183- print (adata_avg_ .var )
198+ # print(adata_avg_.var)
184199
185200 gene_cluster_dict = adata_avg_ .var .groupby ('cluster' ).groups
186201 gene_cluster_dict = {k : v .tolist () for k , v in gene_cluster_dict .items ()}
@@ -192,20 +207,40 @@ def find_gene_program(adata, groupby='cell_type', processed=False, n_clusters=No
192207 return gene_cluster_dict , adata_avg
193208
194209
210+ def find_peak_program (adata , groupby = 'cell_type' , processed = False , n_clusters = None , top_n = - 1 , pvalue_cutoff = 0.05 , logfc_cutoff = 1. , filter_pseudo = False , ** kwargs ):
211+ """
212+ Find peak program for each cell type
213+ """
214+ return find_gene_program (adata , groupby = groupby , processed = processed , top_n = top_n , filter_pseudo = filter_pseudo , pval_cutoff = pvalue_cutoff , logfc_cutoff = logfc_cutoff , ** kwargs )
215+
216+
217+ def find_consensus_program (adata , groupby = 'cell_type' , across = None , set_type = 'gene' , top_n = - 1 , ** kwargs ):
218+ """
219+ Find consensus program for each cell type
220+ """
221+ if across is not None :
222+ adata [groupby + '_' + across ] = adata .obs [groupby ].astype (str ) + '_' + adata .obs [across ].astype (str )
223+ groupby = groupby + '_' + across
224+
225+ if set_type == 'gene' :
226+ return find_gene_program (adata , groupby = groupby , top_n = top_n , ** kwargs )
227+ elif set_type == 'peak' :
228+ return find_peak_program (adata , groupby = groupby , top_n = top_n , ** kwargs )
229+
230+
195231def annotate (
196232 adata ,
197233 cell_type = 'leiden' ,
198234 color = ['cell_type' , 'leiden' , 'tissue' , 'donor' ],
199235 cell_type_markers = 'macrophage' , #None,
200- marker_dict = None ,
201236 show_markers = False ,
202237 gene_sets = 'GO_Biological_Process_2023' ,
203- n_tops = [], #[100],
204- options = ['pos' ], # ['pos', 'neg']
205238 additional = {},
206239 go = True ,
207- out_dir = None , #'../../results/go_and_pathway/NSCLC_macrophage/'
208- cutoff = 0.05
240+ out_dir = None ,
241+ cutoff = 0.05 ,
242+ processed = False ,
243+ top_n = 300 ,
209244):
210245
211246 color = [i for i in color if i in adata .obs .columns ]
@@ -221,28 +256,31 @@ def annotate(
221256 sc .pl .dotplot (adata , cell_type_markers_ , groupby = cell_type , standard_scale = 'var' , cmap = 'coolwarm' )
222257 sc .pl .heatmap (adata , cell_type_markers_ , groupby = cell_type , show_gene_labels = True , vmax = 6 )
223258
224- if marker_dict is None :
225- sc .tl .rank_genes_groups (adata , groupby = cell_type , key_added = cell_type , dendrogram = False )
226- sc .pl .rank_genes_groups_dotplot (adata , n_genes = 5 , cmap = 'coolwarm' , key = cell_type , standard_scale = 'var' , figsize = (22 , 5 ), dendrogram = False )
227- marker = pd .DataFrame (adata .uns [cell_type ]['names' ])
228- # marker_dict = marker.head(5).to_dict(orient='list')
229- plt .show ()
230- else :
231- pass
259+ marker_genes = get_markers (adata , groupby = cell_type , processed = processed , top_n = top_n )
260+ # print(marker_genes)
261+ enrich_and_plot (marker_genes , gene_sets = gene_sets , cutoff = cutoff , out_dir = out_dir )
262+ # if marker_dict is None:
263+ # sc.tl.rank_genes_groups(adata, groupby=cell_type, key_added=cell_type, dendrogram=False)
264+ # sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, cmap='coolwarm', key=cell_type, standard_scale='var', figsize=(22, 5), dendrogram=False)
265+ # marker = pd.DataFrame(adata.uns[cell_type]['names'])
266+ # # marker_dict = marker.head(5).to_dict(orient='list')
267+ # plt.show()
268+ # else:
269+ # pass
232270 # sc.pl.heatmap(adata, marker_dict, groupby=cell_type, show_gene_labels=True, vmax=6)
233271
234- if show_markers :
235- for k , v in marker_dict .items ():
236- print (k )
237- sc .pl .umap (adata , color = v , ncols = 5 )
272+ # if show_markers:
273+ # for k, v in marker_dict.items():
274+ # print(k)
275+ # sc.pl.umap(adata, color=v, ncols=5)
238276
239- if marker_dict is not None :
240- enrich_and_plot (marker_dict , gene_sets = gene_sets , cutoff = cutoff , out_dir = out_dir )
241- elif len (n_tops ) > 0 :
242- for n_top in n_tops :
243- print ('-' * 20 + '\n ' , n_top , '\n ' + '-' * 20 )
244- marker_dict = marker .head (n_top ).to_dict (orient = 'list' )
245- enrich_and_plot (marker_dict , gene_sets = gene_sets , cutoff = cutoff , out_dir = out_dir )
277+ # if marker_dict is not None:
278+ # enrich_and_plot(marker_dict, gene_sets=gene_sets, cutoff=cutoff, out_dir=out_dir)
279+ # elif len(n_tops) > 0:
280+ # for n_top in n_tops:
281+ # print('-'*20+'\n', n_top, '\n'+'-'*20)
282+ # marker_dict = marker.head(n_top).to_dict(orient='list')
283+ # enrich_and_plot(marker_dict, gene_sets=gene_sets, cutoff=cutoff, out_dir=out_dir)
246284 # if go:
247285 # for option in options:
248286 # if option == 'pos':
@@ -271,26 +309,26 @@ def annotate(
271309 # go_results[['Gene_set','Term','Overlap', 'Adjusted P-value', 'Genes', 'cell_type']].to_csv(out_dir + f'/{option}_go_results_{n_top}.csv')
272310 # plt.show()
273311
274- for pathway_name , pathways in additional .items ():
275- try :
276- pathway_results = enrich_analysis (marker_dict , gene_sets = pathways )
277- except :
278- continue
279- ax = dotplot (pathway_results ,
280- column = "Adjusted P-value" ,
281- x = 'cell_type' , # set x axis, so you could do a multi-sample/library comparsion
282- # size=10,
283- top_term = 10 ,
284- figsize = (8 ,10 ),
285- title = pathway_name ,
286- xticklabels_rot = 45 , # rotate xtick labels
287- show_ring = False , # set to False to revmove outer ring
288- marker = 'o' ,
289- cutoff = 0.05 ,
290- cmap = 'viridis'
291- )
312+ # for pathway_name, pathways in additional.items():
313+ # try:
314+ # pathway_results = enrich_analysis(marker_dict, gene_sets=pathways)
315+ # except:
316+ # continue
317+ # ax = dotplot(pathway_results,
318+ # column="Adjusted P-value",
319+ # x='cell_type', # set x axis, so you could do a multi-sample/library comparsion
320+ # # size=10,
321+ # top_term=10,
322+ # figsize=(8,10),
323+ # title = pathway_name,
324+ # xticklabels_rot=45, # rotate xtick labels
325+ # show_ring=False, # set to False to revmove outer ring
326+ # marker='o',
327+ # cutoff=0.05,
328+ # cmap='viridis'
329+ # )
292330
293- plt .show ()
331+ # plt.show()
294332
295333
296334
0 commit comments