forked from FLaTNNBio/GA-DPAMSA
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_benchmarks.py
More file actions
77 lines (57 loc) · 2.63 KB
/
run_benchmarks.py
File metadata and controls
77 lines (57 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
from tqdm import tqdm
import config
import utils
import datasets.inference_dataset.encode_project_dataset_4x101bp as inference_dataset
"""
Benchmarking Script for MSA Methods
This script benchmarks different Multiple Sequence Alignment (MSA) methods, including:
- GA-DPAMSA (Genetic Algorithm-enhanced DPAMSA)
- DPAMSA (Deep Reinforcement Learning-based MSA)
- Other external MSA tools (ClustalW, MAFFT, MUSCLE, etc.)
It allows the user to select benchmarking options, executes the selected MSA methods,
and generates reports and performance visualizations.
Author: https://github.com/FLaTNNBio/GA-DPAMSA
"""
# ===========================
# Dataset and Model Configuration
# ===========================
# Ensure the dataset name matches the imported dataset module
DATASET_NAME = 'encode_project_dataset_4x101bp'
# Ensure DPAMSA model matches dataset size
DPAMSA_MODEL = 'model_3x30'
# Ensure GA-DPAMSA model matches 'AGENT_WINDOW_ROW' and 'AGENT_WINDOW_COLUMN' settings
GA_DPAMSA_MODEL = 'model_3x30'
def main():
"""
Main function to execute MSA benchmarking.
- Displays a selection menu for benchmarking options.
- Runs GA-DPAMSA inference (always executed).
- Runs DPAMSA inference if selected.
- Runs external MSA tools if selected.
- Saves results and generates performance plots.
"""
# Display selection menu
choice = utils.display_menu()
# Get dataset file paths
dataset_folder = os.path.join(config.FASTA_FILES_PATH, DATASET_NAME)
file_paths = [os.path.join(dataset_folder, file) for file in sorted(os.listdir(dataset_folder))]
# Dictionary to store CSV paths for each tool
tool_csv_paths = {}
# GA-DPAMSA must always be executed
ga_csv_path = utils.run_ga_dpamsa_inference('sp', inference_dataset, DATASET_NAME, GA_DPAMSA_MODEL)
tool_csv_paths['GA-DPAMSA'] = ga_csv_path
# Run DPAMSA if option 1 or 3 is selected
if choice == 1 or choice == 3:
dpamsa_csv_path = utils.run_dpamsa_inference(inference_dataset, DATASET_NAME, DPAMSA_MODEL)
tool_csv_paths['DPAMSA'] = dpamsa_csv_path
# Run external MSA tools if option 2 or 3 is selected
if choice == 2 or choice == 3:
for tool_name in tqdm(config.TOOLS.keys(), desc="Processing Tools"):
tool_results = utils.run_tool_and_generate_report(tool_name, file_paths, DATASET_NAME)
csv_path = utils.save_inference_csv(tool_results, tool_name, DATASET_NAME) # Save CSV and get path
tool_csv_paths[tool_name] = csv_path
# Generate performance plots for the selected tools
utils.plot_metrics(tool_csv_paths, DATASET_NAME)
if __name__ == "__main__":
main()