-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstru_decompose3.py
More file actions
210 lines (185 loc) · 8.93 KB
/
stru_decompose3.py
File metadata and controls
210 lines (185 loc) · 8.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import numpy as np
from collections import OrderedDict
import forgi.graph.bulge_graph as fgb
from multiprocessing import Pool
import multiprocessing as mp
from tqdm import tqdm
def is_valid_dot_bracket(structure):
"""
Check if a dot-bracket RNA structure has matching open and closing parentheses.
It does not accepts pseudoknots or other types of base pairs.
"""
stack = []
for char in structure:
if char == '(':
stack.append(char)
elif char == ')':
if not stack:
return False
stack.pop()
# If stack is empty, all parentheses are matched
return len(stack) == 0
def nucleotide_to_feature_vector(nucleotide):
"""将核苷酸(A, C, G, U)转换为特征向量"""
mapping = {'A': 1, 'C': 2, 'G': 3, 'U': 4, 'T': 4, 'N':0 }
return [mapping[nucleotide] if nucleotide in mapping else 0]
def edge_to_feature_vector(edge_type):
"""将边类型转换为特征向量"""
mapping = {'pair': 0, 'backbone': 1, 'F': 2,'T': 3, 'S': 4, 'H': 5, 'I': 6, 'M': 7}
return [mapping[edge_type]]
def decompose(dotbracket_struct):
"""
Decomposes a dot-bracket structure
MODIFIED: This function now returns the raw hypernodes dictionary as well.
"""
bg = fgb.BulgeGraph.from_dotbracket(dotbracket_struct)
raw_hpgraph = bg.to_bg_string().rstrip().upper().split('\n')
hypernodes = OrderedDict()
for line in raw_hpgraph:
if line.startswith('DEFINE'):
tokens = line.split()[1:]
hp_node_id = tokens[0]
if hp_node_id.startswith('F'):
hypernodes[hp_node_id] = list(range(int(tokens[1]) - 1, int(tokens[2]) + 1))
elif hp_node_id.startswith('T'):
hypernodes[hp_node_id] = list(range(int(tokens[1]) - 2, int(tokens[2])))
elif hp_node_id.startswith('S'):
hypernodes[hp_node_id] = [
list(range(int(tokens[1]) - 1, int(tokens[2]))),
list(range(int(tokens[3]) - 1, int(tokens[4])))]
elif hp_node_id.startswith('H'):
hypernodes[hp_node_id] = list(range(int(tokens[1]) - 2, int(tokens[2]) + 1))
elif hp_node_id.startswith('I'):
if len(tokens[1:]) == 2:
stem_id = bg.connections(hp_node_id.lower())[0].upper()
if hypernodes[stem_id][0][-1] == int(tokens[1]) - 2:
hypernodes[hp_node_id] = [
list(range(int(tokens[1]) - 2, int(tokens[2]) + 1)),
[hypernodes[stem_id][1][0] - 1, hypernodes[stem_id][1][0]]]
elif hypernodes[stem_id][1][0] == int(tokens[2]):
hypernodes[hp_node_id] = [
[hypernodes[stem_id][0][-1], hypernodes[stem_id][0][-1] + 1],
list(range(int(tokens[1]) - 2, int(tokens[2]) + 1))]
else:
raise ValueError('Internal loop parsing error')
else:
hypernodes[hp_node_id] = [
list(range(int(tokens[1]) - 2, int(tokens[2]) + 1)),
list(range(int(tokens[3]) - 2, int(tokens[4]) + 1))]
else: #M
if len(tokens) == 1:
stem_ids = [stem_id.upper() for stem_id in bg.connections(hp_node_id.lower())]
e_3_idx, e_5_idx = hypernodes[stem_ids[0]]
l_3_idx, l_5_idx = hypernodes[stem_ids[1]]
if e_3_idx[-1] + 1 == l_3_idx[0]:
hypernodes[hp_node_id] = [e_3_idx[-1], l_3_idx[0]]
elif e_5_idx[0] - 1 == l_5_idx[-1]:
hypernodes[hp_node_id] = [l_5_idx[-1], e_5_idx[0]]
elif e_5_idx[-1] + 1 == l_3_idx[0]:
hypernodes[hp_node_id] = [e_5_idx[-1], l_3_idx[0]]
else:
raise ValueError('Multiloop parsing error:%s\n%s' % (hp_node_id, dotbracket_struct))
else:
hypernodes[hp_node_id] = list(range(int(tokens[1]) - 2, int(tokens[2]) + 1))
return hypernodes
def dotbracket2hgraph(dotbracket_struct, sequence):
"""
Converts a dot-bracket structure and RNA sequence to hypergraph Data object
:param dotbracket_struct: Dot-bracket string (str)
:param sequence: RNA sequence string (str, e.g., 'AUGC...')
:return: (node_fvs, n_idx, e_idx, edge_fvs)
"""
# Validate inputs
if len(dotbracket_struct) != len(sequence):
raise ValueError("Dot-bracket structure and sequence must have the same length")
# Nucleotide nodes
node_fvs = []
for nucleotide in sequence:
node_fvs.append(nucleotide_to_feature_vector(nucleotide))
n_idx, e_idx, edge_fvs = [], [], []
num_edges = 0
# Pairing edges (from dot-bracket)
stack = []
pairs = []
for i, char in enumerate(dotbracket_struct):
if char == '(':
stack.append(i)
elif char == ')':
if stack:
j = stack.pop()
n_idx.extend([j, i]) # Pairing edge connects nucleotides j and i
e_idx.extend([num_edges, num_edges])
edge_fvs.append(edge_to_feature_vector('pair'))
pairs.append((j, i))
num_edges += 1
else:
raise ValueError("Invalid dot-bracket structure: unmatched closing parenthesis")
if stack:
raise ValueError("Invalid dot-bracket structure: unmatched opening parenthesis")
# Backbone edges (connect consecutive nucleotides)
for i in range(len(sequence) - 1):
n_idx.extend([i, i + 1])
e_idx.extend([num_edges, num_edges])
edge_fvs.append(edge_to_feature_vector('backbone'))
num_edges += 1
# Hyperedges (from decompose function)
hypernodes = decompose(dotbracket_struct)
for hp_node_id, nodes in hypernodes.items():
if isinstance(nodes, list) and nodes:
if any(isinstance(sublist, list) for sublist in nodes):
# Flatten all sublists into a single list of indices
flattened_indices = [idx for sublist in nodes for idx in sublist]
else:
flattened_indices = nodes
if flattened_indices: # Ensure indices are not empty
for idx in flattened_indices:
n_idx.append(idx)
e_idx.append(num_edges)
edge_fvs.append(edge_to_feature_vector(hp_node_id[0]))
num_edges += 1
return (node_fvs, n_idx, e_idx, edge_fvs)
def process_rna(args):
dotbracket_struct, rna_seq = args
try:
node_fvs, n_idx, e_idx, edge_fvs = dotbracket2hgraph(dotbracket_struct, rna_seq)
return node_fvs, n_idx, e_idx, edge_fvs, None # 成功时返回 None 作为 error
except Exception as e:
return None, None, None, None, str(e) # 失败时返回错误信息
def process_rna_parallel(dotbracket_structs, rna_seqs, num_processes=None):
"""
并行处理多个 RNA 序列和点括号结构。
Args:
rna_seqs (list): RNA 序列列表。
dotbracket_structs (list): 对应的点括号结构列表。
num_processes (int): 使用的进程数,默认为 CPU 核心数。
Returns:
list: 包含每个 RNA 的 (X, E, H_sparse) 的结果列表。
"""
if len(rna_seqs) != len(dotbracket_structs):
raise ValueError("RNA sequences and dotbracket structures must have the same length")
if num_processes is None:
num_processes = mp.cpu_count()
inputs = list(zip(dotbracket_structs,rna_seqs))
with Pool(processes=num_processes) as pool:
results = list(tqdm(pool.imap(process_rna, inputs), total=len(inputs), desc="Processing RNA sequences"))
for i, (node_fvs, n_idx, e_idx, edge_fvs, error) in enumerate(results):
if error is not None:
print(f"Error processing RNA {i}: {error}")
valid_results = [(node_fvs, n_idx, e_idx, edge_fvs) for node_fvs, n_idx, e_idx, edge_fvs, error in results if error is None]
return valid_results
if __name__ == "__main__":
# Example 1: A simple RNA structure
dotbracket_struct = "...(((((((..((((((.........))))))......).((((((.......))))))..))))))..." # 一个简单的茎环结构
sequence = "CGCUUCAUAUAAUCCUAAUGAUAUGGUUUGGGAGUUUCUACCAAGAGCCUUAAACUCUUGAUUAUGAAGUG" # 对应的 RNA 序列
# 调用函数
node_fvs, n_idx, e_idx, edge_fvs = dotbracket2hgraph(dotbracket_struct, sequence)
print(len(edge_fvs))
# 打印结果
print("Node Feature Vectors (node_fvs):")
print(node_fvs)
print("\nNode Indices (n_idx):")
print(n_idx)
print("\nEdge Indices (e_idx):")
print(e_idx)
print("\nEdge Feature Vectors (edge_fvs):")
print(edge_fvs)