Skip to content

Commit 53bf71a

Browse files
committed
Rename data functions, add function for padding MSAs with dummy sequences
1 parent a4cc6f5 commit 53bf71a

8 files changed

Lines changed: 288 additions & 64 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# DiffPaSS
1+
# DiffPaSS – Differentiable Pairing using Soft Scores
22

33
<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->
44

diffpass/_modidx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@
5454
'diffpass/data_utils.py'),
5555
'diffpass.data_utils.create_groupwise_seq_records': ( 'data_utils.html#create_groupwise_seq_records',
5656
'diffpass/data_utils.py'),
57-
'diffpass.data_utils.fetch_seq_records_from_group_names': ( 'data_utils.html#fetch_seq_records_from_group_names',
58-
'diffpass/data_utils.py'),
5957
'diffpass.data_utils.get_single_and_paired_seqs': ( 'data_utils.html#get_single_and_paired_seqs',
6058
'diffpass/data_utils.py'),
61-
'diffpass.data_utils.seq_records_tokenizer': ( 'data_utils.html#seq_records_tokenizer',
62-
'diffpass/data_utils.py')},
59+
'diffpass.data_utils.one_hot_encode_msa': ( 'data_utils.html#one_hot_encode_msa',
60+
'diffpass/data_utils.py'),
61+
'diffpass.data_utils.pad_msas_with_dummy_sequences': ( 'data_utils.html#pad_msas_with_dummy_sequences',
62+
'diffpass/data_utils.py')},
6363
'diffpass.entropy_ops': { 'diffpass.entropy_ops.pointwise_shannon': ( 'entropy_ops.html#pointwise_shannon',
6464
'diffpass/entropy_ops.py'),
6565
'diffpass.entropy_ops.smooth_mean_one_body_entropy': ( 'entropy_ops.html#smooth_mean_one_body_entropy',

diffpass/data_utils.py

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/data_utils.ipynb.
22

33
# %% auto 0
4-
__all__ = ['SeqRecord', 'SeqRecords', 'GroupwiseSeqRecords', 'create_groupwise_seq_records', 'fetch_seq_records_from_group_names',
5-
'get_single_and_paired_seqs', 'seq_records_tokenizer', 'compute_num_correct_pairings',
4+
__all__ = ['SeqRecord', 'SeqRecords', 'GroupwiseSeqRecords', 'create_groupwise_seq_records', 'pad_msas_with_dummy_sequences',
5+
'get_single_and_paired_seqs', 'one_hot_encode_msa', 'compute_num_correct_pairings',
66
'compute_comparable_group_idxs']
77

8-
# %% ../nbs/data_utils.ipynb 3
8+
# %% ../nbs/data_utils.ipynb 4
99
from collections import defaultdict
1010
from collections.abc import Sequence
1111
from typing import Optional, Union
12+
from copy import deepcopy
1213

1314
import numpy as np
1415

@@ -21,7 +22,7 @@
2122
SeqRecords = list[SeqRecord]
2223
GroupwiseSeqRecords = dict[str, SeqRecords]
2324

24-
# %% ../nbs/data_utils.ipynb 4
25+
# %% ../nbs/data_utils.ipynb 5
2526
def create_groupwise_seq_records(
2627
seq_records: dict[str, SeqRecords],
2728
group_name_func: callable,
@@ -41,18 +42,61 @@ def create_groupwise_seq_records(
4142
return data_group_by_group
4243

4344

44-
def fetch_seq_records_from_group_names(
45-
data_group_by_group: GroupwiseSeqRecords,
46-
group_names: Sequence[str],
47-
) -> dict:
48-
seq_records = []
49-
group_sizes = []
45+
def pad_msas_with_dummy_sequences(
46+
data_group_by_group_x: GroupwiseSeqRecords,
47+
data_group_by_group_y: GroupwiseSeqRecords,
48+
*,
49+
dummy_symbol: str = "-",
50+
) -> tuple[GroupwiseSeqRecords, GroupwiseSeqRecords]:
51+
"""Pad MSAs with dummy sequences so that all groups/species contain the same
52+
number of sequences."""
53+
# Check that all sequences in the x and y MSAs have the same length
54+
lengths_x = set(
55+
[
56+
len(seq)
57+
for data_x_this_group in data_group_by_group_x.values()
58+
for _, seq in data_x_this_group
59+
]
60+
)
61+
lengths_y = set(
62+
[
63+
len(seq)
64+
for data_y_this_group in data_group_by_group_y.values()
65+
for _, seq in data_y_this_group
66+
]
67+
)
68+
if len(lengths_x) != 1:
69+
raise ValueError(
70+
"Sequences in the first input collection must have the same lengths for padding with dummy gap sequences."
71+
)
72+
if len(lengths_y) != 1:
73+
raise ValueError(
74+
"Sequences in the second input collection must have the same lengths for padding with dummy gap sequences."
75+
)
76+
len_x = next(iter(lengths_x))
77+
len_y = next(iter(lengths_y))
78+
79+
group_names = set(data_group_by_group_x.keys()) | set(data_group_by_group_y.keys())
80+
81+
data_group_by_group_x_padded = defaultdict(SeqRecords)
82+
data_group_by_group_y_padded = defaultdict(SeqRecords)
83+
data_group_by_group_x_padded.update(deepcopy(data_group_by_group_x))
84+
data_group_by_group_y_padded.update(deepcopy(data_group_by_group_y))
5085
for group_name in group_names:
51-
recs_this_group_name = data_group_by_group[group_name]
52-
seq_records.extend(recs_this_group_name)
53-
group_sizes.append(len(recs_this_group_name))
86+
max_depth = max(
87+
len(data_group_by_group_x[group_name]),
88+
len(data_group_by_group_y[group_name]),
89+
)
90+
data_group_by_group_x_padded[group_name] += [
91+
(f"dummy_{i}", dummy_symbol * len_x)
92+
for i in range(max_depth - len(data_group_by_group_x[group_name]))
93+
]
94+
data_group_by_group_y_padded[group_name] += [
95+
(f"dummy_{i}", dummy_symbol * len_y)
96+
for i in range(max_depth - len(data_group_by_group_y[group_name]))
97+
]
5498

55-
return {"seq_records": seq_records, "group_sizes": group_sizes}
99+
return data_group_by_group_x_padded, data_group_by_group_y_padded
56100

57101

58102
def get_single_and_paired_seqs(
@@ -96,15 +140,16 @@ def get_single_and_paired_seqs(
96140
"xy_seqs_to_counts_by_group": xy_seqs_to_counts_by_group,
97141
}
98142

99-
# %% ../nbs/data_utils.ipynb 5
100-
def seq_records_tokenizer(
143+
# %% ../nbs/data_utils.ipynb 6
144+
def one_hot_encode_msa(
101145
seq_records: SeqRecords,
102146
aa_to_int: Optional[dict[str, int]] = None,
103147
device: Optional[torch.device] = None,
104148
) -> torch.Tensor:
105149
"""
106-
Given a list of records of the form (header, sequence), tokenize each
107-
sequence and one-hot encode each token.
150+
Given a list of records of the form (header, sequence), assumed to be a parsed MSA,
151+
tokenize each sequence and one-hot encode each token. Return a 3D tensor representing the
152+
one-hot encoded MSA.
108153
"""
109154
if aa_to_int is None:
110155
aa_to_int = DEFAULT_AA_TO_INT
@@ -120,7 +165,7 @@ def seq_records_tokenizer(
120165

121166
return tokenized_records_oh
122167

123-
# %% ../nbs/data_utils.ipynb 6
168+
# %% ../nbs/data_utils.ipynb 8
124169
def compute_num_correct_pairings(
125170
hard_perms_by_group: list[np.ndarray],
126171
*,
@@ -168,7 +213,7 @@ def compute_num_correct_pairings(
168213

169214
return correct
170215

171-
# %% ../nbs/data_utils.ipynb 7
216+
# %% ../nbs/data_utils.ipynb 10
172217
def compute_comparable_group_idxs(
173218
group_sizes_arr: np.ndarray, *, max_size_ratio: int, max_group_size: int
174219
) -> np.ndarray:

diffpass/msa_parsing.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/msa_parsing.ipynb.
22

33
# %% auto 0
4-
__all__ = ['deletekeys', 'translation', 'read_sequence', 'remove_insertions', 'read_msa']
4+
__all__ = ['SeqRecord', 'SeqRecords', 'deletekeys', 'translation', 'read_sequence', 'remove_insertions', 'read_msa']
55

6-
# %% ../nbs/msa_parsing.ipynb 2
7-
from typing import List, Tuple
6+
# %% ../nbs/msa_parsing.ipynb 3
87
import string
98
import itertools
109

1110
from Bio import SeqIO
1211

12+
SeqRecord = tuple[str, str]
13+
SeqRecords = list[SeqRecord]
1314

1415
deletekeys = dict.fromkeys(string.ascii_lowercase)
1516
deletekeys["."] = None
1617
deletekeys["*"] = None
1718
translation = str.maketrans(deletekeys)
1819

1920

20-
def read_sequence(filename: str) -> Tuple[str, str]:
21+
def read_sequence(filename: str) -> SeqRecord:
2122
"""Reads the first (reference) sequences from a fasta or MSA file."""
2223
record = next(SeqIO.parse(filename, "fasta"))
2324
return record.description, str(record.seq)
@@ -28,7 +29,7 @@ def remove_insertions(sequence: str) -> str:
2829
return sequence.translate(translation)
2930

3031

31-
def read_msa(filename: str, nseq: int) -> List[Tuple[str, str]]:
32+
def read_msa(filename: str, nseq: int) -> SeqRecords:
3233
"""Reads the first nseq sequences from an MSA file, automatically removes insertions."""
3334
if nseq == -1:
3435
nseq = len([elem.id for elem in SeqIO.parse(filename, "fasta")])

mutual_information_msa_pairing.ipynb

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,31 +82,31 @@
8282
},
8383
{
8484
"cell_type": "code",
85+
"execution_count": 3,
8586
"metadata": {},
87+
"outputs": [],
8688
"source": [
8789
"# DiffPaSS parsing and preprocessing utilities\n",
8890
"from diffpass.msa_parsing import read_msa\n",
89-
"from diffpass.data_utils import create_groupwise_seq_records, fetch_seq_records_from_group_names, seq_records_tokenizer, compute_num_correct_pairings\n",
91+
"from diffpass.data_utils import create_groupwise_seq_records, one_hot_encode_msa, compute_num_correct_pairings\n",
9092
"\n",
9193
"\n",
9294
"# Load prokaryotic datasets\n",
9395
"\n",
9496
"# HK-RR datasets\n",
9597
"msa_data = [\n",
96-
" read_msa(\"../data/HK-RR/HK_in_Concat_nnn.fasta\", -1),\n",
97-
" read_msa(\"../data/HK-RR/RR_in_Concat_nnn.fasta\", -1)\n",
98+
" read_msa(\"data/HK-RR/HK_in_Concat_nnn.fasta\", -1),\n",
99+
" read_msa(\"data/HK-RR/RR_in_Concat_nnn.fasta\", -1)\n",
98100
"]\n",
99101
"species_name_func = lambda header: header.split(\"|\")[1]\n",
100102
"\n",
101103
"## MALG-MALK datasets\n",
102104
"# msa_data = [\n",
103-
"# read_msa(\"../data/MALG-MALK/MALG_cov75_hmmsearch_extr5000_withLast_b.fasta\", -1),\n",
104-
"# read_msa(\"../data/MALG-MALK/MALK_cov75_hmmsearch_extr5000_withLast_b.fasta\", -1)\n",
105+
"# read_msa(\"data/MALG-MALK/MALG_cov75_hmmsearch_extr5000_withLast_b.fasta\", -1),\n",
106+
"# read_msa(\"data/MALG-MALK/MALK_cov75_hmmsearch_extr5000_withLast_b.fasta\", -1)\n",
105107
"# ]\n",
106108
"# species_name_func = lambda header: header.split(\"_\")[-1]"
107-
],
108-
"outputs": [],
109-
"execution_count": 3
109+
]
110110
},
111111
{
112112
"cell_type": "markdown",
@@ -123,15 +123,13 @@
123123
"metadata": {},
124124
"outputs": [],
125125
"source": [
126+
"# Organize the MSAs by species (\"groupwise\")\n",
126127
"msa_data_species_by_species = [\n",
127128
" create_groupwise_seq_records(msa, species_name_func, remove_groups_with_one_seq=True) \n",
128129
" for msa in msa_data\n",
129130
"]\n",
130131
"all_species = list(msa_data_species_by_species[0])\n",
131-
"assert all_species == list(msa_data_species_by_species[1])\n",
132-
"\n",
133-
"n_species_to_sample = 50\n",
134-
"species = np.random.choice(all_species, n_species_to_sample, replace=False)"
132+
"assert all_species == list(msa_data_species_by_species[1])"
135133
]
136134
},
137135
{
@@ -149,12 +147,15 @@
149147
}
150148
],
151149
"source": [
152-
"msa_data_and_species_sizes = [\n",
153-
" fetch_seq_records_from_group_names(msa_species_by_species, species)\n",
150+
"# Sample a few species to work with, and filter the MSAs to only include these species\n",
151+
"n_species_to_sample = 50\n",
152+
"species = np.random.choice(all_species, n_species_to_sample, replace=False)\n",
153+
"msa_data_species_by_species = [\n",
154+
" {sp: msa_species_by_species[sp] for sp in species}\n",
154155
" for msa_species_by_species in msa_data_species_by_species\n",
155156
"]\n",
156157
"\n",
157-
"species_sizes = msa_data_and_species_sizes[0][\"group_sizes\"]\n",
158+
"species_sizes = [len(records) for records in msa_data_species_by_species[0].values()]\n",
158159
"print(f\"Species sizes: {species_sizes}\")\n",
159160
"\n",
160161
"n_seqs = sum(species_sizes)\n",
@@ -167,8 +168,14 @@
167168
"metadata": {},
168169
"outputs": [],
169170
"source": [
170-
"x = seq_records_tokenizer(msa_data_and_species_sizes[0][\"seq_records\"], device=DEVICE)\n",
171-
"y = seq_records_tokenizer(msa_data_and_species_sizes[1][\"seq_records\"], device=DEVICE)"
171+
"# Bring data back into the original form (list of records)\n",
172+
"msa_data = [\n",
173+
" [record for records_this_species in msa_species_by_species.values() for record in records_this_species]\n",
174+
" for msa_species_by_species in msa_data_species_by_species\n",
175+
"]\n",
176+
"\n",
177+
"x = one_hot_encode_msa(msa_data[0], device=DEVICE)\n",
178+
"y = one_hot_encode_msa(msa_data[1], device=DEVICE)"
172179
]
173180
},
174181
{

0 commit comments

Comments
 (0)