11# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/data_utils.ipynb.
22
33# %% auto 0
4- __all__ = ['SeqRecord' , 'SeqRecords' , 'GroupwiseSeqRecords' , 'create_groupwise_seq_records' , 'fetch_seq_records_from_group_names ' ,
5- 'get_single_and_paired_seqs' , 'seq_records_tokenizer ' , 'compute_num_correct_pairings' ,
4+ __all__ = ['SeqRecord' , 'SeqRecords' , 'GroupwiseSeqRecords' , 'create_groupwise_seq_records' , 'pad_msas_with_dummy_sequences ' ,
5+ 'get_single_and_paired_seqs' , 'one_hot_encode_msa ' , 'compute_num_correct_pairings' ,
66 'compute_comparable_group_idxs' ]
77
8- # %% ../nbs/data_utils.ipynb 3
8+ # %% ../nbs/data_utils.ipynb 4
99from collections import defaultdict
1010from collections .abc import Sequence
1111from typing import Optional , Union
12+ from copy import deepcopy
1213
1314import numpy as np
1415
2122SeqRecords = list [SeqRecord ]
2223GroupwiseSeqRecords = dict [str , SeqRecords ]
2324
24- # %% ../nbs/data_utils.ipynb 4
25+ # %% ../nbs/data_utils.ipynb 5
2526def create_groupwise_seq_records (
2627 seq_records : dict [str , SeqRecords ],
2728 group_name_func : callable ,
@@ -41,18 +42,61 @@ def create_groupwise_seq_records(
4142 return data_group_by_group
4243
4344
44- def fetch_seq_records_from_group_names (
45- data_group_by_group : GroupwiseSeqRecords ,
46- group_names : Sequence [str ],
47- ) -> dict :
48- seq_records = []
49- group_sizes = []
45+ def pad_msas_with_dummy_sequences (
46+ data_group_by_group_x : GroupwiseSeqRecords ,
47+ data_group_by_group_y : GroupwiseSeqRecords ,
48+ * ,
49+ dummy_symbol : str = "-" ,
50+ ) -> tuple [GroupwiseSeqRecords , GroupwiseSeqRecords ]:
51+ """Pad MSAs with dummy sequences so that all groups/species contain the same
52+ number of sequences."""
53+ # Check that all sequences in the x and y MSAs have the same length
54+ lengths_x = set (
55+ [
56+ len (seq )
57+ for data_x_this_group in data_group_by_group_x .values ()
58+ for _ , seq in data_x_this_group
59+ ]
60+ )
61+ lengths_y = set (
62+ [
63+ len (seq )
64+ for data_y_this_group in data_group_by_group_y .values ()
65+ for _ , seq in data_y_this_group
66+ ]
67+ )
68+ if len (lengths_x ) != 1 :
69+ raise ValueError (
70+ "Sequences in the first input collection must have the same lengths for padding with dummy gap sequences."
71+ )
72+ if len (lengths_y ) != 1 :
73+ raise ValueError (
74+ "Sequences in the second input collection must have the same lengths for padding with dummy gap sequences."
75+ )
76+ len_x = next (iter (lengths_x ))
77+ len_y = next (iter (lengths_y ))
78+
79+ group_names = set (data_group_by_group_x .keys ()) | set (data_group_by_group_y .keys ())
80+
81+ data_group_by_group_x_padded = defaultdict (SeqRecords )
82+ data_group_by_group_y_padded = defaultdict (SeqRecords )
83+ data_group_by_group_x_padded .update (deepcopy (data_group_by_group_x ))
84+ data_group_by_group_y_padded .update (deepcopy (data_group_by_group_y ))
5085 for group_name in group_names :
51- recs_this_group_name = data_group_by_group [group_name ]
52- seq_records .extend (recs_this_group_name )
53- group_sizes .append (len (recs_this_group_name ))
86+ max_depth = max (
87+ len (data_group_by_group_x [group_name ]),
88+ len (data_group_by_group_y [group_name ]),
89+ )
90+ data_group_by_group_x_padded [group_name ] += [
91+ (f"dummy_{ i } " , dummy_symbol * len_x )
92+ for i in range (max_depth - len (data_group_by_group_x [group_name ]))
93+ ]
94+ data_group_by_group_y_padded [group_name ] += [
95+ (f"dummy_{ i } " , dummy_symbol * len_y )
96+ for i in range (max_depth - len (data_group_by_group_y [group_name ]))
97+ ]
5498
55- return { "seq_records" : seq_records , "group_sizes" : group_sizes }
99+ return data_group_by_group_x_padded , data_group_by_group_y_padded
56100
57101
58102def get_single_and_paired_seqs (
@@ -96,15 +140,16 @@ def get_single_and_paired_seqs(
96140 "xy_seqs_to_counts_by_group" : xy_seqs_to_counts_by_group ,
97141 }
98142
99- # %% ../nbs/data_utils.ipynb 5
100- def seq_records_tokenizer (
143+ # %% ../nbs/data_utils.ipynb 6
144+ def one_hot_encode_msa (
101145 seq_records : SeqRecords ,
102146 aa_to_int : Optional [dict [str , int ]] = None ,
103147 device : Optional [torch .device ] = None ,
104148) -> torch .Tensor :
105149 """
106- Given a list of records of the form (header, sequence), tokenize each
107- sequence and one-hot encode each token.
150+ Given a list of records of the form (header, sequence), assumed to be a parsed MSA,
151+ tokenize each sequence and one-hot encode each token. Return a 3D tensor representing the
152+ one-hot encoded MSA.
108153 """
109154 if aa_to_int is None :
110155 aa_to_int = DEFAULT_AA_TO_INT
@@ -120,7 +165,7 @@ def seq_records_tokenizer(
120165
121166 return tokenized_records_oh
122167
123- # %% ../nbs/data_utils.ipynb 6
168+ # %% ../nbs/data_utils.ipynb 8
124169def compute_num_correct_pairings (
125170 hard_perms_by_group : list [np .ndarray ],
126171 * ,
@@ -168,7 +213,7 @@ def compute_num_correct_pairings(
168213
169214 return correct
170215
171- # %% ../nbs/data_utils.ipynb 7
216+ # %% ../nbs/data_utils.ipynb 10
172217def compute_comparable_group_idxs (
173218 group_sizes_arr : np .ndarray , * , max_size_ratio : int , max_group_size : int
174219) -> np .ndarray :
0 commit comments