Skip to content

Commit 3d49383

Browse files
committed
cleaned up code
1 parent d60b779 commit 3d49383

3 files changed

Lines changed: 70 additions & 30 deletions

File tree

Snakefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ rule evaluation:
396396
pr_curve_png = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', 'precision-recall-curve-ensemble-nodes.png']),
397397
pca_chosen_pr_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', "precision-recall-pca-chosen-pathway.txt"]),
398398
heatmap_edge_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', "jaccard_edge_heatmap.png"]),
399-
399+
heatmap_node_file = SEP.join([out_dir, '{dataset_gold_standard_pairs}-eval', "jaccard_node_heatmap.png"]),
400400
run:
401401
node_table = Evaluation.from_file(input.gold_standard_file).node_table
402402
edge_table = Evaluation.from_file(input.gold_standard_file).edge_table
@@ -406,7 +406,8 @@ rule evaluation:
406406
Evaluation.precision_recall_curve_node_ensemble(node_ensemble, node_table, output.pr_curve_png)
407407
pca_chosen_pathway = Evaluation.pca_chosen_pathway(input.pca_coordinates_file, out_dir)
408408
Evaluation.precision_and_recall_node(pca_chosen_pathway, node_table, algorithms, output.pca_chosen_pr_file)
409-
Evaluation.jaccard_edge_heatmap(input.pathways, edge_table, algorithms, output.pr_edge_file, output.heatmap_edge_file)
409+
Evaluation.jaccard_edge_heatmap(input.pathways, edge_table, output.heatmap_edge_file)
410+
Evaluation.jaccard_node_heatmap(input.pathways, node_table, output.heatmap_node_file)
410411

411412

412413
# Returns all pathways for a specific algorithm and dataset

config/synthetic.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ algorithms:
5151

5252
- name: "omicsintegrator1"
5353
params:
54-
include: false
54+
include: true
5555
run1:
5656
b: [5, 6]
5757
w: np.linspace(0,5,2)
5858
d: [10]
5959

6060
- name: "omicsintegrator2"
6161
params:
62-
include: false
62+
include: true
6363
run1:
6464
b: [4]
6565
g: [0]
@@ -77,7 +77,7 @@ algorithms:
7777

7878
- name: "mincostflow"
7979
params:
80-
include: false
80+
include: true
8181
run1:
8282
flow: [1] # The flow must be an int
8383
capacity: [1]
@@ -125,7 +125,7 @@ gold_standards:
125125
node_files: ["gs_nodes2.txt"]
126126
edge_files: ["gs_edges.txt"]
127127
data_dir: "input"
128-
dataset_labels: ["data2", "data3"]
128+
dataset_labels: ["data2"]
129129

130130
# If we want to reconstruct then we should set run to true.
131131
# TODO: if include is true above but run is false here, algs are not run.

