-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathACWS_filterCells.py
More file actions
executable file
·241 lines (220 loc) · 9.11 KB
/
ACWS_filterCells.py
File metadata and controls
executable file
·241 lines (220 loc) · 9.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jan 15 16:01:27 2021
@author: smith
"""
import os
import numpy as np
import pandas as pd
import scanpy as sc
import scipy
def findCellsByGeneCoex(adata, gene1, gene2=None, g1thresh=0.6, g2thresh=0.6, gene1up=True, gene2up=True, use_raw=False):
"""Find cells based on co-expression of two genes.
Parameters
----------
adata : AnnData object
Dataset to filter
gene1 : string
First gene to filter by.
gene2 : string, optional (default None)
Second gene to filter by.
g1thresh : float, optional (default 0.6)
Expression theshold for gene1. Be careful to know the data type of your input (e.g. has the data been log transformed?)
g2thresh : float, optional (default 0.6)
Expression theshold for gene1. Be careful to know the data type of your input (e.g. has the data been log transformed?)
gene1Up : bool, optional (default True)
Whether you want to filter gene1 above threshold (set False to keep cells below threshold)
gene2Up : bool, optional (default True)
Whether you want to filter gene2 above threshold (set False to keep cells below threshold)
use_raw : bool, optional (default False)
Whether to use the adata.X matrix containing only highly variable genes (if True), or adata.raw.X
Note that using raw data takes significantly longer than only HVGs so should only be used when necessary.
Returns
-------
Filtered dataframe.
Examples
--------
>>> result = findCellsByGeneCoex(adata, 'Sst', 'Gad1', 0.6, 1.0, True, True, False)
>>> result = findCellsByGeneCoex(adata, gene1='Slc17a6', gene2='Emx1', g1thresh=1.5, g2thresh=0.6, gene1up=False, gene2up=True)
"""
if type(adata.X)==scipy.sparse.csr.csr_matrix:
if use_raw:
mtx = pd.DataFrame(adata.raw.X.toarray())
mtx.columns=adata.raw.var_names
mtx.index=adata.raw.obs_names
elif not use_raw:
mtx = pd.DataFrame(adata.X.toarray())
mtx.columns = adata.var.index
mtx.index = adata.obs.index
elif type(adata.X)==np.ndarray:
if use_raw:
mtx = pd.DataFrame(adata.raw.X)
mtx.columns=adata.raw.var_names
mtx.index=adata.raw.obs_names
elif not use_raw:
mtx = pd.DataFrame(adata.X)
mtx.columns = adata.var.index
mtx.index = adata.obs.index
else:
raise TypeError(str(type(adata.X)) + " is not a valid data type. Must be a scipy sparse matrix or numpy ndarray.")
try:
if gene1up:
df = mtx.loc[mtx[gene1]>=g1thresh]
elif not gene1up:
df = mtx.loc[mtx[gene1]<g1thresh]
except KeyError:
if not use_raw:
raise KeyError(str(gene1) + " not found, you may need to set use_raw=True.")
elif use_raw:
raise KeyError(str(gene1) + " not found in dataset")
try:
if gene2:
if gene2up:
df1 = df.loc[df[gene2]>=g2thresh]
elif not gene2up:
df1 = df.loc[df[gene2]<g2thresh]
elif not gene2:
print(str(df.shape[0]) + " cells stored with mean " + str(gene1) + "expression" + str(np.mean(df[gene1])))
return df
except KeyError:
if not use_raw:
raise KeyError(str(gene2) + " not found, you may need to set use_raw=True.")
elif use_raw:
raise KeyError(str(gene2) + " not found in dataset")
print(str(df1.shape[0]) + " cells stored with mean " + str(gene1) + " expression " + str(round(np.mean(df1[gene1]), 4)) +
" and mean " + str(gene2) + " expression " + str(round(np.mean(df1[gene2]), 4)))
return df1
def filterCellsByCoex(adata, df):
"""Removes cells from adata if they are contained in df.
Parameters
----------
adata : AnnData object.
Data to filter
df : pandas DataFrame
Dataframe of cells to remove (result of findCellsByGeneCoex)
"""
cells = adata.obs.index.tolist()
exCells = df.index.tolist()
filteredCells = [c for c in cells if c not in exCells]
return(adata[adata.obs.index.isin(filteredCells)].copy())
def subsampleData(adata, df):
"""Returns a full AnnData object containing only cells identified in findCellsByGeneCoex.
Parameters
----------
adata : AnnData object.
The scanPy AnnData object results were originally extracted from.
df : Pandas DataFrame.
The result dataframe from findCellsByGeneCoex
Example:
>>> result = findCellsByGeneCoex(adata, 'Gad1', 'Sst', 0.6, gene1up=True, gene2up=True, use_raw=True)
>>> adata=subsampleAdata(adata, result)
"""
return(adata[adata.obs.index.isin(df.index.tolist())])
def countDEGs(file, directory, n_genes=1000, pcutoff=.05, plot=True, save=False, imageType='.png'):
"""Count number of differentially expressed genes in scanpy result file.
Parameters
----------
file : string
Path to saved .xlsx or .csv file containing differential expression data.
directory : string
Directory to save results.
n_genes : int, (optional, default 1000)
Number of genes used in original data analysis.
pcutoff : float (optional, default .05)
Alpha value for significance.
save : bool (optional, default False)
Whether to save or only return result.
Returns
-------
Pandas DataFrame with # of DEGs for each cluster.
"""
fname, ext = os.path.splitext(os.path.basename(file))
if file.endswith('.xlsx'):
df = pd.read_excel(file, index_col=0, engine='openpyxl')
df = df[:n_genes]
elif file.endswith('.csv'):
df = pd.read_csv(file, index_col=0)
df = df[:n_genes]
clusters=[]
degs=[]
for col in df.columns:
if col.endswith('_p'):
count = (df[col]<pcutoff).value_counts()
try:
count = count.loc[count.index==True].values[0]
except IndexError:
count=0
clu = int(col.strip('_p').split(' ')[-1].strip(')'))
clusters.append(clu)
degs.append(count)
lz = list(zip(clusters,degs))
res = pd.DataFrame(lz)
res.columns=['Cluster', 'DEGs']
res.set_index('Cluster', inplace=True, drop=True)
if plot:
fig = res.plot(kind='bar', grid=False)
ax = fig.get_figure()
ax.savefig(os.path.join(directory, 'figures/' + fname + '_DEG_Counts'+imageType))
if save:
res.to_excel(os.path.join(directory, fname + '_DEG_Counts.xlsx'))
return res
def mergeGroupsCountDEGs(file1, file2, directory, n_genes=1000, pcutoff=.05, plot=True, save=False, imageType='.png'):
"""Count number of differentially expressed genes in scanpy result file.
Parameters
----------
file : string
Path to first saved .xlsx or .csv file containing differential expression data.
file2 : string
Path to second .xlsx or .csv file with differential expression data.
directory : string
Directory to save results.
n_genes : int, (optional, default 1000)
Number of genes used in original data analysis.
pcutoff : float (optional, default .05)
Alpha value for significance.
save : bool (optional, default False)
Whether to save or only return result.
Returns
-------
Pandas DataFrame with # of DEGs for each cluster.
"""
cat = pd.DataFrame()
comparisons=[]
for file in [file1, file2]:
fname, ext = os.path.splitext(os.path.basename(file))
comparison = fname.split('DiffExp_')[-1].split('_')[0]
groupid = comparison.split('Upregulated')[-1]
comparisons.append(comparison)
if file.endswith('.xlsx'):
df = pd.read_excel(file, index_col=0, engine='openpyxl')
df = df[:n_genes]
elif file.endswith('.csv'):
df = pd.read_csv(file, index_col=0)
df = df[:n_genes]
clusters=[]
degs=[]
for col in df.columns:
if col.endswith('_pval'):
count = (df[col]<pcutoff).value_counts()
try:
count = count.loc[count.index==True].values[0]
except IndexError:
count=0
try:
clu = int(col.strip('_pval').split(' ')[-1].strip(')'))
except ValueError:
clu = int(col.strip('_pval').split(' ')[-1].strip(')')[1])
clusters.append(clu)
degs.append(count)
cat["Up " + groupid]=degs
if plot:
fig = cat.plot(kind='bar', grid=False)
ax = fig.get_figure()
ax.savefig(os.path.join(directory, 'figures/' + 'Combined_' + comparisons[0] + '_' + comparisons[1] + '_Ngenes'+str(n_genes)+'_DEG_Counts'+imageType), dpi=300, transparent=imageType=='.pdf')
if save:
cat.to_excel(os.path.join(directory, 'Combined_' + comparisons[0] + '_' + comparisons[1] + '_Ngenes'+str(n_genes)+'_DEG_Counts.xlsx'))
return cat
def meanMito(adata):
meanMito = adata.obs['percent_mito'].mean()
return(meanMito)