-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathprocess_clusters.py
More file actions
1947 lines (1760 loc) · 131 KB
/
process_clusters.py
File metadata and controls
1947 lines (1760 loc) · 131 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
VERSION = "0.4.11" # does not necessarily match Tree Nine git version
print(f"PROCESS CLUSTERS - VERSION {VERSION}")
# pylint: disable=too-many-statements,too-many-branches,simplifiable-if-expression,too-many-locals,too-complex,consider-using-tuple,broad-exception-caught
# pylint: disable=wrong-import-position,useless-suppression,multiple-statements,line-too-long,consider-using-sys-exit,duplicate-code
#
# Notes:
# * Eventually we may want persistent_cluster_meta to contain parent-child cluster IDs because that might
# prevent hypothetical edge cases where a cluster itself doesn't change but its subcluster gets renamed,
# resulting in the parent cluster not linking to the correct MR project anymore?
# * This script calls a persistent cluster script written by Marc Perry, which handles all the tricky logic for
# assigning persistent cluster IDs to clusters that already exist. However, we also need to assign IDs to new
# clusters, link parent-child clusters, and upload to Microreact, which is what all this Python does.
# * There may be edge cases where Marc's script's assignment of persistent cluster IDs is non-deterministic
# * This script also calls find_clusters.py in -jmatsu mode to get the distance matrix of a backmasked cluster,
# and it sets the distance to this value because we don't want to fall back to the default 20 (which causes
# confusion in the prints, since the debug_name() of the backmask clusters will all be 20), but we also want
# to calculate matrix_max, which is skipped if distance is unsigned 32-bit max. Hence...
UINT32_MAX_MINUS_ONE = 4294967294
#
# * My script's assingment of brand-new cluster IDs is likely non-deterministic as it relies on sets and
# unsorted polars dataframes. Additionally, if typical methods for assigning cluster IDs fail due to name
# conflicts, my script will start calling random numbers to generate new cluster IDs.
# * This script is not super optimized, but it is performant (~1 minute) on laptops up to at least 4000 clusters
# * Some versions of polars are stricter than others in reading/writing TSVs and JSONs
# * 5 SNP clusters always have a 10 SNP parent, and 10 SNP clusters always have a 20 SNP parent
# * Persistent clusters can run into a Ship of Theseus situation over time
#
# Surprisingly important information r/e logging:
# As of mid-2025, after changing its backend to GCP Batch, Terra seems to have an issue where logging slows
# task execution to an extreme degree. We're talking "a task used to take 10 hours but now is less than
# halfway through at 48 hours on comparable inputs" levels of slowness. It genuinely appears to be faster to
# save intermediate dataframes as JSONs to the disk rather than print them to stdout, so that's what this script
# does via debug_logging_handler_df().
# Tradeoff: If this script crashes (or if Terra/GCP has an intermittent error that crashes the VM, which seems
# to happen about 5% of the time as of mid-2025), Terra may or may not delocalize these JSON files, potentially
# leaving you without valuable debug information. If this happens, try using whatever files you have from
# find_clusters.py and/or the previous task to do testing locally.
import io
import os
import csv
import json
import time
import random
import logging
import argparse
from datetime import datetime, timezone
import subprocess
import requests
import polars as pl # this is overkill and takes forever to import; too bad!
from polars.testing import assert_series_equal
pl.Config.set_tbl_rows(-1)
pl.Config.set_tbl_cols(-1)
pl.Config.set_tbl_width_chars(200)
pl.Config.set_fmt_str_lengths(500)
pl.Config.set_fmt_table_cell_list_len(500)
today = datetime.now(timezone.utc) # I don't care if this runs past midnight, give everything the same day!
print(f"It's {today} in Thurles right now. Up Tipp!")
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("requests").setLevel(logging.WARNING)
max_random_id_attempts = 500 # maximum attempts to fix invalid cluster IDs
FIND_CLUSTERS_OUTFILE_PREFIX = "workdir"
if os.path.isfile("/scripts/marcs_incredible_script_update.pl"):
script_path = "/scripts"
elif os.path.isfile("./scripts/marcs_incredible_script_update.pl"):
script_path = "./scripts"
else:
raise FileNotFoundError
if os.path.isfile("/scripts/marcs_incredible_script_update.pl"):
script_path = "/scripts"
elif os.path.isfile("./scripts/marcs_incredible_script_update.pl"):
script_path = "./scripts"
else:
raise FileNotFoundError
def main():
print("################# (1) INPUT HANDLING #################")
parser = argparse.ArgumentParser(description="Crunch data, extract trees, upload to MR, etc")
parser.add_argument('-s', '--shareemail', type=str, required=False, help="email (just one) for calling MR share API")
parser.add_argument('-to', '--token', type=str, required=False, help="TXT: MR token")
parser.add_argument('-as', '--allsamples', type=str, required=False, help='comma-delimited list of samples to consider for clustering (if absent, will do entire tree)')
parser.add_argument('-ls', '--latestsamples', type=str, help='TSV: latest sample information (as identified by find_clusters.py)')
parser.add_argument('-lm', '--latestclustermeta', type=str, required=False, help='TSV: metadata from find_clusters.py (only used for matrix_max)')
#parser.add_argument('-sm', '--samplemeta', type=str, help='TSV: sample metadata pulled from terra (including myco outs), one line per sample')
parser.add_argument('-pcm', '--persistentclustermeta', type=str, help='TSV: persistent cluster metadata from last full run of TB-D')
parser.add_argument('-pid', '--persistentids', type=str, help='TSV: persistent IDs from last full run of TB-D')
parser.add_argument('-mat', '--mat_tree', type=str, help='PB: tree')
#parser.add_argument('-cs', '--contextsamples', type=int, default=0, help="[UNUSED] int: Number of context samples for cluster subtrees")
parser.add_argument('-cd', '--combineddiff', type=str, help='diff: Maple-formatted combined diff file, needed for backmasking')
parser.add_argument('-dl', '--denylist', type=str, required=False, help='TXT: newline delimited list of cluster IDs to never use')
parser.add_argument('-mr', '--upload_to_microreact', action='store_true', help='upload clusters to MR (requires -to)')
parser.add_argument('-d', '--today', type=str, required=True, help='ISO 8601 date, YYYY-MM-DD')
parser.add_argument('-v', '--verbose', action='store_true', help='enable verbose logging to stdout (warning: extremely slow on Terra)')
parser.add_argument('--no_err_on_decimated_on_mr', action='store_true', help='do not error if a cluster on MR becomes decimated')
parser.add_argument('--no_cleanup', action='store_true', help="do not clean up input files (this may break delocalization on Terra; only use this for rapid debug runs)")
parser.add_argument('--mr_blank_template', type=str, help="JSON: template file for blank MR projects")
parser.add_argument('--mr_update_template', type=str, help="JSON: template file for in-use MR projects")
parser.add_argument('--mr_decimated_template', type=str, help="JSON: template file for in-use MR projects which have since lost all of their samples")
parser.add_argument('--no_upload_childless_20s', action='store_true', help="do not upload 20-clusters to MR if they have no children (ie, no subclusters)")
parser.add_argument('--skip_perl', action='store_true', help="skip the perl scripts to debug using existing rosetta_20/10/5 files (don't enable this for real runs!)")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
if args.persistentclustermeta and not args.persistentids:
raise ValueError("You provided --persistentclustermeta but no --persistentids, you need both or neither")
if args.persistentids and not args.persistentclustermeta:
raise ValueError("You provided --persistentids but no --persistentclustermeta, you need both or neither")
if not args.persistentids and not args.persistentclustermeta:
start_over = True
print("You have not provided persistent IDs nor persistent cluster metadata. This will restart clustering.")
else:
start_over = False
if args.upload_to_microreact and not (args.mr_blank_template and args.mr_update_template):
raise ValueError("You said --upload_to_microreact but didn't include --mr_blank_template and/or --mr_update_template")
all_latest_samples = pl.read_csv(args.latestsamples,
separator="\t",
schema_overrides={"latest_cluster_id": pl.Utf8}).filter(pl.col("latest_cluster_id").is_not_null())
if not start_over:
all_persistent_samples = pl.read_csv(args.persistentids,
separator="\t",
schema_overrides={"cluster_id": pl.Utf8}).filter(pl.col("cluster_id").is_not_null())
persistent_clusters_meta = pl.read_csv(args.persistentclustermeta,
separator="\t",
null_values="NULL",
try_parse_dates=True,
schema_overrides={"cluster_id": pl.Utf8}).filter(pl.col("cluster_id").is_not_null())
global today # pylint: disable=global-statement
args_today = datetime.strptime(args.today, "%Y-%m-%d").date()
if args_today != today:
# this is just to warn the user they might be using an old or cached date, but we have
# to use the user-provided date anyway for WDL output matching to work without glob()
# (we are avoiding WDL-glob() because it creates a random-name folder in GCP which is annoying)
logging.warning("The date you provided (%s, interpreted as type %s) doesn't match the date in Thurles.", args_today, type(args_today))
today = args_today
if args.upload_to_microreact and not args.token:
logging.error("You entered --upload_to_microreact but didn't provide a token file with --token")
raise ValueError
if args.token:
with open(args.token, 'r', encoding="utf-8") as file:
token = file.readline()
debug_logging_handler_df("Loaded all_latest_samples", all_latest_samples, "1_inputs")
if not start_over:
debug_logging_handler_df("Loaded all_persistent_samples", all_persistent_samples, "1_inputs")
# ensure each sample in latest-clusters has, at most, one 20 SNP, one 10 SNP, and one 05 SNP
temp_latest_groupby_sample = all_latest_samples.group_by("sample_id", maintain_order=True).agg(pl.col("cluster_distance"))
for row in temp_latest_groupby_sample.iter_rows(named=True):
if len(row["cluster_distance"]) == 1:
if row["cluster_distance"] == [-1]: # unclustered -- but not latestly in find_clusters.py
pass
else:
assert row["cluster_distance"] == [20], f"{row['sample_id']} has one cluster but it's not 20 SNP: {row['cluster_distance']}"
elif len(row["cluster_distance"]) == 2:
assert row["cluster_distance"] == [20,10], f"{row['sample_id']} has two clusters but it's not 20-10: {row['cluster_distance']}"
elif len(row["cluster_distance"]) == 3:
assert row["cluster_distance"] == [20,10,5], f"{row['sample_id']} has three clusters but it's not 20-10-5: {row['cluster_distance']}"
else:
logging.error(f"{row['sample_id']} has invalid clusters: {row['cluster_distance']}") #pylint: disable=logging-fstring-interpolation
raise ValueError
if not start_over:
# TODO: this method of detecting decimation isn't working correctly and should probably be replaced
persis_groupby_cluster = all_persistent_samples.group_by("cluster_id", maintain_order=True).agg(pl.col("sample_id"), pl.col("cluster_distance").unique().first())
# This is very tricky: We need to make sure that if any persistent clusters don't exist anymore, their IDs do not get reused.
# It should only happen when running on a subset of samples and/or if samples have been removed. (In theory something like
# this can also happen if two clusters get merged, but we'll cross that bridge later.)
# Previously we would iterate temp_latest_groupby_cluster rowwise and check if
# this_clusters_latest_samps.isdisjoint(persis_groupby_cluster.filter(pl.col("cluster_id") == latest_cluster_id)),
# but this was a terrible idea because it would fire whenever samples simply generated a different cluster ID.
# We instead want to start with a simple question:
# Are there any samples present in all_persistent_samples not present in all_latest_samples?
# If no: Literally who cares, the perl script will handle it
# If yes: Iterate the *persistent* clusters rowwise to make sure they aren't decimated
all_latest_samples_set = set(all_latest_samples["sample_id"].to_list())
all_persistent_samples_set = set(all_persistent_samples["sample_id"].to_list())
debug_logging_handler_txt("Set of all latest samples", "1_inputs", 10)
debug_logging_handler_txt(all_latest_samples_set, "1_inputs", 10)
debug_logging_handler_txt("Set of all persistent samples", "1_inputs", 10)
debug_logging_handler_txt(all_persistent_samples_set, "1_inputs", 10)
if all_persistent_samples_set.issubset(all_latest_samples_set):
debug_logging_handler_txt("All persistent samples is a subset of all latest samples", "1_inputs", 20)
else:
samples_missing_from_latest = all_persistent_samples_set - all_latest_samples_set # these are sets so this excludes samples exclusive to all_latest
debug_logging_handler_txt(f"Samples appear to be missing from the latest run: {samples_missing_from_latest}", "1_inputs", 30)
if args.allsamples:
all_input_samples_including_unclustered = args.allsamples.split(',')
else:
all_input_samples_including_unclustered = None
debug_logging_handler_txt("Missing args.allsamples; can't be sure if missing samples are dropped because they no longer cluster or if they were never input.", "1_inputs", 30)
for sample in samples_missing_from_latest:
if all_input_samples_including_unclustered is None:
pass
elif sample in all_input_samples_including_unclustered:
debug_logging_handler_txt(f"{sample} is newly unclustered", "1_inputs", 30)
else:
debug_logging_handler_txt(f"{sample} seems to have been dropped from inputs", "1_inputs", 30)
# get persistent cluster ID regardless
cluster_ids = get_cluster_ids_for_sample(all_persistent_samples, sample)
for cluster in cluster_ids:
if len(get_other_samples_in_cluster(all_persistent_samples, cluster, samples_missing_from_latest)) <= 1:
# In theory we could handle this, in practice it's a massive pain in the neck and very easy to mess up!!
debug_logging_handler_txt(f"{cluster} is decimated thanks to losing all samples (or all but one)", "1_inputs", 30)
# IF AND ONLY IF this is not on MR (which should only happen if this is a 20-cluster with no subclusters),
# we can live with this being a decimated cluster.
if not has_microreact_url(persistent_clusters_meta, cluster):
debug_logging_handler_txt(f"{cluster} already lacks a Microreact URL, so we can live with it being decimated", "1_inputs", 20)
elif args.no_err_on_decimated_on_mr:
debug_logging_handler_txt(f"{cluster} has an MR URL but we will accept it being decimated due to --no_err_on_decimated_on_mr", "1_inputs", 30)
else:
debug_logging_handler_txt(f"{cluster} has an MR URL and should never be decimated. Cannot continue.", "1_inputs", 40)
exit(55)
else:
debug_logging_handler_txt(f"Dropped {sample} from {cluster} but that seems to be okay", "1_inputs", 20)
print("################# (2) 𓅀 𓁪 THE MARC PERRY ZONE 𓁫 𓀂 #################")
all_latest_20 = all_latest_samples.filter(pl.col("cluster_distance") == 20).select(["sample_id", "latest_cluster_id"])
all_latest_10 = all_latest_samples.filter(pl.col("cluster_distance") == 10).select(["sample_id", "latest_cluster_id"])
all_latest_5 = all_latest_samples.filter(pl.col("cluster_distance") == 5).select(["sample_id", "latest_cluster_id"])
all_latest_unclustered = all_latest_samples.filter(pl.col("cluster_distance") == -1).select(["sample_id", "latest_cluster_id"]) # pylint: disable=unused-variable
all_persistent_20 = all_persistent_samples.filter(pl.col("cluster_distance") == 20).select(["sample_id", "cluster_id"])
all_persistent_10 = all_persistent_samples.filter(pl.col("cluster_distance") == 10).select(["sample_id", "cluster_id"])
all_persistent_5 = all_persistent_samples.filter(pl.col("cluster_distance") == 5).select(["sample_id", "cluster_id"])
all_persistent_unclustered = all_persistent_samples.filter(pl.col("cluster_distance") == -1).select(["sample_id", "cluster_id"]) # pylint: disable=unused-variable
# Marc's script requires that you input only sample IDs that are present in both the persistent cluster file
# and your latest clusters, so we need to do an inner join first -- after getting our "rosetta stone" we will
# modify the original dataframe.
# "But Ash!" I hear you say, "These merges give you two dataframes that each have a column of old IDs and a
# column of new IDs! That's a rosetta stone already, we don't need Marc's script!"
# You are a fool. Yes, we could stick to that... but then we wouldn't be able to handle situations where
# clusters merge, split, or generally get messy without reinventing the wheel Marc has already made for us.
debug_logging_handler_txt("Preparing to run the absolute legend's script...", "2_marc", 20)
filtered_latest_20 = all_latest_20.join(all_persistent_20.drop(['cluster_id']), on="sample_id", how="inner").rename({'latest_cluster_id': 'cluster_id'}).sort('cluster_id')
filtered_latest_10 = all_latest_10.join(all_persistent_10.drop(['cluster_id']), on="sample_id", how="inner").rename({'latest_cluster_id': 'cluster_id'}).sort('cluster_id')
filtered_latest_5 = all_latest_5.join(all_persistent_5.drop(['cluster_id']), on="sample_id", how="inner").rename({'latest_cluster_id': 'cluster_id'}).sort('cluster_id')
filtered_persistent_20 = all_persistent_20.join(all_latest_20.drop(['latest_cluster_id']), on="sample_id", how="inner").sort('cluster_id')
filtered_persistent_10 = all_persistent_10.join(all_latest_10.drop(['latest_cluster_id']), on="sample_id", how="inner").sort('cluster_id')
filtered_persistent_5 = all_persistent_5.join(all_latest_5.drop(['latest_cluster_id']), on="sample_id", how="inner").sort('cluster_id')
for distance, dataframe in {20: filtered_latest_20, 10: filtered_latest_10, 5: filtered_latest_5}.items():
debug_logging_handler_df(f"Filtered latest {distance}", dataframe, "2_marc")
for distance, dataframe in {20: filtered_persistent_20, 10: filtered_persistent_10, 5: filtered_persistent_5}.items():
debug_logging_handler_df(f"Filtered persistent {distance}", dataframe, "2_marc")
filtered_latest_20.select(["sample_id", "cluster_id"]).write_csv('filtered_latest_20.tsv', separator='\t', include_header=False)
filtered_latest_10.select(["sample_id", "cluster_id"]).write_csv('filtered_latest_10.tsv', separator='\t', include_header=False)
filtered_latest_5.select(["sample_id", "cluster_id"]).write_csv('filtered_latest_5.tsv', separator='\t', include_header=False)
filtered_persistent_20.select(["sample_id", "cluster_id"]).write_csv('filtered_persistent_20.tsv', separator='\t', include_header=False)
filtered_persistent_10.select(["sample_id", "cluster_id"]).write_csv('filtered_persistent_10.tsv', separator='\t', include_header=False)
filtered_persistent_5.select(["sample_id", "cluster_id"]).write_csv('filtered_persistent_5.tsv', separator='\t', include_header=False)
if not args.skip_perl:
debug_logging_handler_txt("Actually running scripts...", "2_marc", 20)
perl_20 = subprocess.run(f"perl {script_path}/marcs_incredible_script_update.pl filtered_persistent_20.tsv filtered_latest_20.tsv", shell=True, check=True, capture_output=True, text=True)
debug_logging_handler_txt(perl_20.stdout, "2_marc", 20)
subprocess.run("mv mapped_persistent_cluster_ids_to_new_cluster_ids.tsv rosetta_stone_20.tsv", shell=True, check=True)
perl_10 = subprocess.run(f"perl {script_path}/marcs_incredible_script_update.pl filtered_persistent_10.tsv filtered_latest_10.tsv", shell=True, check=True, capture_output=True, text=True)
debug_logging_handler_txt(perl_10.stdout, "2_marc", 20)
subprocess.run("mv mapped_persistent_cluster_ids_to_new_cluster_ids.tsv rosetta_stone_10.tsv", shell=True, check=True)
perl_5 = subprocess.run(f"perl {script_path}/marcs_incredible_script_update.pl filtered_persistent_5.tsv filtered_latest_5.tsv", shell=True, check=True, capture_output=True, text=True)
debug_logging_handler_txt(perl_5.stdout, "2_marc", 20)
subprocess.run("mv mapped_persistent_cluster_ids_to_new_cluster_ids.tsv rosetta_stone_5.tsv", shell=True, check=True)
# TODO: why are were we not running equalize tabs except when logging is debug?
# debug print basic rosetta stones
if logging.root.level == logging.DEBUG:
for rock in ['rosetta_stone_20.tsv', 'rosetta_stone_10.tsv', 'rosetta_stone_5.tsv']:
if os.path.isfile(rock):
with open(rock, 'r', encoding="utf-8") as file:
debug_logging_handler_txt(f"---------------------\nContents of {rock} (before strip_tsv and equalize_tabs):\n", "2_marc", 10)
debug_logging_handler_txt(list(file), "2_marc", 10)
#subprocess.run(f"/bin/bash {script_path}/equalize_tabs.sh {rock}", shell=True, check=True)
# get more information about merges... if we have any!
rock_pairs = {'rosetta_stone_20.tsv':'rosetta_stone_20_merges.tsv',
'rosetta_stone_10.tsv':'rosetta_stone_10_merges.tsv',
'rosetta_stone_5.tsv':'rosetta_stone_5_merges.tsv'}
for rock, merge_rock in rock_pairs.items():
if os.path.isfile(merge_rock):
debug_logging_handler_txt(f"Found {merge_rock}, indicating clusters merged at this distance", "2_marc", 20)
debug_logging_handler_txt(f"---------------------\nContents of {merge_rock} (before strip_tsv and equalize_tabs):\n", "2_marc", 20)
with open(merge_rock, 'r', encoding="utf-8") as file:
debug_logging_handler_txt(list(merge_rock), "2_marc", 20)
#subprocess.run(f"/bin/bash {script_path}/equalize_tabs.sh {rock}", shell=True, check=True)
subprocess.run(f"/bin/bash {script_path}/strip_tsv.sh {rock} {merge_rock}", shell=True, check=True)
else:
debug_logging_handler_txt(f"Did not find {merge_rock}, indicating clusters didn't merge at this distance", "2_marc", 20)
# we need schema_overrides or else cluster IDs can become non-zfilled i64
# For some godforesaken reason, some versions of polars will throw `polars.exceptions.ComputeError: found more fields than defined in 'Schema'` even if we set
# infer_schema = True with a hella large infer_schema_length. Idk why because the exact same file works perfectly fine on my local installation of polars (polars==1.27.0)
# without even needing to set anything with infer_schema!! Not even a try-except with the except having a three column schema works!! Ugh!!!
# TODO: is this because the docker is polars==1.26.0?
# ---> WORKAROUND: equalize_tabs.sh
debug_logging_handler_txt("Processing perl outputs...", "2_marc", 20)
rosetta_20 = pl.read_csv("rosetta_stone_20.tsv", separator="\t", has_header=False,
schema_overrides={"column_1": pl.Utf8, "column_2": pl.Utf8, "column_3": pl.Utf8},
truncate_ragged_lines=True, ignore_errors=True, infer_schema_length=5000).rename(
{'column_1': 'persistent_cluster_id', 'column_2': 'latest_cluster_id', 'column_3': 'special_handling'}
)
rosetta_10 = pl.read_csv("rosetta_stone_10.tsv", separator="\t", has_header=False,
schema_overrides={"column_1": pl.Utf8, "column_2": pl.Utf8, "column_3": pl.Utf8},
truncate_ragged_lines=True, ignore_errors=True, infer_schema_length=5000).rename(
{'column_1': 'persistent_cluster_id', 'column_2': 'latest_cluster_id', 'column_3': 'special_handling'}
)
rosetta_5 = pl.read_csv("rosetta_stone_5.tsv", separator="\t", has_header=False,
schema_overrides={"column_1": pl.Utf8, "column_2": pl.Utf8, "column_3": pl.Utf8},
truncate_ragged_lines=True, ignore_errors=True, infer_schema_length=5000).rename(
{'column_1': 'persistent_cluster_id', 'column_2': 'latest_cluster_id', 'column_3': 'special_handling'}
)
# It seems theoretically possible that a (say) 20 SNP cluster could generate a persistent ID that matches a persistent ID
# already being used by (say) 10 SNP cluster. We'll call this "cross-distance ID sharing" because I love naming things.
# This function takes in persistent_clusters_meta so it can account for decimated cluster IDs (in theory).
# TODO: there should be a check like this in the ad-hoc case too, just in case find_clusters does an oopsies
debug_logging_handler_txt("Checking for cross-distance ID shares", "2_marc", 20)
rosetta_20, rosetta_10, rosetta_5 = fix_cross_distance_ID_shares(rosetta_20, rosetta_10, rosetta_5, persistent_clusters_meta, "marc_perry")
# TODO: Because we merge on latest_cluster_id here, and we only fixed the persistent ID, this merge could get funky?
# In theory everything should be fine...
debug_logging_handler_txt("Joining all_latest_samples on rosetta_20 upon latest_cluster_id to generate persistent_20_cluster_id column...", "2_marc", 20)
latest_samples_translated = (all_latest_samples.join(rosetta_20, on="latest_cluster_id", how="full")).rename({'persistent_cluster_id': 'persistent_20_cluster_id'}).drop("latest_cluster_id_right")
debug_logging_handler_txt("Joining all_latest_samples on rosetta_10 upon latest_cluster_id to generate persistent_10_cluster_id column...", "2_marc", 20)
latest_samples_translated = (latest_samples_translated.join(rosetta_10, on="latest_cluster_id", how="full")).rename({'persistent_cluster_id': 'persistent_10_cluster_id'}).drop("latest_cluster_id_right")
debug_logging_handler_txt("Nullfilling the special_handling column...", "2_marc", 20)
latest_samples_translated = nullfill_LR(latest_samples_translated, "special_handling", "special_handling_right")
debug_logging_handler_txt("Joining all_latest_samples on rosetta_5 upon latest_cluster_id to generate persistent_5_cluster_id column...", "2_marc", 20)
latest_samples_translated = (latest_samples_translated.join(rosetta_5, on="latest_cluster_id", how="full")).rename({'persistent_cluster_id': 'persistent_5_cluster_id'}).drop("latest_cluster_id_right")
debug_logging_handler_txt("Nullfilling the special_handling column (again)...", "2_marc", 20)
latest_samples_translated = nullfill_LR(latest_samples_translated, "special_handling", "special_handling_right")
debug_logging_handler_df("Early latest_samples_translated before polars expressions", latest_samples_translated, "2_marc")
all_latest_samples = None # stymie my silly tendency to reuse stale variables
debug_logging_handler_txt("Marking samples that were in 20, 10, or 5 clusters previously...", "2_marc", 20)
latest_samples_translated = latest_samples_translated.with_columns(
pl.when(pl.col("cluster_distance") == 20)
.then(
pl.when(latest_samples_translated["sample_id"].is_in(all_persistent_20["sample_id"]))
.then(True)
.otherwise(False)
)
.otherwise(None)
.alias("in_20_cluster_last_run")
)
latest_samples_translated = latest_samples_translated.with_columns(
pl.when(pl.col("cluster_distance") == 10)
.then(
pl.when(latest_samples_translated["sample_id"].is_in(all_persistent_10["sample_id"]))
.then(True)
.otherwise(False)
)
.otherwise(None)
.alias("in_10_cluster_last_run")
)
latest_samples_translated = latest_samples_translated.with_columns(
pl.when(pl.col("cluster_distance") == 5)
.then(
pl.when(latest_samples_translated["sample_id"].is_in(all_persistent_5["sample_id"]))
.then(True)
.otherwise(False)
)
.otherwise(None)
.alias("in_5_cluster_last_run")
)
latest_samples_translated = latest_samples_translated.with_columns(
pl.when(latest_samples_translated["sample_id"].is_in(all_persistent_samples_set))
.then(False)
.otherwise(True)
.alias("sample_brand_new")
)
# Previously we did
# ```
# latest_samples_translated.with_columns(
# pl.coalesce('persistent_20_cluster_id', 'persistent_10_cluster_id', 'persistent_5_cluster_id', 'latest_cluster_id')
# .alias("cluster_id")
# ```
# This was done because brand new clusters are not considered in Marc's script if they don't have any samples in a cluster
# at that distance, since we only input sample IDs that are also in the persistent list at that SNP distance. So for example,
# if brand new sample X and old sample Y formed brand new cluster 000033 at 10 SNPs, and Y was not in a 10 SNP cluster
# previously, Y would not included in the output of Marc's script. We wouldn't have a persistent ID for it, so
# might as well just use the latest_cluster_id (aka workdir cluster ID), righ?
#
# But this is problematic, since 000033 could already exist as a persistent ID in use by other samples. So we'd end up with
# something like this:
#
# samp | dis | workdir | cluster_id
# ----------------------------------
# A | 10 | 000015 | 000033
# B | 10 | 000015 | 000033
# X | 10 | 000033 | 000033
# Y | 10 | 000033 | 000033
#
# We had a section dedicated to adjusting this, but it was confusing and I'm not fully confident it was error-proof, so I've
# decided to make handling the "brand new cluster" (more correctly "no persistent ID") situation more explict.
#
latest_samples_translated = (
latest_samples_translated.with_columns(
pl.coalesce(pl.col(['persistent_20_cluster_id', 'persistent_10_cluster_id', 'persistent_5_cluster_id']), pl.lit("NO_PERSIS_ID"))
.alias("cluster_id")
).drop(['persistent_20_cluster_id', 'persistent_10_cluster_id', 'persistent_5_cluster_id'])
).rename({'latest_cluster_id': 'workdir_cluster_id'})
# To prevent samples that don't need to tagged as renamed in special_handling even if their cluster ID matches the workdir cluster ID,
# we'll fill in these ahead of time.
latest_samples_translated = latest_samples_translated.with_columns(
pl.when(pl.col('workdir_cluster_id') == pl.col('cluster_id'))
.then(pl.lit('none')) # we don't say "unchanged" since the cluster's contents may have changed, nor do we use literal None
.otherwise(pl.col('special_handling'))
.alias('special_handling')
)
print("################# (3) SPECIAL HANDLING (of new clusters) #################")
# This section is for handling the brand-new-cluster situation, since it generated without a persistent ID, but the workdir ID
# it generated with could overlap with an existing persistent ID. In older versions we coalesced workdir cluster ID into (persistent)
# cluster ID in the previous section, then in this section, detected issues by checking how many workdir cluster IDs a given
# (persistent) cluster ID had. But it was kind of cringe so now we're handling this differently.
debug_logging_handler_txt("Handling clusters without a persistent ID (if any)", "3_new_clusters", 20)
latest_samples_translated = latest_samples_translated.with_columns( # this assumes all no-persistent-ids are brand new clusters; for now tis okay
pl.when(pl.col('cluster_id') == pl.lit("NO_PERSIS_ID"))
.then(pl.lit(True))
.otherwise(pl.lit(False))
.alias('cluster_brand_new')
)
no_persistent_id_yet = latest_samples_translated.filter(pl.col('cluster_id') == pl.lit("NO_PERSIS_ID"))
debug_logging_handler_df("samples with no persistent ID yet", no_persistent_id_yet, "3_new_clusters")
workdir_ids_of_no_persistent_ids = set(no_persistent_id_yet.select('workdir_cluster_id').to_series().to_list())
for possibly_problematic_id in workdir_ids_of_no_persistent_ids:
# check for nonsense
n_samps_in_full_df = len(latest_samples_translated.filter(pl.col('workdir_cluster_id') == pl.lit(possibly_problematic_id)))
n_samps_in_filtered_df = len(no_persistent_id_yet.filter(pl.col('workdir_cluster_id') == pl.lit(possibly_problematic_id)))
assert n_samps_in_full_df == n_samps_in_filtered_df
# See if this is actually a problem -- does the workdir cluster ID overlap with a persistent ID?
# By including IDs from persistent cluster meta, this should account for decimated samples too.
all_current_persistent_cluster_ids = set(latest_samples_translated.select('cluster_id').cast(pl.Utf8).to_series().to_list())
all_previous_persistent_cluster_ids = set(persistent_clusters_meta.select('cluster_id').cast(pl.Utf8).to_series().to_list())
all_cluster_ids = all_current_persistent_cluster_ids.union(all_previous_persistent_cluster_ids)
# Yes, there's an overlap, let's generate a new ID and call that the persistent ID
# (idk why we need to do set([str(possibly_problematic_id)]) instead of just set(str(possibly_problematic_id)) but we do, ugh)
if set([str(possibly_problematic_id)]) & all_cluster_ids:
new_id = generate_new_cluster_id(all_cluster_ids, "3_new_clusters")
debug_logging_handler_txt(f"Workdir ID of brand new cluster {possibly_problematic_id} already exists as persistent ID, will change to {new_id}", "3_new_clusters", 20)
latest_samples_translated = latest_samples_translated.with_columns([
pl.when(pl.col('workdir_cluster_id') == pl.lit(possibly_problematic_id))
.then(pl.lit(new_id))
.otherwise(pl.col('cluster_id'))
.alias('cluster_id'),
pl.when(pl.col('workdir_cluster_id') == pl.lit(possibly_problematic_id))
.then(pl.lit('brand new (renamed)'))
.otherwise(pl.col('special_handling'))
.alias('special_handling')
])
# No overlap, let's use the workdir cluster ID as the persistent ID
else:
debug_logging_handler_txt(f"Workdir ID of brand new cluster {possibly_problematic_id} doesn't exist as persistent ID", "3_new_clusters", 20)
latest_samples_translated = latest_samples_translated.with_columns([
pl.when(pl.col('workdir_cluster_id') == pl.lit(possibly_problematic_id))
.then(pl.col('workdir_cluster_id'))
.otherwise(pl.col('cluster_id'))
.alias('cluster_id'),
pl.when(pl.col('workdir_cluster_id') == pl.lit(possibly_problematic_id))
.then(pl.lit('brand new (no conflict)'))
.otherwise(pl.col('special_handling'))
.alias('special_handling')
])
# Check for B.S.
true_for_10_not_20 = latest_samples_translated.filter(
pl.col("in_10_cluster_last_run") & ~pl.col("in_20_cluster_last_run")
)["sample_id"].to_list()
if true_for_10_not_20:
raise ValueError(f"These samples were in a 10 SNP cluster last time, but not a 20 SNP cluster: {', '.join(true_for_10_not_20)}")
true_for_5_not_10 = latest_samples_translated.filter(
pl.col("in_5_cluster_last_run") & ~pl.col("in_10_cluster_last_run")
)["sample_id"].to_list()
if true_for_5_not_10:
raise ValueError(f"These samples were in a 5 SNP cluster last time, but not a 10 SNP cluster: {', '.join(true_for_5_not_10)}")
true_for_5_not_20 = latest_samples_translated.filter(
pl.col("in_5_cluster_last_run") & ~pl.col("in_20_cluster_last_run")
)["sample_id"].to_list()
if true_for_5_not_20:
raise ValueError(f"These samples were in a 5 SNP cluster last time, but not a 20 SNP cluster: {', '.join(true_for_5_not_20)}")
debug_logging_handler_df("latest_samples_translated after pl.coalesce and check (sorted by workdir_cluster_id in this view)",
latest_samples_translated.sort('workdir_cluster_id'), "3_new_clusters")
# ad-hoc (no-persistent-IDs) case
else:
latest_samples_translated = all_latest_samples.with_columns([
pl.col("latest_cluster_id").alias("workdir_cluster_id"),
pl.col("latest_cluster_id").alias("cluster_id"),
pl.lit(True).alias("cluster_brand_new"),
pl.lit("restart").alias("special_handling"),
pl.lit(False).alias("in_20_cluster_last_run"),
pl.lit(False).alias("in_10_cluster_last_run"),
pl.lit(False).alias("in_5_cluster_last_run"),
pl.lit(True).alias("sample_brand_new")
])
print("################# (4) LINK PARENTS AND CHILDREN #################")
# Possible ways to speed this up:
# * more native polars expressions
# * acting on the grouped dataframe instead of latest_samples_translated
#
# We actually do this twice, once on latest samples and once on grouped-by-persistent. In the future,
# we may want to instead pass in persistent parent-child information as metadata so we don't need to
# recalculate every time...
debug_logging_handler_txt("Preparing to link parents and children...", "4_calc_paternity", 20)
latest_samples_translated = latest_samples_translated.sort(["cluster_distance", "cluster_id"])
debug_logging_handler_df("latest_samples_translated at start of step 4", latest_samples_translated, "4_calc_paternity")
if not start_over:
all_persistent_samples = all_persistent_samples.sort(["cluster_distance", "cluster_id"])
debug_logging_handler_df("all_persistent_samples at start of step 4", all_persistent_samples, "4_calc_paternity")
debug_logging_handler_txt("Linking samples...", "4_calc_paternity", 20)
sample_map_latest = build_sample_map(latest_samples_translated)
if not start_over:
sample_map_previous = build_sample_map(all_persistent_samples)
# TODO: This works, but I feel like there's bound to be another/better/faster way to do this using polars expressions
debug_logging_handler_txt("Iterating latest_samples_translated's rows...", "4_calc_paternity", 20)
parental_latest = establish_parenthood(latest_samples_translated, sample_map_latest)
debug_logging_handler_txt(f"Generated latest parenthood list (len {len(parental_latest)} values), but won't update dataframe yet", "4_calc_paternity", 20)
if not start_over:
parental_previous = establish_parenthood(all_persistent_samples, sample_map_previous)
debug_logging_handler_txt(f"Generated old parenthood list (len {len(parental_previous)} values), but won't update dataframe yet", "4_calc_paternity", 20)
# We don't actually do the updates until after the group, because dealing with an agg'd list() column in polars is a mess
# Also, acting on the grouped dataframe should be a little faster too (even if bigO doesn't really change)
print("################# (5) GROUP #################")
# In this section, we're going to be grouping by persistent cluster ID in order to perform some checks,
# and get ready to check if clusters have been updated or not (however the final determination will rely
# on a join, which happens after this, in order to properly catch clusters that lose samples)
debug_logging_handler_txt("Grouping by persistent cluster ID...", "5_group", 20)
latest_samples_translated = add_col_if_not_there(latest_samples_translated, "matrix_max")
grouped = latest_samples_translated.group_by("cluster_id").agg(
pl.col("cluster_distance").unique(),
pl.col("matrix_max").unique(),
pl.col("cluster_brand_new").unique(),
pl.col("sample_brand_new").unique(),
pl.col("special_handling").unique(),
pl.col("workdir_cluster_id").unique(),
pl.col("in_20_cluster_last_run").unique(),
pl.col("in_10_cluster_last_run").unique(),
pl.col("in_5_cluster_last_run").unique(),
pl.col("sample_id").unique(),
pl.col("sample_id").n_unique().alias("n_samples")
)
# Check every cluster has at least two samples (because this is based of the "latest" samples dataframe and doesn't have any
# persistent metadata, we can do this check, since decimated clusters are excluded.)
if not (grouped["sample_id"].list.len() >= 2).all():
logging.basicConfig(level=logging.DEBUG) # effectively overrides global verbose
debug_logging_handler_txt("Found cluster with less than two samples (decimated clusters are excluded in this check)", "5_group", 40)
debug_logging_handler_df("ERROR clusters with less than two samples", grouped.filter(pl.col("sample_id").list.len() > 1), "5_group")
raise ValueError('Found cluster with less than two samples (decimated clusters are excluded in this check')
debug_logging_handler_txt("Asserted all clusters have at least two samples (this check happens before we have any info about decimated clusters)", "5_group", 20)
# Check every cluster ID only has one workdir cluster ID (this is a relic of some older versions' handling of brand new clusters and should never fire)
if not (grouped["workdir_cluster_id"].list.len() <= 1).all():
logging.basicConfig(level=logging.DEBUG) # effectively overrides global verbose
debug_logging_handler_txt('Found non-zero number of "persistent" cluster IDs associated with multiple different workdir cluster IDs', "5_group", 40)
debug_logging_handler_df("ERROR clusters with more than one workdir ID", grouped.filter(pl.col("workdir_cluster_id").list.len() > 1), "5_group")
raise ValueError('Found non-zero number of "persistent" cluster IDs associated with multiple different workdir cluster IDs')
debug_logging_handler_txt("Asserted all persistent cluster IDs only associated with one or zero workdir IDs", "5_group", 20)
# Check only one distance per cluster ID (double checking cross-distance ID shares, this also should never fire)
if not (grouped["cluster_distance"].list.len() == 1).all():
debug_logging_handler_txt('Found non-zero number of "persistent" cluster IDs associated with multiple SNP distances', "5_group", 40)
debug_logging_handler_df("ERROR clusters with not one distance", grouped.filter(pl.col("cluster_distance").list.len() != 1), "5_group")
raise ValueError("Some clusters have multiple unique cluster_distance values.")
debug_logging_handler_txt("Asserted all cluster_distance lists have a len of precisely 1", "5_group", 20)
# Check only one type of special handling per cluster ID (in theory that would actually be okay but given how we do it it shouldn't happen)
if not (grouped["special_handling"].list.len() == 1).all():
debug_logging_handler_txt("Found different types of special handling in some clusters", "5_group", 40)
debug_logging_handler_df("ERROR clusters with not one special_handling", grouped.filter(pl.col("special_handling").list.len() != 1), "5_group")
raise ValueError("Some clusters have multiple unique special_handling values.")
debug_logging_handler_txt("Asserted all special_handling lists have a len of precisely 1", "5_group", 20)
# Check... you get the picture
if not (grouped["cluster_brand_new"].list.len() == 1).all():
debug_logging_handler_txt("Found clusters that don't know if they're new or not", "5_group", 40)
debug_logging_handler_df("ERROR clusters with not one cluster_brand_new", grouped.filter(pl.col("cluster_brand_new").list.len() != 1), "5_group")
raise ValueError("Some clusters have multiple unique cluster_brand_new values.")
debug_logging_handler_txt("Asserted all cluster_brand_new lists have a len of precisely 1", "5_group", 20)
debug_logging_handler_txt("Converting lists to base types where possible...", "5_group", 20)
grouped = grouped.with_columns([
pl.col("workdir_cluster_id").list.get(0).alias("workdir_cluster_id"),
pl.col("cluster_distance").list.get(0).alias("cluster_distance"),
pl.col("special_handling").list.get(0).alias("special_handling"),
pl.col("cluster_brand_new").list.get(0).alias("cluster_brand_new")
])
# Collapse the 20/10/5 last run columns
grouped = grouped.with_columns(
pl.when(pl.col("cluster_distance") == 20)
.then(pl.col("in_20_cluster_last_run"))
.otherwise(
pl.when(pl.col("cluster_distance") == 10)
.then(pl.col("in_10_cluster_last_run"))
.otherwise(
pl.when(pl.col("cluster_distance") == 5)
.then(pl.col("in_5_cluster_last_run"))
.otherwise(None)
)
)
.alias("samples_previously_in_cluster") # don't love this name but can't think of a better one (also see note in part 8)
).sort('cluster_id').drop(['in_20_cluster_last_run', 'in_10_cluster_last_run', 'in_5_cluster_last_run']) # will be readded upon join
debug_logging_handler_df("After grouping and then intager-a-fy", grouped, "5_group")
# Previously we dropped "sample_id" column here since we grouped a second time before joining on the persistent metadata/groupby files,
# but there isn't a reason to do that anymore.
print("################# (6) UPDATE PATERNITY #################")
# We already identified parents and children earlier, but now we're going to actually update the dataframe with the "updates" lists
debug_logging_handler_txt("Updating latest grouped dataframe with paternity information...", "6_update_paternity", 20)
grouped = grouped.with_columns(
pl.lit(None).cast(pl.Utf8).alias("cluster_parent"),
pl.lit([]).cast(pl.List(pl.Utf8)).alias("cluster_children") # intentionally not None (see part 8)
).sort(["cluster_distance", "cluster_id"])
for cluster_id, col, value in parental_latest:
#debug_logging_handler_txt(f"For cluster {cluster_id}, col {col}, val {value} in updates", "6_update_paternity", 10) # too verbose even for debug logging
if col == "cluster_parent":
# Beware: This is an overwrite, so we can't check if there's multiple cluster parents
grouped = update_cluster_column(grouped, cluster_id, "cluster_parent", value)
else:
grouped = grouped.with_columns(
pl.when(pl.col("cluster_id") == cluster_id)
.then((pl.col("cluster_children").list.concat(pl.lit(value))).list.unique())
.otherwise(pl.col("cluster_children"))
.alias("cluster_children")
)
cluster_id, parental_latest = None, None
debug_logging_handler_df("grouped after linking parents and children", grouped, "6_update_paternity")
if not start_over:
debug_logging_handler_txt("Updating previous run's dataframe with paternity information...", "6_update_paternity", 20)
persis_groupby_cluster = persis_groupby_cluster.with_columns(
pl.lit(None).cast(pl.Utf8).alias("cluster_parent"),
pl.lit([]).cast(pl.List(pl.Utf8)).alias("cluster_children") # intentionally not None (see part 8)
).sort(["cluster_distance", "cluster_id"])
for cluster_id, col, value in parental_previous:
#debug_logging_handler_txt(f"For cluster {cluster_id}, col {col}, val {value} in updates", "6_update_paternity", 10) # too verbose even for debug logging
if col == "cluster_parent":
# Beware: This is an overwrite, so we can't check if there's multiple cluster parents
persis_groupby_cluster = update_cluster_column(persis_groupby_cluster, cluster_id, "cluster_parent", value)
else:
persis_groupby_cluster = persis_groupby_cluster.with_columns(
pl.when(pl.col("cluster_id") == cluster_id)
.then((pl.col("cluster_children").list.concat(pl.lit(value))).list.unique())
.otherwise(pl.col("cluster_children"))
.alias("cluster_children")
)
cluster_id, parental_previous = None, None
debug_logging_handler_df("persis_groupby_cluster after linking parents and children", persis_groupby_cluster, "6_update_paternity")
# Checks involving parent/child relationships
# The cluster_children check might change across versions of polars; right now we expect an empty list, as opposed to pl.Null or [pl.Null].
# Previously I'm pretty sure we had [pl.Null] since we inserted paternity before the group, and now we do it after.
# We don't use list len() because [null] is considered to have a length of 1 in some versions of polars but perhaps not others;
# see also https://github.com/pola-rs/polars/issues/18522
if start_over:
check_dfs = [grouped]
else:
check_dfs = [grouped, persis_groupby_cluster]
for df in check_dfs:
assert ((df.filter(pl.col("cluster_distance") == pl.lit(5)))["cluster_parent"].is_not_null()).all(), "5-cluster with null cluster_parent"
assert ((df.filter(pl.col("cluster_distance") == pl.lit(10)))["cluster_parent"].is_not_null()).all(), "10-cluster with null cluster_parent"
assert ((df.filter(pl.col("cluster_distance") == pl.lit(20)))["cluster_parent"].is_null()).all(), "20-cluster with non-null cluster_parent"
assert ((df.filter(pl.col("cluster_distance") == pl.lit(5)))["cluster_children"] == []).all(), "5-cluster with cluster_children"
debug_logging_handler_txt("Asserted no 5 clusters have children or no parent, no 10s lack parent, and no 20s have parent", "6_update_paternity", 20)
# convert [null] to null
# Actually, we don't do this anymore, because I want to use pl.col("col_a").list.unique().list.sort() after joining with the
# persistent grouped by dataframe, in order to see if a cluster got new/different children. I'm not confident that will
# work problem on null, so we're returning to empty lists.
# debug_logging_handler_txt("Converting [null] to null in cluster_children...", "6_update_paternity", 20)
# grouped = grouped.with_columns([
# # previously: pl.when(pl.col("cluster_children").list.get(0).is_null())
# # We used to handle paternity before grouping, resulting in empty children being [pl.Null], but now we handle
# # paternity after grouping so empty children are now []. In fact, list.get(0) will error in our current version.
# pl.when(pl.col("cluster_children") == pl.lit([]))
# .then(None)
# .otherwise(pl.col("cluster_children"))
# .alias("cluster_children")
# ])
# if not start_over:
# persis_groupby_cluster = persis_groupby_cluster.with_columns([
# pl.when(pl.col("cluster_children") == pl.lit([]))
# .then(None)
# .otherwise(pl.col("cluster_children"))
# .alias("cluster_children")
# ])
print("################# (7) JOIN with persistent/latest information #################")
# First, we join with the persistent cluster metadata TSV to get first_found, last_update, jurisdictions, and microreact_url
# Then, we join with persis_groupby_cluster (which will tell us what samples clusters previously had)
# Only after doing these can we confidentally declare which clusters have actually been updated in some way
# Latest cluster meta is only used for matrix_max
if args.latestclustermeta:
debug_logging_handler_txt("Adding matrix_max metadata from args.latestclustermeta...", "7_join", 20)
latest_clusters_meta = pl.read_csv(args.latestclustermeta, separator="\t", schema_overrides={"latest_cluster_id": pl.Utf8})
latest_clusters_meta = latest_clusters_meta.rename({'latest_cluster_id': 'workdir_cluster_id'})
latest_clusters_meta = latest_clusters_meta.select(['workdir_cluster_id', 'matrix_max'])
# loglevel override (TODO: probably don't need this once we're sure of structure of metadata frame)
current_log_level = logging.getLogger().getEffectiveLevel()
if "matrix_max" in grouped.columns:
debug_logging_handler_txt("matrix_max already in grouped dataframe before merge?", "7_join", 30)
logging.basicConfig(level=logging.DEBUG) # force print debug frame, even if Terra hates that
debug_logging_handler_df("matrix_max already in grouped dataframe before merge", grouped, "7_join")
grouped = grouped.drop("matrix_max")
grouped = grouped.join(latest_clusters_meta, how="full", on="workdir_cluster_id", coalesce=True)
debug_logging_handler_df("grouped after join", grouped, "7_join")
# revert loglevel override
logging.basicConfig(level=logging.DEBUG if current_log_level == 10 else logging.INFO)
if "workdir_cluster_id_right" in grouped.columns:
grouped = grouped.drop("workdir_cluster_id_right")
if "matrix_max_right" in grouped.columns:
grouped = grouped.drop("matrix_max_right")
else:
# No matrix_max, but we can still have b_max
debug_logging_handler_txt("args.latestclustermeta not defined, matrix_max will be Null for all clusters", "7_join", 20)
if start_over:
# This sets first_found, needs_updating, and last_json_update to today in the start over case. In the persistent case,
# these values come from either the persistent dataframe (in part 7), or are set to today if the cluster is brand new
# (in part 8).
debug_logging_handler_txt("Generating metadata fresh (since we're starting over)...", "7_join", 20)
all_cluster_information = grouped.with_columns([
pl.lit(today.isoformat()).alias("first_found"),
pl.lit(True).alias("needs_updating"),
pl.lit(today.isoformat()).alias("last_json_update")])
debug_logging_handler_df("after adding relevant information", all_cluster_information, "7_join")
else:
debug_logging_handler_txt("Joining with the persistent metadata TSV...", "7_join", 20)
persistent_clusters_meta = persistent_clusters_meta.with_columns(pl.lit(False).alias("cluster_brand_new"))
all_cluster_information = grouped.join(persistent_clusters_meta, how="full", on="cluster_id", coalesce=True)
# TODO: this is gonna require we filter out the Nones for new and decimated samples; probably isn't worth the hassle
#assert_series_equal(
# all_cluster_information.filter().select("cluster_brand_new").to_series(),
# all_cluster_information.filter().select("cluster_brand_new_right").to_series(),
# check_names=False, check_order=True
#)
# Persistent clutter meta can introduce decimated clusters which will have nulls for some columns, better
# deal with those now
all_cluster_information = all_cluster_information.with_columns(
pl.col("cluster_brand_new").fill_null(False)
)
all_cluster_information = all_cluster_information.with_columns(
pl.col("n_samples").fill_null(0)
)
# Since this is the persistent case, the only time first_found should be null is if the cluster wasn't in the persistent
# dataframe (ie is brand new). (In the start over case, we already set first_found to today, so we don't need any other
# first_found handling after this.)
all_cluster_information = all_cluster_information.with_columns([
pl.when(pl.col("first_found").is_null())
.then(
pl.when(pl.col("cluster_brand_new"))
.then(pl.lit(today.isoformat()))
.otherwise(pl.lit("UNKNOWN")) # going foward this shouldn't happen but it did happen on older versions
)
.otherwise(pl.col("first_found"))
.alias("first_found"),
])
# Warn about stuff that has an unknown find date -- this shouldn't happen going forward but it did happen in the past
# (which is why it's just a warning and not an error; for the time being I'm testing with the old JSONs)
no_date = all_cluster_information.filter(pl.col("first_found") == pl.lit("UNKNOWN"))
if len(no_date) > 0:
debug_logging_handler_txt(f"Found {no_date.shape[0]} clusters with no clear first_found date", "7_join", 30)
debug_logging_handler_df("WARNING no first_found date", no_date, "7_join")
debug_logging_handler_df("after joining with persistent_clusters_meta", all_cluster_information, "7_join")
# Now joined by the grouped-by-cluster persistent cluster ID information, which gives us the list of samples the clusters previously had
# persis_groupby_cluster is only created if start_over is false, so we don't need to worry about unassigned vars here
debug_logging_handler_txt("Joining persis_groupby_cluster...", "7_join", 20)
all_cluster_information = all_cluster_information.join(persis_groupby_cluster, how="full", on="cluster_id", coalesce=True) # pylint: disable=possibly-used-before-assignment
all_cluster_information = all_cluster_information.with_columns(
pl.when(pl.col("cluster_distance").is_null())
.then(
pl.when(pl.col("cluster_distance_right").is_null())
.then(pl.lit(None)) # should never happen, except in older decimated clusters
.otherwise(pl.col("cluster_distance_right"))
)
.otherwise(pl.col("cluster_distance"))
.alias("cluster_distance")
).drop("cluster_distance_right")
all_cluster_information = all_cluster_information.rename({
'sample_id_right': 'sample_id_previously',
'cluster_parent_right': 'cluster_parent_previously',
'cluster_children_right': 'cluster_children_previously'
})
# It's okay if cluster_parent is null, but the way we detect cluster children changes might freak out with nulls
for column in ["cluster_children", "cluster_children_previously"]:
all_cluster_information = all_cluster_information.with_columns(
pl.col(column).fill_null([])
)
all_cluster_information = all_cluster_information.drop("cluster_brand_new_right")
# Older cluster JSONs don't have a "decimated" column, so we're not gonna rely on it at all
debug_logging_handler_txt("Declaring clusters decimated, or not...", "7_join", 20)
if "decimated" in all_cluster_information.columns:
all_cluster_information = all_cluster_information.drop("decimated")
all_cluster_information = all_cluster_information.with_columns([
# TODO: Consider adding a check for an empty list instead of just null -- seems to work though?
pl.when(
(
(pl.col('sample_id').is_null())
.or_(pl.col("sample_id_previously").is_null())
)
.and_(pl.col("cluster_brand_new") == pl.lit(False)) # We already changed these from null to False
)
.then(pl.lit(True))
.otherwise(pl.lit(False))
.alias("decimated"),
# We distinguish between old and newly decimated clusters since that affects if they need updating,
# and I don't really trust polars to compare empty/null lists properly
pl.when((pl.col('sample_id').is_null()).and_(pl.col("sample_id_previously").is_not_null()))
.then(pl.lit(True))
.otherwise(pl.lit(False))
.alias("newly_decimated"),
# This means a decimated cluster's persistent ID got reused, which should never happen
pl.when(
(pl.col('sample_id').is_not_null())
.and_(pl.col("sample_id_previously").is_null())
.and_(pl.col("cluster_brand_new") == pl.lit(False)) # We already changed these from null to False
)
.then(pl.lit(True))
.otherwise(pl.lit(False))
.alias("reused_decimated_persistent_id")
])
reused_decimated = all_cluster_information.filter(pl.col("reused_decimated_persistent_id"))
if len(reused_decimated) > 0:
logging.basicConfig(level=logging.DEBUG) # effectively overrides global verbose in order to force dumping df to stdout
debug_logging_handler_txt(f"We appear to have reused {reused_decimated.shape[0]} decimated cluster IDs", "7_join", 40)
debug_logging_handler_df("reused decimated persistent IDs", reused_decimated, "7_join")
raise ValueError
all_cluster_information = all_cluster_information.drop("reused_decimated_persistent_id")
debug_logging_handler_df("all_cluster_information at end", all_cluster_information, "7_join")
# Now we finally have all the information we need to declare which clusters have ACTUALLY changed or not
print("################# (8) RECOGNIZE (have I seen you before?) #################")
# Previously we tried to get clever and rely on samples_previously_in_cluster via:
# [True, False]/[False, True] --> some samples were in cluster previously --> old cluster, needs updating
# [False] --> no samples were in this cluster previously --> new cluster, needs updating
# [True] --> all samples were in this cluster previously --> old cluster, unchanged
#
# However, this doesn't work if an existing cluster splits into a new cluster where the new cluster only has old samples,
# and likely missed other edge cases too. Still, we can detect many (most?) changes without comparing lists of samples directly.
#
# Unfortunately, to be on the safe side, I still think it's worth comparing lists. I don't think there's ever been
# a situation where the only way to have caught it is with a list compare, but it hypothetically possible.
# The clearest way to do this in polars is to iterate the dataframe rowwise, extract the two lists as sets, and
# then do a set comparison. There might be a more effecient way to do this, but let's keep it simple for now.
debug_logging_handler_txt("Determining which clusters are brand new and/or need updating...", "8_recognize", 20)
all_cluster_information = add_col_if_not_there(all_cluster_information, "changes")
# KEEP IN MIND:
# * These cases aren't mutually exclusive
# * The easiest way to check if a polars list col is [False, True] or [True, False] is by checking its length and praying there's no nulls
# Existing parent gains or loses children (actual child IDs checked, not just number)
all_cluster_information = all_cluster_information.with_columns(
pl.when(
(pl.col("cluster_children").list.unique().list.sort() != pl.col("cluster_children_previously").list.unique().list.sort())
.and_(pl.col('cluster_brand_new') == pl.lit(False))
)
.then(True)
.otherwise(False)
.alias("different_children")
)
if logging.root.level in (logging.INFO, logging.DEBUG):
different_children = all_cluster_information.filter(pl.col('different_children')).select(
['cluster_id', 'cluster_distance', 'sample_brand_new',
'cluster_children', 'cluster_children_previously',
'sample_id', 'sample_id_previously']
)
debug_logging_handler_txt(f"Found {different_children.shape[0]} clusters with different children", "8_recognize", 20)
debug_logging_handler_df("different_children", different_children, "8_recognize")
# Let's be a little more specific...
# We can use this without .fill_null([]) because cluster_children, unlike cluster_parent, is an empty list instead of null when empty
all_cluster_information = all_cluster_information.with_columns(
pl.col("cluster_children").list.set_difference(pl.col("cluster_children_previously")).alias('new_child_clusters'),
pl.col("cluster_children_previously").list.set_difference(pl.col("cluster_children")).alias('missing_child_clusters')
)
if logging.root.level in (logging.INFO, logging.DEBUG):
new_child_clusters = all_cluster_information.filter(pl.col('new_child_clusters').list.len().gt(pl.lit(0))).select(
['cluster_id', 'cluster_distance', 'sample_brand_new',
'different_children', 'new_child_clusters', 'missing_child_clusters',
'cluster_children', 'cluster_children_previously',
'sample_id', 'sample_id_previously']
)
debug_logging_handler_txt(f"Found {new_child_clusters.shape[0]} clusters with new children", "8_recognize", 20)
debug_logging_handler_df("new_child_clusters", new_child_clusters, "8_recognize")
missing_child_clusters = all_cluster_information.filter(pl.col('missing_child_clusters').list.len().gt(pl.lit(0))).select(
['cluster_id', 'cluster_distance', 'sample_brand_new',
'different_children', 'new_child_clusters', 'missing_child_clusters',
'cluster_children', 'cluster_children_previously',
'sample_id', 'sample_id_previously']
)
debug_logging_handler_txt(f"Found {missing_child_clusters.shape[0]} clusters missing a child cluster", "8_recognize", 20)
debug_logging_handler_df("new_child_clusters", missing_child_clusters, "8_recognize")
# Child has new parent (hypothetically possible if a 20/10 cluster splits weirdly enough)
all_cluster_information = all_cluster_information.with_columns(
pl.when(
(pl.col('cluster_parent') != pl.col('cluster_parent_previously'))
.and_(pl.col('cluster_brand_new') == pl.lit(False))
)
.then(True)
.otherwise(False)
.alias("new_parent")
)
new_parent = all_cluster_information.filter(pl.col('new_parent'))
debug_logging_handler_txt(f"Found {new_parent.shape[0]} clusters with a new parent cluster", "8_recognize", 20)
debug_logging_handler_df("new_parent", new_parent, "8_recognize")
# The cluster is newly decimated (we print all decimated clusters but will only flag newly decimated as needs_updating)
decimated = all_cluster_information.filter(pl.col("decimated"))
new_decimated = all_cluster_information.filter(pl.col("newly_decimated"))
debug_logging_handler_txt(f"Found {decimated.shape[0]} decimated clusters of which {new_decimated.shape[0]} are newly decimated", "8_recognize", 30)
debug_logging_handler_df("decimated clusters", decimated, "8_recognize")
# Existing cluster has brand new samples
all_cluster_information = all_cluster_information.with_columns(
pl.when(
(pl.col('samples_previously_in_cluster').list.len() == pl.lit(2))
.and_(pl.col('sample_brand_new').list.len() == pl.lit(2))
.and_(
(pl.col('special_handling') == pl.lit("none"))
.or_(pl.col('special_handling') == pl.lit("renamed"))
)
.and_(pl.col('cluster_brand_new') == pl.lit(False))
)
.then(True)
.otherwise(False)
.alias("existing_new_samps")
)
existing_new_samps = all_cluster_information.filter(pl.col('existing_new_samps'))
debug_logging_handler_txt(f"Found {existing_new_samps.shape[0]} existing clusters that got new samples", "8_recognize", 20)
debug_logging_handler_df("existing_new_samps", existing_new_samps, "8_recognize")
# Cluster is brand new (which may have only new samples, or old-and-new samples, or perhaps even only old samples)
cluster_brand_new = all_cluster_information.filter(pl.col('cluster_brand_new'))
debug_logging_handler_txt(f"Found {cluster_brand_new.shape[0]} brand new clusters", "8_recognize", 20)
debug_logging_handler_df("cluster_brand_new", cluster_brand_new, "8_recognize")
# Cluster's sample contents changed
all_cluster_information = all_cluster_information.with_columns(
pl.when(
(pl.col("sample_id").list.unique().list.sort() != pl.col("sample_id_previously").list.unique().list.sort())
.and_(pl.col('cluster_brand_new') == pl.lit(False))
)
.then(True)
.otherwise(False)
.alias("different_samples")
)
different_samples = all_cluster_information.filter(pl.col('different_samples'))
debug_logging_handler_txt(f"Found {different_samples.shape[0]} clusters whose sample contents changed in some way", "8_recognize", 20)
debug_logging_handler_df("different_samples", different_samples, "8_recognize")
# These situations should all be covered by the above, but might be worth pulling out on their own eventually:
# * The cluster itself is brand new, made up of only brand new samples (sample_brand_new = [true], special_handling = "brand new", samples_previously_in_cluster = [false])