spras/evaluation.py

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,17 @@ def precision_and_recall_edge(file_paths: Iterable[Path], edge_table: pd.DataFra
9494
@param output_file: the filename to save the precision and recall of each pathway
9595
@param output_png (optional): the filename to plot the precision and recall of each pathway (not a PRC)
9696
"""
97-
gs_edges = set()
97+
y_true = set()
9898
for row in edge_table.itertuples():
99-
gs_edges.add((row[1], row[2]))
99+
y_true.add((row[1], row[2]))
100100
results = []
101101
for file in file_paths:
102102
df = pd.read_table(file, sep="\t", header=0, usecols=["Node1", "Node2"])
103103
y_pred = set()
104104
for row in df.itertuples():
105105
y_pred.add((row[1], row[2]))
106-
all_edges = set(gs_edges.union(y_pred))
107-
y_true_binary = [1 if (edge[0], edge[1]) in gs_edges or (edge[1], edge[0]) in gs_edges else 0 for edge in all_edges]
106+
all_edges = set(y_true.union(y_pred))
107+
y_true_binary = [1 if (edge[0], edge[1]) in y_true or (edge[1], edge[0]) in y_true else 0 for edge in all_edges]
108108
y_pred_binary = [1 if (edge[0], edge[1]) in y_pred or (edge[1], edge[0]) in y_pred else 0 for edge in all_edges]
109109
precision = precision_score(y_true_binary, y_pred_binary, zero_division=0.0)
110110
recall = recall_score(y_true_binary, y_pred_binary, zero_division=0.0)
@@ -211,27 +211,23 @@ def precision_and_recall_node(file_paths: Iterable[Path], node_table: pd.DataFra
211211
plt.title("Empty Pathway Files")
212212
plt.savefig(output_png)
213213

214-
215214
@staticmethod
216-
def jaccard_edge_heatmap(file_paths: Iterable[Path], edge_table: pd.DataFrame, algorithms: list, output_file: str, output_png:str=None):
215+
def jaccard_edge_heatmap(file_paths: Iterable[Path], edge_table: pd.DataFrame, output_png:str=None):
217216
"""
218217
Takes in file paths for a specific dataset and an associated gold standard edge table.
219-
Calculates precision and recall for each pathway file
220-
Returns output back to output_file
218+
Generates a jaccard index heatmap image that compares all the edge similarity between each dataset and the gold standard
219+
Returns output back to output_png
221220
@param file_paths: file paths of pathway reconstruction algorithm outputs
222221
@param edge_table: the gold standard edges
223-
@param algorithms: list of algorithms used in current run of SPRAS
224-
@param output_file: the filename to save the precision and recall of each pathway
225-
@param output_png (optional): the filename to plot the precision and recall of each pathway (not a PRC)
222+
@param output_png (optional): the filename to plot the heatmap (not a PRC)
226223
"""
227-
print("jaccard_heatmap")
228-
229224
gs_edges = set()
230225
for row in edge_table.itertuples():
231226
gs_edges.add((row[1], row[2]))
232-
# calculate all the jaccard edge index for each method against the gs
233-
jaccard_edges_indices_list = []
234-
algs = []
227+
228+
# calculate all the jaccard edge index for each method against the gold standard
229+
jaccard_edge_indices_list = []
230+
algorithms = []
235231
for file in file_paths:
236232
df = pd.read_table(file, sep="\t", header=0, usecols=["Node1", "Node2"])
237233
method_edges = set()
@@ -240,25 +236,68 @@ def jaccard_edge_heatmap(file_paths: Iterable[Path], edge_table: pd.DataFrame, a
240236
edge_union = gs_edges | method_edges
241237
edge_intersection = gs_edges & method_edges
242238
jaccard_edge_index = len(edge_intersection) / len(edge_union)
243-
jaccard_edges_indices_list.append(float(jaccard_edge_index))
244-
algs.append(file.split("/")[1].split("-")[1])
245-
246-
jaccard_edges_indices = np.asanyarray([jaccard_edges_indices_list])
239+
jaccard_edge_indices_list.append(float(jaccard_edge_index))
240+
algorithms.append(file.split("/")[1].split("-")[1])
247241

248-
print(algs)
242+
jaccard_edge_indices = np.asanyarray([jaccard_edge_indices_list])
249243

250244
plt.figure(figsize=(10, 8))
251245
sns.heatmap(
252-
jaccard_edges_indices,
246+
jaccard_edge_indices,
253247
annot=True,
254248
cmap="viridis",
255-
yticklabels=["Pathways"],
256-
xticklabels=algs,
249+
xticklabels=algorithms,
250+
yticklabels=[""],
257251
)
258252
plt.xlabel("Algorithms")
253+
plt.ylabel("Pathways")
259254
plt.title("Jaccard Index Edge Heatmap")
255+
plt.tick_params(axis='x', which='major', labelsize=7.5)
260256
plt.savefig(output_png, format="png", dpi=300)
261257

258+
@staticmethod
259+
def jaccard_node_heatmap(file_paths: Iterable[Path], node_table: pd.DataFrame, output_png:str=None):
260+
"""
261+
Takes in file paths for a specific dataset and an associated gold standard nodes table.
262+
Generates a jaccard index heatmap image that compares all the nodes similarity between each dataset and the gold standard
263+
Returns output back to output_png
264+
@param file_paths: file paths of pathway reconstruction algorithm outputs
265+
@param node_table: the gold standard nodes
266+
@param output_png (optional): the filename to plot the heatmap (not a PRC)
267+
"""
268+
gs_nodes = set(node_table['NODEID'])
269+
# calculate all the jaccard node index for each method against the gold standard
270+
jaccard_node_indices_list = []
271+
algorithms = []
272+
for file in file_paths:
273+
df = pd.read_table(file, sep="\t", header=0, usecols=["Node1", "Node2"])
274+
method_nodes = set()
275+
for row in df.itertuples():
276+
if row[1] not in method_nodes:
277+
method_nodes.add(row[1])
278+
if row[2] not in method_nodes:
279+
method_nodes.add(row[2])
280+
node_union = gs_nodes | method_nodes
281+
node_intersection = gs_nodes & method_nodes
282+
jaccard_node_index = len(node_intersection) / len(node_union)
283+
jaccard_node_indices_list.append(float(jaccard_node_index))
284+
algorithms.append(file.split("/")[1].split("-")[1])
285+
286+
jaccard_node_indices = np.asanyarray([jaccard_node_indices_list])
287+
288+
plt.figure(figsize=(10, 8))
289+
sns.heatmap(
290+
jaccard_node_indices,
291+
annot=True,
292+
cmap="viridis",
293+
xticklabels=algorithms,
294+
yticklabels=[""],
295+
)
296+
plt.xlabel("Algorithms")
297+
plt.ylabel("Pathways")
298+
plt.title("Jaccard Index Nodes Heatmap")
299+
plt.tick_params(axis='x', which='major', labelsize=7.5)
300+
plt.savefig(output_png, format="png", dpi=300)
262301

263302
def select_max_freq_and_node(row: pd.Series):
264303
"""

0 commit comments

Comments
 (0)