diff --git a/.env b/.env index 74786fe4..88ac0bb3 100644 --- a/.env +++ b/.env @@ -10,7 +10,7 @@ # expected that you use the same saving conventions as the RCSB PDB, which means: # `1a2b` --> /path/to/pdb_mirror/a2/1a2b.cif.gz # To set up a mirror, you can use tha atomworks commandline: `atomworks pdb sync /path/to/mirror` -PDB_MIRROR_PATH= +PDB_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_01_pdb # The `CCD_MIRROR_PATH` is a path to a local mirror of the CCD database. # It's expected that you use the same saving conventions as the RCSB CCD, which means: @@ -19,7 +19,7 @@ PDB_MIRROR_PATH= # If no mirror is provided, the internal biotite CCD will be used as a fallback. To provide a # custom CCD for a ligand, you can place it in the in the CCD mirror path following the CCDs pattern. # Example: /path/to/ccd_mirror/M/MYLIGAND1/MYLIGAND1.cif -CCD_MIRROR_PATH= +CCD_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_11_ccd # --- Local MSA directories --- LOCAL_MSA_DIRS= @@ -29,14 +29,14 @@ LOCAL_MSA_DIRS= # The HBPLUS_PATH is a path to the hbplus tool, which is used for hydrogen bond calculation # during training and during metrics computation. # Example: /path/to/hbplus -HBPLUS_PATH= +HBPLUS_PATH=/projects/ml/hbplus # The `X3DNA_PATH` is a path to the x3dna tool, which is used for DNA structure analysis. # Example: /path/to/x3dna-v2.4 -X3DNA_PATH= +X3DNA_PATH=/projects/ml/prot_dna/x3dna-v2.4 # For secondary structure prediction (not currently used) -DSSP_PATH= +DSSP_PATH=/projects/ml/dssp/install/bin/mkdssp # The `HHFILTER_PATH` is a path to the hhfilter tool from the HH-suite, which is used for # filtering MSAs to reduce redundancy. diff --git a/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml b/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml index 063e6f5e..4714de10 100644 --- a/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml +++ b/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml @@ -42,4 +42,4 @@ confidence_head: n_bins_exp_resolved: 2 use_Cb_distances: False use_af3_style_binning_and_final_layer_norms: True - symmetrize_Cb_logits: True \ No newline at end of file + symmetrize_Cb_logits: True diff --git a/models/rfd3/README.md b/models/rfd3/README.md index ecb4f34c..bdbf0dd0 100644 --- a/models/rfd3/README.md +++ b/models/rfd3/README.md @@ -1,6 +1,6 @@ # De novo Design of Biomolecular Interactions with RFdiffusion3 -RFdiffusion3 (RFD3) is a diffusion method that can design protein structures +RFdiffusion3 (RFD3) is a diffusion method that can design biopolymer structures under complex constraints. This repository contains both the training and inference code, and @@ -62,6 +62,8 @@ For example, you can fix sequence and not structure (prediction-type task), fix For full details on how to specify inputs, see the [input specification documentation](./docs/input.md). You can also see `foundry/models/rfd3/configs/inference_engine/rfdiffusion3.yaml` for even more options. +Nucleic acid design, along with proteins, is also possible using RFD3 using the atom23 checkpoints. For full details see the [atom23 design documentation](./docs/examples/atom23_design.md) + ## Further example JSONs for different applications Additional examples are broken up by use case. If you have cloned the repository, matching `.json` files are in `foundry/models/rfd3/docs/examples` @@ -75,27 +77,33 @@ you will need to change the path in the `.json` file(s) before running. - + + + +
-

Nucleic acid binder design

- -
-

Small molecule binder design

+

Small molecule binder design

-

Protein binder design

+

Protein binder design

+

Nucleic acid binder design

+ +
-

Enzyme design

+

Enzyme design

-

Symmetric design

+

Symmetric design

+

Multipolymer design

+ +
diff --git a/models/rfd3/configs/callbacks/design_callbacks.yaml b/models/rfd3/configs/callbacks/design_callbacks.yaml index 309b492a..68c9a4ce 100644 --- a/models/rfd3/configs/callbacks/design_callbacks.yaml +++ b/models/rfd3/configs/callbacks/design_callbacks.yaml @@ -1,5 +1,6 @@ defaults: - train_logging + - metrics_logging - _self_ log_learning_rate_callback: diff --git a/models/rfd3/configs/datasets/design_base_rfd3na.yaml b/models/rfd3/configs/datasets/design_base_rfd3na.yaml new file mode 100644 index 00000000..661b2fb0 --- /dev/null +++ b/models/rfd3/configs/datasets/design_base_rfd3na.yaml @@ -0,0 +1,105 @@ +# base training dataset for training AF3 design models (atom14 variants): +# protein subsampling only. + +defaults: + # Grab datasets + - train/pdb/rfd3_train_interface@train.pdb.sub_datasets.interface + - train/pdb/rfd3_train_pn_unit@train.pdb.sub_datasets.pn_unit + - train/rfd3_monomer_distillation@train + - train/rna_monomer_distillation@train + + # Customized validation datasets + #- val/unconditional@val.unconditional + #- val/unconditional_deep@val.unconditional_deep + #- val/indexed@val.indexed + - val/pseudoknot@val.pseudoknot + + # Customized train masks + - conditions/unconditional@global_transform_args.train_conditions.unconditional + - conditions/island@global_transform_args.train_conditions.island + - conditions/tipatom@global_transform_args.train_conditions.tipatom + - conditions/sequence_design@global_transform_args.train_conditions.sequence_design + - conditions/ppi@global_transform_args.train_conditions.ppi + + - _self_ + +# Create a dictionary used for transform arguments +pipeline_target: rfd3.transforms.pipelines.build_atom14_base_pipeline + +# Base config overrides: +diffusion_batch_size_train: 32 +diffusion_batch_size_inference: 8 +crop_size: 384 +n_recycles_train: 2 +n_recycles_validation: 1 +max_atoms_in_crop: 3840 # ~10x crop size. + +# Global transform arguments are necessary for arguments shared between training and inference +global_transform_args: + n_atoms_per_token: 14 + central_atom: CB + sigma_perturb: 2.0 + sigma_perturb_com: 1.0 + association_scheme: dense + center_option: diffuse # options are ["all", "motif", "diffuse"] + + # Reference conformer policy + generate_conformers: True + generate_conformers_for_non_protein_only: True + provide_reference_conformer_when_unmasked: True + ground_truth_conformer_policy: IGNORE # Other options: REPLACE, ADD, FALLBACK. See atomworks.enums for details + provide_elements_for_unindexed_components: True + use_element_for_atom_names_of_atomized_tokens: True # TODO: correct name, implies unindexed do too + + # PPI Cropping + keep_full_binder_in_spatial_crop: False + max_binder_length: 170 + + # PPI Hotspots + max_ppi_hotspots_frac_to_provide: 0.2 + ppi_hotspot_max_distance: 4.5 + + # Secondary structure features + max_ss_frac_to_provide: 0.4 + min_ss_island_len: 1 + max_ss_island_len: 10 + + # Nucleic acid features + add_na_pair_features: false + + train_conditions: + unconditional: + frequency: 5.0 + sequence_design: + frequency: 2.0 + island: + frequency: 1.0 + tipatom: + frequency: 0.0 + ppi: + frequency: 0.0 + + # Used to create simple boolean flags for downstream conditioning + meta_conditioning_probabilities: + p_is_nucleic_ss_example: 0.1 + p_nucleic_ss_show_partial_feats: 0.7 + calculate_NA_SS: 0.5 + calculate_hbonds: 0.2 + calculate_rasa: 0.6 + + keep_protein_motif_rasa: 0.1 # Small to prevent noisy input to model + hbond_subsample: 0.5 + + # fully indexed training + unindex_leak_global_index: 0.10 + unindex_insert_random_break: 0.10 + unindex_remove_random_break: 0.10 + + # Probability of adding 1d secondary structure conditioning + add_1d_ss_features: 0.1 + featurize_plddt: 0.9 # Applied for monomer distillation only + add_global_is_non_loopy_feature: 0.99 + + # PPI + add_ppi_hotspots: 0.75 + full_binder_crop: 0.75 diff --git a/models/rfd3/configs/datasets/train/pdb/base_transform_args.yaml b/models/rfd3/configs/datasets/train/pdb/base_transform_args.yaml index 08c735ca..28e2384d 100644 --- a/models/rfd3/configs/datasets/train/pdb/base_transform_args.yaml +++ b/models/rfd3/configs/datasets/train/pdb/base_transform_args.yaml @@ -43,6 +43,9 @@ dataset: min_ss_island_len: ${datasets.global_transform_args.min_ss_island_len} max_ss_island_len: ${datasets.global_transform_args.max_ss_island_len} + # Nucleic acid features + add_na_pair_features: ${datasets.global_transform_args.add_na_pair_features} + # Cropping crop_size: ${datasets.crop_size} max_atoms_in_crop: ${datasets.max_atoms_in_crop} @@ -56,4 +59,5 @@ dataset: # Other dataset-specific parameters atom_1d_features: ${model.net.token_initializer.atom_1d_features} - token_1d_features: ${model.net.token_initializer.token_1d_features} \ No newline at end of file + token_1d_features: ${model.net.token_initializer.token_1d_features} + token_2d_features: ${model.net.token_initializer.token_2d_features} diff --git a/models/rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml b/models/rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml index d5df571b..9ff698a8 100644 --- a/models/rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml +++ b/models/rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml @@ -12,6 +12,7 @@ dataset: # filters common across all PDB datasets - 'pdb_id not in ["7rte", "7m5w", "7n5u"]' - 'pdb_id not in ["3di3", "5o45", "1z92", "2gy5", "4zxb"]' + - 'pdb_id not in ["1drz", "2m8k", "2miy", "3q3z", "4oqu", "4plx", "4znp", "7kd1", "7kga", "7qr4"]' - "deposition_date < '2024-12-16'" - "resolution < 9.0" - "num_polymer_pn_units <= 300" @@ -19,4 +20,4 @@ dataset: # interface specific filters - "~(pn_unit_1_non_polymer_res_names.notnull() and pn_unit_1_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))" - "~(pn_unit_2_non_polymer_res_names.notnull() and pn_unit_2_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))" - - "is_inter_molecule" \ No newline at end of file + - "is_inter_molecule" diff --git a/models/rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml b/models/rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml index dfd548fc..19d31e8b 100644 --- a/models/rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml +++ b/models/rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml @@ -15,6 +15,7 @@ dataset: # filters common across all PDB datasets - 'pdb_id not in ["7rte", "7m5w", "7n5u"]' - 'pdb_id not in ["3di3", "5o45", "1z92", "2gy5", "4zxb"]' + - 'pdb_id not in ["1drz", "2m8k", "2miy", "3q3z", "4oqu", "4plx", "4znp", "7kd1", "7kga", "7qr4"]' - "deposition_date < '2024-12-16'" - "resolution < 9.0" - "num_polymer_pn_units <= 300" diff --git a/models/rfd3/configs/datasets/train/rna_monomer_distillation.yaml b/models/rfd3/configs/datasets/train/rna_monomer_distillation.yaml new file mode 100644 index 00000000..05ac8472 --- /dev/null +++ b/models/rfd3/configs/datasets/train/rna_monomer_distillation.yaml @@ -0,0 +1,39 @@ +defaults: + - pdb/base_transform_args@rna_monomer_distillation + - _self_ + +rna_monomer_distillation: + dataset: + _target_: atomworks.ml.datasets.StructuralDatasetWrapper + save_failed_examples_to_dir: ${paths.data.failed_examples_dir} + + # cif parser arguments + cif_parser_args: + cache_dir: null + load_from_cache: False + save_to_cache: False + + # metadata parser + dataset_parser: + _target_: atomworks.ml.datasets.parsers.GenericDFParser + pn_unit_iid_colnames: null + + # metadata dataset + dataset: + _target_: atomworks.ml.datasets.PandasDataset + name: rna_monomer_distillation + id_column: example_id + data: /projects/ml/afavor/rna_distillation/rna_distillation_filtered_df.parquet + columns_to_load: + - example_id + - path + - cluster_id + - seq_hash + - overall_plddt + - overall_pde + - overall_pae + + transform: + crop_contiguous_probability: 0.67 + crop_spatial_probability: 0.33 + diff --git a/models/rfd3/configs/datasets/val/design_validation_base.yaml b/models/rfd3/configs/datasets/val/design_validation_base.yaml index 5aabcc07..90931e09 100644 --- a/models/rfd3/configs/datasets/val/design_validation_base.yaml +++ b/models/rfd3/configs/datasets/val/design_validation_base.yaml @@ -37,4 +37,5 @@ dataset: # Other dataset-specific parameters atom_1d_features: ${model.net.token_initializer.atom_1d_features} - token_1d_features: ${model.net.token_initializer.token_1d_features} \ No newline at end of file + token_1d_features: ${model.net.token_initializer.token_1d_features} + token_2d_features: ${model.net.token_initializer.token_2d_features} \ No newline at end of file diff --git a/models/rfd3/configs/datasets/val/pseudoknot.yaml b/models/rfd3/configs/datasets/val/pseudoknot.yaml new file mode 100644 index 00000000..7cf5ce14 --- /dev/null +++ b/models/rfd3/configs/datasets/val/pseudoknot.yaml @@ -0,0 +1,9 @@ + +defaults: + - design_validation_base + - _self_ + +dataset: + name: pseudoknot + eval_every_n: 1 + data: ${paths.data.design_benchmark_data_dir}/pseudoknot.json diff --git a/models/rfd3/configs/experiment/rfd3na_fine_tune.yaml b/models/rfd3/configs/experiment/rfd3na_fine_tune.yaml new file mode 100644 index 00000000..3f2dfbe1 --- /dev/null +++ b/models/rfd3/configs/experiment/rfd3na_fine_tune.yaml @@ -0,0 +1,105 @@ +# @package _global_ +# Training configuration for RFD3 + +defaults: + #- /debug/default + - override /model: rfd3_base + - override /logger: wandb + - override /datasets: design_base_rfd3na + - _self_ + +name: rfd3na-fine-tune-ss0.5 +tags: [print-model] +ckpt_path: null + +model: + net: + token_initializer: + token_1d_features: + ref_motif_token_type: 3 + restype: 32 + is_dna_token: 1 + is_rna_token: 1 + is_protein_token: 1 + token_2d_features: + bp_partners: 3 # Unspecified, pair, loop + atom_1d_features: + ref_atom_name_chars: 256 + ref_element: 128 + ref_charge: 1 + ref_mask: 1 + ref_is_motif_atom_with_fixed_coord: 1 + ref_is_motif_atom_unindexed: 1 + has_zero_occupancy: 1 + ref_pos: 3 + + # Guided features + ref_atomwise_rasa: 3 + active_donor: 1 + active_acceptor: 1 + is_atom_level_hotspot: 1 + diffusion_module: + n_recycle: 2 + use_local_token_attention: True + diffusion_transformer: + n_local_tokens: 32 + n_keys: 128 + + inference_sampler: + num_timesteps: 100 + + +datasets: + diffusion_batch_size_train: 16 + crop_size: 256 + max_atoms_in_crop: 2560 # ~10x crop size. + global_transform_args: + meta_conditioning_probabilities: + p_is_nucleic_ss_example: 0.5 + p_nucleic_ss_show_partial_feats: 0.7 + p_canonical_bp_filter: 0.2 + #calculate_NA_SS: 0.3 + + association_scheme: atom23 + #add_na_pair_features: true + train_conditions: + unconditional: + frequency: 2.0 + island: + frequency: 2.0 + sequence_design: + frequency: 0.5 + tipatom: + frequency: 5.0 + ppi: + frequency: 0.0 + train: + # These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data. + pdb: + probability: 0.6 + sub_datasets: + pn_unit: + weights: + alphas: + a_nuc: 3.0 + interface: + weights: + alphas: + a_nuc: 3.0 + rna_monomer_distillation: + probability: 0.3 + monomer_distillation: + probability: 0.1 + + val: + pseudoknot: + dataset: + # eval_every_n: 10 + eval_every_n: 5 + +trainer: + #devices_per_node: 1 + #limit_train_batches: 10 + #limit_val_batches: 1 + validate_every_n_epochs: 5 + prevalidate: true diff --git a/models/rfd3/configs/inference_engine/rfdiffusion3.yaml b/models/rfd3/configs/inference_engine/rfdiffusion3.yaml index c5a6a72b..fe4a00fb 100644 --- a/models/rfd3/configs/inference_engine/rfdiffusion3.yaml +++ b/models/rfd3/configs/inference_engine/rfdiffusion3.yaml @@ -28,6 +28,7 @@ inference_sampler: - active_donor - active_acceptor - ref_atomwise_rasa + - bp_partners use_classifier_free_guidance: False cfg_t_max: null # max t to apply cfg guidance diff --git a/models/rfd3/configs/model/components/rfd3_net.yaml b/models/rfd3/configs/model/components/rfd3_net.yaml index 40334833..83cddc4b 100644 --- a/models/rfd3/configs/model/components/rfd3_net.yaml +++ b/models/rfd3/configs/model/components/rfd3_net.yaml @@ -25,6 +25,9 @@ token_initializer: # formerly known as the trunk ref_plddt: 1 is_non_loopy: 1 + # Optional 2D token feature definitions (empty by default) + token_2d_features: {} + downcast: ${model.net.diffusion_module.downcast} atom_1d_features: ref_atom_name_chars: 256 diff --git a/models/rfd3/configs/trainer/metrics/design_metrics.yaml b/models/rfd3/configs/trainer/metrics/design_metrics.yaml index 2a456051..c0bde00d 100644 --- a/models/rfd3/configs/trainer/metrics/design_metrics.yaml +++ b/models/rfd3/configs/trainer/metrics/design_metrics.yaml @@ -20,3 +20,16 @@ hbond_metrics: _target_: rfd3.metrics.hbonds_hbplus_metrics.HbondMetrics cutoff_HA_dist: 3 cutoff_DA_distance: 3.5 + +nucleic_ss_similarity: + _target_: rfd3.metrics.nucleic_ss_metrics.NucleicSSSimilarityMetrics + restrict_to_nucleic: True + compute_for_diffused_region_only: False + annotate_predicted_fresh: True + annotation_NA_only: False + annotation_planar_only: True + +rna_aptamer_contacts: + _target_: rfd3.metrics.rna_aptamer_metrics.LigandContactMetrics + restrict_to_nucleic: True + diff --git a/models/rfd3/docs/.assets/multipolymer.png b/models/rfd3/docs/.assets/multipolymer.png new file mode 100644 index 00000000..f081043a Binary files /dev/null and b/models/rfd3/docs/.assets/multipolymer.png differ diff --git a/models/rfd3/docs/.assets/overview.png b/models/rfd3/docs/.assets/overview.png index ee53c36a..6495cf69 100644 Binary files a/models/rfd3/docs/.assets/overview.png and b/models/rfd3/docs/.assets/overview.png differ diff --git a/models/rfd3/docs/examples/atom23_design.json b/models/rfd3/docs/examples/atom23_design.json new file mode 100644 index 00000000..4632fef8 --- /dev/null +++ b/models/rfd3/docs/examples/atom23_design.json @@ -0,0 +1,98 @@ +{ + "multipolymer": { + "contig": "40-50R,/0,10-20D,/0,80-110", + "length": "130-180", + "input": "../input_pdbs/AMP.pdb" + }, + "W05": { + "ss_dbn": ".(((((((((((((((((((..[[[[[[.)))))(((....)))(((....)))))))))))))))))((((((..]]]]]].)))))).", + "select_fixed_atoms": false, + "contig": "90-90R", + "length": "90-90", + "input": "../input_pdbs/AMP.pdb" + }, + "AMP_aptamer": { + "input": "../input_pdbs/AMP.pdb", + "ligand": "AMP", + "contig": "40-50R", + "length": "40-50", + "ori_jitter": 1, + "select_buried": {"AMP": "ALL"}, + "select_hbond_acceptor": { + "AMP": "N7,O4',O1P,O2P,O3P,N3,N1" + }, + "select_hbond_donor": { + "AMP": "N6,O3',O2'" + } + }, + "FMN_aptamer": { + "input": "../input_pdbs/FMN_3x21.pdb", + "ligand": "FMN", + "contig": "40-50R", + "length": "40-50", + "ori_jitter": 1, + "select_buried": {"FMN": "ALL"}, + "select_hbond_acceptor": { + "FMN": "O2,O4,N1,N5,O5',O2P,O3P" + }, + "select_hbond_donor": { + "FMN": "N3,O2',O3',O4'" + } + }, + "AMP_aptamer_noh": { + "input": "../input_pdbs/AMP.pdb", + "ligand": "AMP", + "contig": "40-50R", + "length": "40-50", + "ori_jitter": 1, + "select_buried": {"AMP": "ALL"} + }, + "FMN_aptamer_noh": { + "input": "../input_pdbs/FMN_3x21.pdb", + "ligand": "FMN", + "contig": "40-50R", + "length": "40-50", + "ori_jitter": 1, + "select_buried": {"FMN": "ALL"} + }, + "unindexed_rnasep": { + "input": "../input_pdbs/rnase_p_3q1q_active_site_small.pdb", + "contig": "50-80R,/0,100-120,/0,C1-4,C79-86", + "length": "162-212", + "ligand": "MG,PO4", + "unindex": "B49,B50,B51,B52,B321,/0,A56-58,/0", + "select_fixed_atoms": { + "B49": "ALL", + "B50": "ALL", + "B51": "ALL", + "B52": "ALL", + "B321": "ALL", + "A56-58": "ALL", + "C1-4": "ALL", + "C79-86": "ALL" + } + }, + "dict_input_ss": { + "ss_dbn_dict": { + "A6-25":"(((..)))....(((..)))", + "B1-20":"((((..))))...((...))" + }, + "contig":"30-30R,/0,30-30R", + "length":"60-60", + "input":"../input_pdbs/AMP.pdb" + }, + "paired_region_input_ss": { + "paired_region_list": ["A20-25,B10-15"], + "loop_region_list":["A10-19","B20-30"], + "contig":"50-50R,/0,50-50R", + "length":"100-100", + "input":"../input_pdbs/AMP.pdb" + }, + "paired_position_input_ss": { + "paired_position_list": ["A3,B3","A5,B5","A7,B7","A9,B9","A11,B11","A13,B13","A15,B15","A17,B17","A19,B19"], + "contig":"20-20R,/0,20-20R", + "length":"40-40", + "input":"../input_pdbs/AMP.pdb" + + } +} diff --git a/models/rfd3/docs/examples/atom23_design.md b/models/rfd3/docs/examples/atom23_design.md new file mode 100644 index 00000000..91f54ebc --- /dev/null +++ b/models/rfd3/docs/examples/atom23_design.md @@ -0,0 +1,228 @@ +# RNA / DNA Design in RFdiffusion3 + +This guide describes extensions to RFdiffusion3 for nucleic acid and hybrid RNA–protein design, including: + +- RNA/DNA-aware contigs (`R` / `D` suffix) +- Ligand-conditioned aptamer design +- Secondary structure (SS) conditioning +- Base-pair constraints (region- and position-level) +- Partial structure fixing and unindexing + +--- + +## 1. Contig Syntax for RNA/DNA + +Contigs now support nucleic acid specification: + +- `R` → RNA segment +- `D` → DNA segment +- No suffix → protein (default) + +### Example + +```json +{ + "contig": "40-50R,/0,10-20D,/0,80-110" +} +``` +This corresponds to: 40–50 nt RNA, chain break, 10–20 nt DNA, chain break, 80–110 aa protein + +Multipolymer Design + +```json + +{ + "multipolymer": { + "contig": "40-50R,/0,10-20D,/0,80-110", + "length": "130-180", + "input": "../input_pdbs/AMP.pdb" + } +} +``` + +## 2. Secondary Structure Conditioning +### 2.1 Dot-Bracket Notation (Global) +```json +{ + "W05": { + "ss_dbn": ".(((((((((((((((((((..[[[[[[.)))))(((....)))(((....)))))))))))))))))((((((..]]]]]].)))))).", + "select_fixed_atoms": false, + "contig": "90-90R", + "length": "90-90", + "input": "../input_pdbs/AMP.pdb" + } +} +``` +`ss_dbn` specifies full RNA secondary structure + +Will be applied to the first L tokens, where L is the length of `ss_dbn`. + +### 2.2 Dictionary-Based SS Input + +Specify secondary structure for subsections: +``` json +{ + "ss_dbn_dict": { + "A6-25": "(((..)))....(((..)))", + "B1-20": "((((..))))...((...))" + } +} +``` +Used in: +``` json +{ + "dict_input_ss": { + "ss_dbn_dict": { + "A6-25": "(((..)))....(((..)))", + "B1-20": "((((..))))...((...))" + }, + "contig": "30-30R,/0,30-30R", + "length": "60-60", + "input": "../input_pdbs/AMP.pdb" + } +} +``` +## 3. Base Pair region Conditioning +### 3.1 Paired Regions + +Define paired and loop regions: +```json +{ + "paired_region_list": ["A20-25,B10-15"], + "loop_region_list": ["A10-19","B20-30"] +} +``` +Enforces pairing and loop propensity between residue ranges during sampling + +Used in: +```json +{ + "paired_region_input_ss": { + "paired_region_list": ["A20-25,B10-15"], + "loop_region_list": ["A10-19","B20-30"], + "contig": "50-50R,/0,50-50R", + "length": "100-100", + "input": "../input_pdbs/AMP.pdb" + } +} +``` + +### 3.2 Explicit Base Pair Positions + +Fine-grained base pairing control: + +```json +{ + "paired_position_list": [ + "A3,B3","A5,B5","A7,B7","A9,B9","A11,B11", + "A13,B13","A15,B15","A17,B17","A19,B19" + ] +} +``` +Used in: +```json +{ + "paired_position_input_ss": { + "paired_position_list": [ + "A3,B3","A5,B5","A7,B7","A9,B9","A11,B11", + "A13,B13","A15,B15","A17,B17","A19,B19" + ], + "contig": "20-20R,/0,20-20R", + "length": "40-40", + "input": "../input_pdbs/AMP.pdb" + } +} +``` +### Note: Most of the above jsons is not actually reading the `input` field. Kept as a dummy for the `inference_engine`. + +## 4. Ligand-Conditioned Aptamer Design + +Supports small molecule binding RNA design. + +AMP Aptamer Example +```json +{ + "AMP_aptamer": { + "input": "../input_pdbs/AMP.pdb", + "ligand": "AMP", + "contig": "40-50R", + "length": "40-50", + "ori_jitter": 1, + "select_buried": {"AMP": "ALL"}, + "select_hbond_acceptor": { + "AMP": "N7,O4',O1P,O2P,O3P,N3,N1" + }, + "select_hbond_donor": { + "AMP": "N6,O3',O2'" + } + } +} +``` +Key Options + +`ligand`: ligand name in the input PDB + +`select_buried`: enforce burial of ligand atoms + +`select_hbond_acceptor` / `select_hbond_donor`: suggest Hbond interaction atoms + +`ori_jitter`: small random perturbation of ori token (from ligand COM) + + +## 5. Hybrid RNA–Protein Design with Constraints +### RNase P Active Site Example + +```json +{ + "unindexed_rnasep": { + "input": "../input_pdbs/rnase_p_3q1q_active_site_small.pdb", + "contig": "50-80R,/0,100-120,/0,C1-4,C79-86", + "length": "162-212", + "ligand": "MG,PO4", + "unindex": "B49,B50,B51,B52,B321,/0,A56-58,/0", + "select_fixed_atoms": { + "B49": "ALL", + "B50": "ALL", + "B51": "ALL", + "B52": "ALL", + "B321": "ALL", + "A56-58": "ALL", + "C1-4": "ALL", + "C79-86": "ALL" + } + } +} +``` +Key Features + +Mixed RNA + protein + fixed fragments + +`unindex`: removes residues from positional indexing + +`select_fixed_atoms`: freezes specified atoms + +Ligands (MG, PO4) included in design context + +Useful for catalytic residues or structural motifs + +## 7. Summary of Features Used + +R / D suffix → RNA / DNA specification in contigs + +`ss_dbn` → global secondary structure constraint + +`ss_dbn_dict` → local secondary structure constraints + +`paired_region_list` → helix-level pairing constraints + +`paired_position_list` → base-level pairing constraints + +ligand + selection options → aptamer design + +`unindex` → remove residues from indexing + +`select_fixed_atoms` → freeze structural elements + + +--- + diff --git a/models/rfd3/docs/input.md b/models/rfd3/docs/input.md index 9253abc5..384101c6 100644 --- a/models/rfd3/docs/input.md +++ b/models/rfd3/docs/input.md @@ -120,7 +120,7 @@ Below is a table of all of the inputs that the `InputSpecification` accepts. Use | -------------------------------------------------------------- | ----------------- | --------------------------------------------------------------------- | | `input` | `str` | Path to and file name of **PDB/CIF**. Required if you provide contig+length. | | `atom_array_input` | internal | Pre-loaded [`AtomArray`](https://www.biotite-python.org/latest/apidoc/biotite.structure.AtomArray.html) (not recommended). | -| `contig` | `InputSelection` | (Can only pass a contig string.) Indexed motif specification, e.g., `"A1-80,10,/0,B5-12"`. | +| `contig` | `InputSelection` | (Can only pass a contig string.) Indexed motif specification, e.g., `"A1-80,10,/0,B5-12"`. When running inference with an atom223 checkpoint, contigs can include `R` or `D` suffixes to denote designed polymer type is RNA or DNA e.g. "A1-80,10-10R,20-20D,30-30,/0,B5-12" (designs a 10 length RNA chain, 20 length DNA chain and a 30 length protein chain). | | `unindex` | `InputSelection` | (Can only pass a contig string or dictionary.) Unindexed motif components, the specified residues can be anywhere in the final sequence. See [Unindexing Specifics](#unindexing-specifics) for more information. | | `length` | `str` | Total design length constraint; `"min-max"` or int for specified length. | | `ligand` | `str` | Ligand(s) by chemical component name (from [RSCB PDB](https://www.rcsb.org/)) or index. | @@ -136,6 +136,7 @@ Below is a table of all of the inputs that the `InputSpecification` accepts. Use | `symmetry` | `SymmetryConfig` | See {doc}`examples/symmetry`. | | `ori_token` | `list[float]` | `[x,y,z]` origin override to control COM (center of mass) placement of designed structure. | | `infer_ori_strategy` | `str` | `"com"` or `"hotspots"`. The center of mass of the diffused region will typically be within 5Å of the ORI token. Using `hotspots` will place the ORI token 10Å outward from the center of mass of the specified hotspots. Using `com` will place the token at the center of mass of the input structure.| +| `ori_jitter` | `float` | default `None`. Per batch, move the ori token in a random direction by a distance sampled from a geomtric distribution with the mean as the specified float value in Angstrom.| | `plddt_enhanced` | `bool` | Default `True`. Enables pLDDT (predicted Local Distance Difference Test) enhancement. | | `is_non_loopy` | `bool \| None` | Default `None`. If `True`/`False`, produces output structures with fewer/more loops.| | `partial_t` | `float` | Noise (Å) for partial diffusion, enables partial diffusion (sets the noise level.) Recommended values are 5.0-15.0 Å. See [Partial Diffusion](#partial-diffusion) for more information. | diff --git a/models/rfd3/src/rfd3/constants.py b/models/rfd3/src/rfd3/constants.py index 59c387c3..5407e681 100644 --- a/models/rfd3/src/rfd3/constants.py +++ b/models/rfd3/src/rfd3/constants.py @@ -72,6 +72,37 @@ 'GLY': (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly 'UNK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk 'MSK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask + 'DA': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", + ' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 ', + None), + + 'DC': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", + ' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 ', + None, None, None), + + 'DG': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", + ' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '), + + 'DT': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", + ' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C7 ', ' C6 ', + None, None), + + 'A' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", + ' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 ', + None), + + 'C' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", + ' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 ', + None, None, None), + + 'G' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", + ' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '), + + 'U' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", + ' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C6 ', + None, None, None), + 'DX': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", None, None, None, None, None, None, None, None, None, None, None), #dna_mask + 'X': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", None, None, None, None, None, None, None, None, None, None, None), #rna mask } """Canonical ordering of amino acid atom names in the CCD.""" @@ -195,10 +226,7 @@ strip_list = lambda x: [(x.strip() if x is not None else None) for x in x] # noqa -association_schemes_stripped = { - name: {k: strip_list(v) for k, v in scheme.items()} - for name, scheme in association_schemes.items() -} + SELECTION_PROTEIN = ["POLYPEPTIDE(D)", "POLYPEPTIDE(L)"] SELECTION_NONPROTEIN = [ "POLYDEOXYRIBONUCLEOTIDE", @@ -210,3 +238,461 @@ "MACROLIDE", "POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID", ] + +backbone_atomscheme_DNA = [ + " P ", + " OP1", + " OP2", + " O5'", + " C5'", + " C4'", + " O4'", + " C3'", + " O3'", + " C2'", + " C1'", +] # , None] + +backbone_atomscheme_RNA = [ + " P ", + " OP1", + " OP2", + " O5'", + " C5'", + " C4'", + " O4'", + " C3'", + " O3'", + " C2'", + " O2'", + " C1'", +] + +DNA_atoms = { + "DA": [ + " N9 ", + " C8 ", + " N7 ", + " C5 ", + " C6 ", + " N6 ", + " N1 ", + " C2 ", + " N3 ", + " C4 ", + ], + "DC": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " N4 ", " C5 ", " C6 "], + "DG": [ + " N9 ", + " C8 ", + " N7 ", + " C5 ", + " C6 ", + " O6 ", + " N1 ", + " C2 ", + " N2 ", + " N3 ", + " C4 ", + ], + "DT": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " O4 ", " C5 ", " C7 ", " C6 "], +} + +RNA_atoms = { + "A": [ + " N9 ", + " C8 ", + " N7 ", + " C5 ", + " C6 ", + " N6 ", + " N1 ", + " C2 ", + " N3 ", + " C4 ", + ], + "C": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " N4 ", " C5 ", " C6 "], + "G": [ + " N9 ", + " C8 ", + " N7 ", + " C5 ", + " C6 ", + " O6 ", + " N1 ", + " C2 ", + " N2 ", + " N3 ", + " C4 ", + ], + "U": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " O4 ", " C5 ", " C6 "], +} + +association_schemes["atom23"] = {} +for item in DNA_atoms: + association_schemes["atom23"][item] = tuple( + backbone_atomscheme_DNA + + DNA_atoms[item] + + [None] * (22 - len(DNA_atoms[item] + backbone_atomscheme_DNA)) + ) +for item in RNA_atoms: + association_schemes["atom23"][item] = tuple( + backbone_atomscheme_RNA + + RNA_atoms[item] + + [None] * (23 - len(RNA_atoms[item] + backbone_atomscheme_RNA)) + ) + +for item in association_schemes["dense"]: + association_schemes["atom23"][item] = association_schemes["dense"][item] + +association_schemes["atom23"]["DX"] = ( + " P ", + " OP1", + " OP2", + " O5'", + " C5'", + " C4'", + " O4'", + " C3'", + " O3'", + " C2'", + " C1'", + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, +) # rna_mask +association_schemes["atom23"]["X"] = ( + " P ", + " OP1", + " OP2", + " O5'", + " C5'", + " C4'", + " O4'", + " C3'", + " O3'", + " C2'", + " O2'", + " C1'", + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, +) # rna mask + +ATOM23_ATOM_NAMES_RNA = np.array( + [item.strip() for item in backbone_atomscheme_RNA] + + [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))] +) +"""Atom23 atom names (e.g. CA, V1)""" + +ATOM23_ATOM_ELEMENTS_RNA = np.array( + ["P", "O", "O", "O", "C", "C", "O", "C", "O", "C", "O", "C"] + + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(23 - len(backbone_atomscheme_RNA))] +) +"""Atom23 element names (e.g. C, VX)""" + +ATOM23_ATOM_NAME_TO_ELEMENT = { + name: elem for name, elem in zip(ATOM23_ATOM_NAMES_RNA, ATOM23_ATOM_ELEMENTS_RNA) +} +ATOM23_ATOM_NAMES_DNA = np.array( + [item.strip() for item in backbone_atomscheme_DNA] + + [f"V{i}" for i in range(22 - len(backbone_atomscheme_DNA))] +) +"""Atom23 atom names (e.g. CA, V1)""" + +ATOM23_ATOM_ELEMENTS_DNA = np.array( + ["P", "O", "O", "O", "C", "C", "O", "C", "O", "C", "C"] + + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(22 - len(backbone_atomscheme_DNA))] +) +"""Atom23 element names (e.g. C, VX)""" + + +"""Mapping from atom14 atom names (e.g. CA, V1) to their corresponding element names (e.g. C, VX)""" +## combining name to element mapping, should be fine +for item in ATOM14_ATOM_NAME_TO_ELEMENT: + ATOM23_ATOM_NAME_TO_ELEMENT[item] = ATOM14_ATOM_NAME_TO_ELEMENT[item] + +association_schemes_stripped = { + name: {k: strip_list(v) for k, v in scheme.items()} + for name, scheme in association_schemes.items() +} + +backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA) +backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA) + +# Mapping from residue type to its backbone and sidechain atoms (for convenience) +ATOM_REGION_BY_RESI = { + "ALA": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")}, + "ARG": { + "bb": ("N", "CA", "C", "O"), + "sc": ("CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"), + }, + "ASN": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "OD1", "ND2")}, + "ASP": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "OD1", "OD2")}, + "CYS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "SG")}, + "GLN": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "OE1", "NE2")}, + "GLU": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "OE1", "OE2")}, + "GLY": {"bb": ("N", "CA", "C", "O"), "sc": ()}, + "HIS": { + "bb": ("N", "CA", "C", "O"), + "sc": ("CB", "CG", "ND1", "CD2", "CE1", "NE2"), + }, + "ILE": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG1", "CG2", "CD1")}, + "LEU": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD1", "CD2")}, + "LYS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "CE", "NZ")}, + "MET": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "SD", "CE")}, + "PHE": { + "bb": ("N", "CA", "C", "O"), + "sc": ("CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"), + }, + "PRO": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD")}, + "SER": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "OG")}, + "THR": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "OG1", "CG2")}, + "TRP": { + "bb": ("N", "CA", "C", "O"), + "sc": ("CB", "CG", "CD1", "CD2", "CE2", "CE3", "NE1", "CZ2", "CZ3", "CH2"), + }, + "TYR": { + "bb": ("N", "CA", "C", "O"), + "sc": ("CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"), + }, + "VAL": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG1", "CG2")}, + "UNK": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")}, + "MAS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")}, + "DA": { + "bb": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + ), + "sc": ("N9", "C4", "N3", "C2", "N1", "C6", "C5", "N7", "C8", "N6"), + }, + "DC": { + "bb": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + ), + "sc": ("N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"), + }, + "DG": { + "bb": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + ), + "sc": ("N9", "C4", "N3", "C2", "N1", "C6", "C5", "N7", "C8", "N2", "O6"), + }, + "DT": { + "bb": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + ), + "sc": ("N1", "C2", "O2", "N3", "C4", "O4", "C5", "C7", "C6"), + }, + "DX": { + "bb": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + ), + "sc": (), + }, + "A": { + "bb": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "O2'", + ), + "sc": ("N1", "C2", "N3", "C4", "C5", "C6", "N6", "N7", "C8", "N9"), + }, + "C": { + "bb": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "O2'", + ), + "sc": ("N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"), + }, + "G": { + "bb": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "O2'", + ), + "sc": ("N1", "C2", "N2", "N3", "C4", "C5", "C6", "O6", "N7", "C8", "N9"), + }, + "U": { + "bb": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "O2'", + ), + "sc": ("N1", "C2", "O2", "N3", "C4", "O4", "C5", "C6"), + }, + "X": { + "bb": ( + "O4'", + "C1'", + "C2'", + "OP1", + "P", + "OP2", + "O5'", + "C5'", + "C4'", + "C3'", + "O3'", + "O2'", + ), + "sc": (), + }, + "HIS_D": { + "bb": ("N", "CA", "C", "O"), + "sc": ("CB", "CG", "NE2", "CD2", "CE1", "ND1"), + }, +} +# Known planar sidechain atoms for each canonical residue type: +PLANAR_ATOMS_BY_RESI = { + "ALA": [], + "ARG": ["NH1", "NH2", "CZ", "NE", "CD"], + "ASN": ["OD1", "ND2", "CG", "CB"], + "ASP": ["OD1", "OD2", "CG", "CB"], + "CYS": [], + "GLN": ["OE1", "NE2", "CD", "CG"], + "GLU": ["OE1", "OE2", "CD", "CG"], + "GLY": [], + "HIS": ["ND1", "CE1", "NE2", "CD2", "CG", "CB"], + "ILE": [], + "LEU": [], + "LYS": [], + "MET": [], + "PHE": ["CZ", "CE1", "CE2", "CD1", "CD2", "CG", "CB"], + "PRO": [], + "SER": [], + "THR": [], + "TRP": ["CH2", "CZ3", "CZ2", "CE3", "CE2", "CD2", "NE1", "CD1", "CG", "CB"], + "TYR": ["OH", "CZ", "CE1", "CE2", "CD1", "CD2", "CG", "CB"], + "VAL": [], + "UNK": [], + "MAS": [], + "DA": ["N6", "C6", "N1", "C2", "N3", "C4", "C5", "N7", "C8", "N9"], + "DC": ["N4", "C4", "N3", "O2", "C2", "C5", "C6", "N1"], + "DG": ["O6", "C6", "N1", "N2", "C2", "N3", "C4", "C5", "N7", "C8", "N9"], + "DT": ["O4", "O2", "N3", "C4", "C2", "C5", "C6", "N1", "C7"], + "DX": [], + "A": ["N6", "C6", "N1", "C2", "N3", "C4", "C5", "N7", "C8", "N9"], + "C": ["N4", "C4", "N3", "O2", "C2", "C5", "C6", "N1"], + "G": ["O6", "C6", "N1", "N2", "C2", "N3", "C4", "C5", "N7", "C8", "N9"], + "U": ["O4", "O2", "N3", "C4", "C2", "C5", "C6", "N1"], + "X": [], + "HIS_D": ["ND1", "CD2", "CE1", "NE2", "CG", "CB"], +} + +# fix C/U symmetry +temp = list(association_schemes["atom23"]["U"]) +temp[19], temp[20] = temp[20], temp[19] +association_schemes["atom23"]["U"] = tuple(temp) + +association_schemes_stripped = { + name: {k: strip_list(v) for k, v in scheme.items()} + for name, scheme in association_schemes.items() +} + +if __name__ == "__main__": + import pdb + + pdb.set_trace() diff --git a/models/rfd3/src/rfd3/inference/input_parsing.py b/models/rfd3/src/rfd3/inference/input_parsing.py index d97b3be3..8293ca74 100644 --- a/models/rfd3/src/rfd3/inference/input_parsing.py +++ b/models/rfd3/src/rfd3/inference/input_parsing.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Union import numpy as np -from atomworks.constants import STANDARD_AA +from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.io.parser import parse_atom_array # from atomworks.ml.datasets.datasets import BaseDataset @@ -30,9 +30,12 @@ OPTIONAL_CONDITIONING_VALUES, REQUIRED_CONDITIONING_ANNOTATION_VALUES, REQUIRED_INFERENCE_ANNOTATIONS, + backbone_atoms_DNA, + backbone_atoms_RNA, ) from rfd3.inference.legacy_input_parsing import ( create_atom_array_from_design_specification_legacy, + reorder_atoms_per_residue, ) from rfd3.inference.parsing import InputSelection from rfd3.inference.symmetry.symmetry_utils import ( @@ -67,7 +70,6 @@ logger = RankedLogger(__name__, rank_zero_only=True) - ################################################################################# # Custom infer_ori functions ################################################################################# @@ -117,7 +119,9 @@ class DesignInputSpecification(BaseModel): validate_assignment=False, str_strip_whitespace=True, str_min_length=1, - extra="forbid", + # extra="forbid", #################################################### + extra="allow", + ## for now allowing extra for rfd3na-ss purposes, can decide later ## ) # fmt: off # ======================================================================== @@ -180,6 +184,7 @@ class DesignInputSpecification(BaseModel): symmetry: Optional[SymmetryConfig] = Field(None, description="Symmetry specification, see docs/symmetry.md") # Centering & COM guidance ori_token: Optional[list[float]] = Field(None, description="Origin coordinates") + ori_jitter: Optional[float] = Field(None, description="Jitter ori in a random direction and use ori_jitter to sample distance via exponential distribution") infer_ori_strategy: Optional[str] = Field(None, description="Strategy for inferring origin; `com` or `hotspots`") # Additional global conditioning plddt_enhanced: Optional[bool] = Field(True, description="Enable pLDDT enhancement") @@ -492,6 +497,15 @@ def apply_selections(start, end): aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like( is_bkbn, False, dtype=int ) + elif ( + aa.res_name[start] in (STANDARD_DNA + STANDARD_RNA) + and self.redesign_motif_sidechains + ): + is_bkbn = np.isin(aa.atom_name[start:end], backbone_atoms_RNA) + aa.is_motif_atom_with_fixed_coord[start:end] = is_bkbn.astype(int) + aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like( + is_bkbn, False, dtype=int + ) # ... Apply selections on top apply_selections(start, end) @@ -505,6 +519,24 @@ def apply_selections(start, end): def build(self, return_metadata=False): """Main build pipeline.""" atom_array_input_annotated = copy.deepcopy(self.atom_array_input) + + ########## reorder NA atoms ########### + if exists(atom_array_input_annotated): + is_dna = np.isin( + atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"] + ) + is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"]) + dna_array = atom_array_input_annotated[is_dna] + rna_array = atom_array_input_annotated[is_rna] + + atom_array_input_annotated[is_dna] = reorder_atoms_per_residue( + dna_array, backbone_atoms_DNA + ) + atom_array_input_annotated[is_rna] = reorder_atoms_per_residue( + rna_array, backbone_atoms_RNA + ) + ####################################### + atom_array = self._build_init(atom_array_input_annotated) # Apply post-processing @@ -699,6 +731,13 @@ def _append_ligand(self, atom_array, atom_array_input_annotated): ligand_array.set_annotation( annot, np.full(ligand_array.array_length(), default) ) + + chain_cand = "X" + while chain_cand in atom_array.chain_id.tolist(): + chain_cand = chain_cand + chain_cand + ligand_chain = np.array([chain_cand] * len(ligand_array)) + ligand_array.chain_id = ligand_chain + atom_array = atom_array + ligand_array return atom_array @@ -723,8 +762,13 @@ def _set_origin(self, atom_array): "Partial diffusion with symmetry: skipping COM centering to preserve chain spacing" ) else: + if not exists(self.ori_jitter): + self.ori_jitter = None atom_array = set_com( - atom_array, ori_token=None, infer_ori_strategy="com" + atom_array, + ori_token=None, + infer_ori_strategy="com", + ori_jitter=self.ori_jitter, ) else: # Standard: set ori token, zero out diffused atoms @@ -894,35 +938,57 @@ def validator_context(validator_name: str, data: dict = None): raise e -def create_diffused_residues(n, additional_annotations=None): +def create_diffused_residues(n, additional_annotations=None, polymer_type="P"): + from rfd3.constants import ( + ATOM23_ATOM_NAME_TO_ELEMENT, + backbone_atoms_DNA, + backbone_atoms_RNA, + ) + if n <= 0: raise ValueError(f"Negative/null residue count ({n}) not allowed.") + if polymer_type == "P": + res_name = "ALA" + bb_len = 5 + bb_atom_names = ["N", "CA", "C", "O", "CB"] + elif polymer_type == "R": + res_name = "A" + bb_len = len(backbone_atoms_RNA) + bb_atom_names = backbone_atoms_RNA + elif polymer_type == "D": + res_name = "DA" + bb_len = len(backbone_atoms_DNA) + bb_atom_names = backbone_atoms_DNA + else: + raise ValueError( + f"invalid polymer type detected: {polymer_type}, check contig!" + ) + + bb_elements = [ATOM23_ATOM_NAME_TO_ELEMENT[item] for item in bb_atom_names] + atoms = [] [ atoms.extend( [ struc.Atom( np.array([0.0, 0.0, 0.0], dtype=np.float32), - res_name="ALA", + res_name=res_name, res_id=idx, ) - for _ in range(5) + for _ in range(bb_len) ] ) for idx in range(1, n + 1) ] array = struc.array(atoms) - array.set_annotation( - "element", np.array(["N", "C", "C", "O", "C"] * n, dtype=" AtomArray: + """ + Reorder atoms within each residue of an AtomArray. + Atoms in `desired_order` appear first (in that order), followed by all others + in original order. Faster version using get_residue_starts(). + + Parameters: + - atom_array: AtomArray to reorder. + - desired_order: List of atom names in the desired per-residue order. + + Returns: + - AtomArray with reordered atoms per residue. + """ + if len(atom_array) == 0: + return atom_array + res_starts = get_residue_starts(atom_array) + res_starts = np.append(res_starts, len(atom_array)) # add end index for slicing + reordered_chunks = [] + order_dict = {name: i for i, name in enumerate(desired_order)} + + for i in range(len(res_starts) - 1): + start, end = res_starts[i], res_starts[i + 1] + residue = atom_array[start:end] + + # Boolean masks for matching and non-matching atom names + in_order_mask = np.isin(residue.atom_name, desired_order) + not_in_order_mask = ~in_order_mask + + # Sort matching atoms by desired order + atoms_in_order = residue[in_order_mask] + sort_idx = np.argsort([order_dict[name] for name in atoms_in_order.atom_name]) + ordered_atoms = atoms_in_order[sort_idx] + + # Remaining atoms as-is + remaining_atoms = residue[not_in_order_mask] + + # Concatenate reordered residue + reordered_chunks.append(concatenate([ordered_atoms, remaining_atoms])) + return concatenate(reordered_chunks) diff --git a/models/rfd3/src/rfd3/metrics/design_metrics.py b/models/rfd3/src/rfd3/metrics/design_metrics.py index 5ac24fbc..1ac0ad51 100644 --- a/models/rfd3/src/rfd3/metrics/design_metrics.py +++ b/models/rfd3/src/rfd3/metrics/design_metrics.py @@ -4,6 +4,7 @@ get_token_starts, ) from beartype.typing import Any +from rfd3.constants import backbone_atoms_RNA from rfd3.metrics.metrics_utils import ( _flatten_dict, get_hotspot_contacts, @@ -14,6 +15,7 @@ from foundry.metrics.metric import Metric STANDARD_CACA_DIST = 3.8 +STANDARD_P_P_DISTANCE = 6.4 ## average of B and A form 7 and 5.9 def get_clash_metrics( @@ -28,11 +30,17 @@ def get_clash_metrics( ) def get_chainbreaks(): - ca_atoms = atom_array[atom_array.atom_name == "CA"] + if "CA" in atom_array.atom_name: + ca_atoms = atom_array[atom_array.atom_name == "CA"] + cut_off = STANDARD_CACA_DIST + elif "P" in atom_array.atom_name: + ca_atoms = atom_array[atom_array.atom_name == "P"] + cut_off = STANDARD_P_P_DISTANCE + xyz = ca_atoms.coord xyz = torch.from_numpy(xyz) ca_dists = torch.norm(xyz[1:] - xyz[:-1], dim=-1) - deviation = torch.abs(ca_dists - STANDARD_CACA_DIST) + deviation = torch.abs(ca_dists - cut_off) # Allow leniency for expected chain breaks (e.g. PPI) chain_breaks = ca_atoms.chain_iid[1:] != ca_atoms.chain_iid[:-1] @@ -45,7 +53,9 @@ def get_chainbreaks(): } def get_interresidue_clashes(backbone_only=False): - protein_array = atom_array[atom_array.is_protein] + protein_array = atom_array[ + atom_array.is_protein | atom_array.is_dna | atom_array.is_rna + ] resid = protein_array.res_id - protein_array.res_id.min() xyz = protein_array.coord dists = np.linalg.norm(xyz[:, None] - xyz[None], axis=-1) # N_atoms x N_atoms @@ -58,7 +68,9 @@ def get_interresidue_clashes(backbone_only=False): if backbone_only: # Block out non-backbone atoms - backbone_mask = np.isin(protein_array.atom_name, ["N", "CA", "C"]) + backbone_mask = np.isin( + protein_array.atom_name, ["N", "CA", "C"] + backbone_atoms_RNA + ) mask = backbone_mask[:, None] & backbone_mask[None, :] dists[~mask] = 999 @@ -291,6 +303,7 @@ def __init__(self, compute_for_diffused_region_only: bool = False): 3.0 # maximum closest-neighbour distance before considered a floating atom ) self.standard_ca_dist = 3.8 + self.standard_PP_dist = 6.4 self.compute_for_diffused_region_only = compute_for_diffused_region_only @property @@ -310,6 +323,8 @@ def compute(self, X_L, tok_idx, f): ) # N_atoms x N_atoms is_protein = f["is_protein"][tok_idx].cpu().numpy() # n_atoms + is_rna = f["is_rna"][tok_idx].cpu().numpy() + is_dna = f["is_dna"][tok_idx].cpu().numpy() mask = np.zeros_like(dists, dtype=bool) mask = mask | (np.eye(dists.shape[-1], dtype=bool))[None] @@ -362,19 +377,51 @@ def compute(self, X_L, tok_idx, f): if self.compute_for_diffused_region_only: is_ca = is_ca[diffused_region] is_protein = is_protein[diffused_region] - idx_mask = is_ca & is_protein + is_dna = is_dna[diffused_region] + is_rna = is_rna[diffused_region] + protein_idx_mask = is_ca & (is_protein) + na_idx_mask = is_ca & (is_rna | is_dna) + if self.compute_for_diffused_region_only: - xyz = X_L.cpu()[:, diffused_region][:, idx_mask] + xyz_protein = X_L.cpu()[:, diffused_region][:, protein_idx_mask] + xyz_na = X_L.cpu()[:, diffused_region][:, na_idx_mask] else: - xyz = X_L.cpu()[:, idx_mask] + xyz_protein = X_L.cpu()[:, protein_idx_mask] + xyz_na = X_L.cpu()[:, na_idx_mask] - ca_dists = torch.norm(xyz[:, 1:] - xyz[:, :-1], dim=-1) - deviation = torch.abs(ca_dists - self.standard_ca_dist) # B, (I-1) - is_chainbreak = deviation > 0.75 + ca_dists_protein = torch.norm( + xyz_protein[:, 1:] - xyz_protein[:, :-1], dim=-1 + ) + ca_dists_na = torch.norm(xyz_na[:, 1:] - xyz_na[:, :-1], dim=-1) + + deviation_protein = torch.abs( + ca_dists_protein - self.standard_ca_dist + ) # B, (I-1) + deviation_na = torch.abs(ca_dists_na - self.standard_PP_dist) # B, (I-1) + is_chainbreak_protein = deviation_protein > 0.75 + is_chainbreak_na = deviation_na > 1 + + try: + o["max_ca_deviation_protein"] = float( + deviation_protein.max(-1).values.mean() + ) + o["fraction_chainbreaks_protein"] = float( + is_chainbreak_protein.float().mean(-1).mean() + ) + o["n_chainbreaks_protein"] = float( + is_chainbreak_protein.float().sum(-1).mean() + ) + except Exception: + print("No protein in this example, skipping protein chainbreak metrics") - o["max_ca_deviation"] = float(deviation.max(-1).values.mean()) - o["fraction_chainbreaks"] = float(is_chainbreak.float().mean(-1).mean()) - o["n_chainbreaks"] = float(is_chainbreak.float().sum(-1).mean()) + try: + o["max_ca_deviation_na"] = float(deviation_na.max(-1).values.mean()) + o["fraction_chainbreaks_na"] = float( + is_chainbreak_na.float().mean(-1).mean() + ) + o["n_chainbreaks_na"] = float(is_chainbreak_na.float().sum(-1).mean()) + except Exception: + print("No NA in this example, skipping NA chainbreak metrics") return o diff --git a/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py new file mode 100644 index 00000000..53c21194 --- /dev/null +++ b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py @@ -0,0 +1,380 @@ +import logging + +import numpy as np +from atomworks.ml.utils.token import ( + get_token_starts, +) +from biotite.structure import AtomArray +from rfd3.trainer.trainer_utils import ( + _cleanup_virtual_atoms_and_assign_atom_name_elements, + _readout_seq_from_struc, +) +from rfd3.transforms.na_geom_utils import annotate_na_ss + +from foundry.metrics.metric import Metric +from foundry.utils.ddp import RankedLogger + +logging.basicConfig(level=logging.INFO) +global_logger = RankedLogger(__name__, rank_zero_only=False) + + +def _get_bp_partners_annotation(atom_array: AtomArray): + """Return bp-partners annotation.""" + categories = atom_array.get_annotation_categories() + if "bp_partners" in categories: + return atom_array.bp_partners + raise ValueError("atom_array missing bp_partners annotation") + + +def _safe_f1_from_sizes(intersection_n: int, pred_n: int, gt_n: int) -> float: + """Return F1 with sensible empty-set handling.""" + if pred_n == 0 and gt_n == 0: + return 1.0 + + precision = float(intersection_n / pred_n) if pred_n > 0 else 0.0 + recall = float(intersection_n / gt_n) if gt_n > 0 else 0.0 + + if precision + recall == 0.0: + return 0.0 + + return float(2.0 * precision * recall / (precision + recall)) + + +def _get_token_ids(atom_array: AtomArray) -> np.ndarray: + token_starts = get_token_starts(atom_array) + token_level_array = atom_array[token_starts] + return np.asarray(token_level_array.token_id, dtype=int) + + +def _get_candidate_token_ids( + atom_array: AtomArray, + *, + restrict_to_nucleic: bool, + compute_for_diffused_region_only: bool, +) -> set[int]: + """Return a set of token_ids to include for scoring.""" + token_starts = get_token_starts(atom_array) + token_level_array = atom_array[token_starts] + token_ids = np.asarray(token_level_array.token_id, dtype=int) + + token_mask = np.ones(len(token_ids), dtype=bool) + + if restrict_to_nucleic: + is_rna = ( + np.asarray(getattr(token_level_array, "is_rna"), dtype=bool) + if hasattr(token_level_array, "is_rna") + else np.zeros(len(token_ids), dtype=bool) + ) + is_dna = ( + np.asarray(getattr(token_level_array, "is_dna"), dtype=bool) + if hasattr(token_level_array, "is_dna") + else np.zeros(len(token_ids), dtype=bool) + ) + token_mask &= ( + (is_rna | is_dna) if (is_rna.any() or is_dna.any()) else token_mask + ) + + if compute_for_diffused_region_only: + if hasattr(token_level_array, "is_motif_atom"): + token_mask &= ~np.asarray(token_level_array.is_motif_atom, dtype=bool) + elif hasattr(token_level_array, "is_motif_token"): + token_mask &= ~np.asarray(token_level_array.is_motif_token, dtype=bool) + + return set(int(t) for t in token_ids[token_mask].tolist()) + + +def _extract_bp_pairs( + atom_array: AtomArray, + *, + allowed_token_ids: set[int], +) -> set[tuple[int, int]]: + """Extract unordered base-pair edges from bp-partner annotations. + + Pairs are represented as (min_token_id, max_token_id). + """ + token_starts = get_token_starts(atom_array) + token_level_array = atom_array[token_starts] + token_ids = np.asarray(token_level_array.token_id, dtype=int) + token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())} + + bp_partner_ann = _get_bp_partners_annotation(atom_array) + pairs: set[tuple[int, int]] = set() + + for pos, start_idx in enumerate(token_starts.tolist()): + i_tid = int(token_ids[pos]) + if i_tid not in allowed_token_ids: + continue + + partners = bp_partner_ann[int(start_idx)] + if partners is None: + continue + if not isinstance(partners, (list, tuple, np.ndarray)): + continue + + for partner_token_id in partners: + try: + j_tid = int(partner_token_id) + except Exception: + continue + + if j_tid == i_tid or j_tid not in allowed_token_ids: + continue + + if j_tid not in token_id_to_pos: + continue + + a, b = (i_tid, j_tid) if i_tid < j_tid else (j_tid, i_tid) + pairs.add((a, b)) + + return pairs + + +def _extract_loop_and_paired_token_ids( + atom_array: AtomArray, + *, + allowed_token_ids: set[int], +) -> tuple[set[int], set[int]]: + """Return (loop_token_ids, paired_token_ids) within the allowed token set.""" + token_starts = get_token_starts(atom_array) + token_level_array = atom_array[token_starts] + token_ids = np.asarray(token_level_array.token_id, dtype=int) + token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())} + + bp_partner_ann = _get_bp_partners_annotation(atom_array) + + loop_token_ids: set[int] = set() + paired_token_ids: set[int] = set() + + for pos, start_idx in enumerate(token_starts.tolist()): + i_tid = int(token_ids[pos]) + if i_tid not in allowed_token_ids: + continue + + partners = bp_partner_ann[int(start_idx)] + # New semantics: + # - None => unannotated/masked (NOT a loop) + # - [] => explicitly unpaired loop + if partners is None: + continue + if not isinstance(partners, (list, tuple, np.ndarray)): + continue + if len(partners) == 0: + loop_token_ids.add(i_tid) + continue + + for partner_token_id in partners: + try: + j_tid = int(partner_token_id) + except Exception: + continue + + if j_tid == i_tid or j_tid not in allowed_token_ids: + continue + if j_tid not in token_id_to_pos: + continue + paired_token_ids.add(i_tid) + paired_token_ids.add(j_tid) + + return loop_token_ids, paired_token_ids + + +def compute_from_two_arr( + gt_arr, pred_arr, restrict_to_nucleic=True, compute_for_diffused_region_only=False +): + gt_token_ids = _get_token_ids(gt_arr) + pred_token_ids = _get_token_ids(pred_arr) + if len(gt_token_ids) != len(pred_token_ids): + None + + # Restrict to token_ids that are valid in both arrays. + gt_allowed = _get_candidate_token_ids( + gt_arr, + restrict_to_nucleic=restrict_to_nucleic, + compute_for_diffused_region_only=compute_for_diffused_region_only, + ) + pred_allowed = _get_candidate_token_ids( + pred_arr, + restrict_to_nucleic=restrict_to_nucleic, + compute_for_diffused_region_only=compute_for_diffused_region_only, + ) + allowed = gt_allowed & pred_allowed + + if len(allowed) == 0: + return None + + gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed) + pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed) + + gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids( + gt_arr, allowed_token_ids=allowed + ) + pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids( + pred_arr, allowed_token_ids=allowed + ) + + pair_tp = len(gt_pairs & pred_pairs) + pair_pred_n = len(pred_pairs) + pair_gt_n = len(gt_pairs) + + loop_tp = len(gt_loop & pred_loop) + loop_pred_n = len(pred_loop) + loop_gt_n = len(gt_loop) + + pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n) + loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n) + + pair_weight = len(gt_paired_tokens) + loop_weight = len(gt_loop) + total_weight = pair_weight + loop_weight + if total_weight == 0: + weighted_f1 = 1.0 + else: + weighted_f1 = float( + (pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight + ) + + return pair_f1, loop_f1, weighted_f1 + + +def get_NA_SS_F1(pred_array): + ## save the original bop_partner annotation + gt_array = pred_array.copy() + + ## replace by annotating again + pred_array = annotate_na_ss( + pred_array, + NA_only=True, + planar_only=True, + overwrite=True, + p_canonical_bp_filter=0.0, + ) + + try: + pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_array, pred_array) + except Exception: + # fails when returns None because expects three returns + return {} + + return { + "pair_f1": pair_f1, + "loop_f1": loop_f1, + "weighted_f1": weighted_f1, + } + + +class NucleicSSSimilarityMetrics(Metric): + """Secondary-structure similarity for nucleic acids. + + Reports: + - `pair_f1`: F1 over basepair edges from token-level bp-partner annotation. + - `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partners == []`). + Unannotated tokens (`bp_partners is None`) are masked. + - `weighted_f1`: GT-weighted average of `pair_f1` and `loop_f1`, weighted by + the prevalence of paired vs loop tokens in the GT. + """ + + def __init__( + self, + *, + restrict_to_nucleic: bool = True, + compute_for_diffused_region_only: bool = False, + annotate_predicted_fresh: bool = False, + annotation_NA_only: bool = False, + annotation_planar_only: bool = True, + ): + super().__init__() + self.restrict_to_nucleic = restrict_to_nucleic + self.compute_for_diffused_region_only = compute_for_diffused_region_only + self.annotate_predicted_fresh = annotate_predicted_fresh + self.annotation_NA_only = annotation_NA_only + self.annotation_planar_only = annotation_planar_only + + @property + def kwargs_to_compute_args(self): + return { + "ground_truth_atom_array_stack": ("ground_truth_atom_array_stack",), + "predicted_atom_array_stack": ("predicted_atom_array_stack",), + } + + def compute(self, *, ground_truth_atom_array_stack, predicted_atom_array_stack): + if ground_truth_atom_array_stack is None or predicted_atom_array_stack is None: + return {} + + pair_f1_list: list[float] = [] + loop_f1_list: list[float] = [] + weighted_f1_list: list[float] = [] + + n_valid = 0 + + for gt_arr, pred_arr in zip( + ground_truth_atom_array_stack, predicted_atom_array_stack + ): + gt_categories = gt_arr.get_annotation_categories() + if "bp_partners" not in gt_categories: + continue + + # Important: predicted AtomArrays are built from a template AtomArray. + # If that template already carries bp_partners (often GT-derived), the + # prediction can inherit it, yielding artificially perfect scores. + # Optionally recompute bp_partners from the *predicted coordinates*. + if self.annotate_predicted_fresh: + try: + # Infer res name from geometry first + pred_arr = _readout_seq_from_struc( + pred_arr, + central_atom="C1'", + threshold=0.5, + association_scheme="atom23", + ) + + # strip virtuals and set final atom names/elements + pred_arr = _cleanup_virtual_atoms_and_assign_atom_name_elements( + pred_arr, + association_scheme="atom23", + ) + except Exception: + # this can fail early in training + print("could not cleanup virtuals for nucleic ss metric compute") + pass + # clear annotation to avoid potential info leak + if "bp_partners" in pred_arr.get_annotation_categories(): + pred_arr.del_annotation("bp_partners") + + # add nucleic-ss annotations + annotate_na_ss( + pred_arr, + NA_only=self.annotation_NA_only, + planar_only=self.annotation_planar_only, + overwrite=True, + p_canonical_bp_filter=0.0, + ) + pred_categories = pred_arr.get_annotation_categories() + if "bp_partners" not in pred_categories: + continue + + # Basic sanity check: token counts should match for aligned comparisons + try: + pair_f1, loop_f1, weighted_f1 = compute_from_two_arr( + gt_arr, + pred_arr, + restrict_to_nucleic=self.restrict_to_nucleic, + compute_for_diffused_region_only=self.compute_for_diffused_region_only, + ) + except Exception: + # fails when returns None because expects three returns + continue + + pair_f1_list.append(pair_f1) + loop_f1_list.append(loop_f1) + weighted_f1_list.append(weighted_f1) + n_valid += 1 + + if n_valid == 0: + return {} + + return { + "pair_f1": float(np.mean(pair_f1_list)), + "loop_f1": float(np.mean(loop_f1_list)), + "weighted_f1": float(np.mean(weighted_f1_list)), + "n_valid_samples": int(n_valid), + } diff --git a/models/rfd3/src/rfd3/metrics/rna_aptamer_metrics.py b/models/rfd3/src/rfd3/metrics/rna_aptamer_metrics.py new file mode 100644 index 00000000..bf1c7a36 --- /dev/null +++ b/models/rfd3/src/rfd3/metrics/rna_aptamer_metrics.py @@ -0,0 +1,115 @@ +import logging + +import numpy as np + +from foundry.metrics.metric import Metric +from foundry.utils.ddp import RankedLogger + +logging.basicConfig(level=logging.INFO) +global_logger = RankedLogger(__name__, rank_zero_only=False) + + +def calculate_ligand_contacts( + atom_array_stack, + cutoff_distance=4.0, +): + """ + Count number of atom contacts within cutoff of any ligand atom. + + Parameters + ---------- + atom_array_stack : AtomArrayStack + Shape: (n_models, n_atoms) + cutoff_distance : float + Distance cutoff in Å + + Returns + ------- + total_contacts : int + mean_contacts_per_model : float + """ + + cutoff_sq = cutoff_distance**2 + contacts_per_model = [] + + n_models = len(atom_array_stack) + + for i in range(n_models): + atoms = atom_array_stack[i] + + coords = atoms.coord + hetero_mask = atoms.hetero.astype(bool) + + # Skip if no ligand + if not np.any(hetero_mask): + contacts_per_model.append(0) + continue + + ligand_coords = coords[hetero_mask] + non_ligand_coords = coords[~hetero_mask] + + if len(non_ligand_coords) == 0: + contacts_per_model.append(0) + continue + + # Pairwise squared distances + diff = non_ligand_coords[:, None, :] - ligand_coords[None, :, :] + dist_sq = np.sum(diff**2, axis=-1) + + # Any ligand within cutoff + contact_mask = np.any(dist_sq < cutoff_sq, axis=1) + + n_contacts = np.sum(contact_mask) + contacts_per_model.append(n_contacts) + + contacts_per_model = np.array(contacts_per_model) + + return ( + int(np.sum(contacts_per_model)), + float(np.mean(contacts_per_model)), + float(np.mean(contacts_per_model)) / hetero_mask.sum(), + ) + + +class LigandContactMetrics(Metric): + def __init__( + self, + *, + cutoff_distance: float = 4.0, + restrict_to_nucleic: bool = True, + ): + super().__init__() + self.cutoff_distance = cutoff_distance + self.restrict_to_nucleic = restrict_to_nucleic + + @property + def kwargs_to_compute_args(self): + return { + "predicted_atom_array_stack": ("predicted_atom_array_stack",), + } + + def compute(self, *, predicted_atom_array_stack): + if self.restrict_to_nucleic: + if ( + predicted_atom_array_stack[0].is_rna.sum() + + predicted_atom_array_stack[0].is_dna.sum() + == 0 + ): + return {} + try: + total_contacts, mean_contacts, mean_contacts_per_atom = ( + calculate_ligand_contacts( + atom_array_stack=predicted_atom_array_stack, + cutoff_distance=self.cutoff_distance, + ) + ) + except Exception as e: + global_logger.error( + f"Error calculating ligand contact metrics: {e} | Skipping" + ) + return {} + + return { + "mean_ligand_contacts_per_model": float(mean_contacts), + "mean_ligand_contacts_per_atom": float(mean_contacts_per_atom), + } diff --git a/models/rfd3/src/rfd3/model/cfg_utils.py b/models/rfd3/src/rfd3/model/cfg_utils.py index d99c2fa0..2b9f860d 100644 --- a/models/rfd3/src/rfd3/model/cfg_utils.py +++ b/models/rfd3/src/rfd3/model/cfg_utils.py @@ -57,9 +57,14 @@ def strip_f( # set the feature to default value if it is in the cfg_features if k in cfg_features: - v_cropped = torch.zeros_like(v_cropped).to( - v_cropped.device, dtype=v_cropped.dtype - ) + if k not in ["bp_partners"]: + v_cropped = torch.zeros_like(v_cropped).to( + v_cropped.device, dtype=v_cropped.dtype + ) + else: + ## for bp_partners default is a mask feature + v_cropped[:, :, 0] = 1 + v_cropped[:, :, 1:] = 0 # update the feature in the dictionary f_stripped[k] = v_cropped diff --git a/models/rfd3/src/rfd3/model/layers/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py index aeac08c8..3dfa0d56 100644 --- a/models/rfd3/src/rfd3/model/layers/block_utils.py +++ b/models/rfd3/src/rfd3/model/layers/block_utils.py @@ -210,7 +210,9 @@ def create_attention_indices( chain_ids is not None and len(torch.unique(chain_ids)) > 3 ): # Multi-chain structure # Reserve 25% of attention keys for inter-chain interactions - k_inter_chain = max(32, k_actual // 4) # At least 32 inter-chain keys + k_inter_chain = min( + max(32, k_actual // 4), k_actual + ) # At least 32 inter-chain keys k_intra_chain = k_actual - k_inter_chain attn_indices = get_sparse_attention_indices_with_inter_chain( diff --git a/models/rfd3/src/rfd3/model/layers/blocks.py b/models/rfd3/src/rfd3/model/layers/blocks.py index eaf08093..3290b9e0 100644 --- a/models/rfd3/src/rfd3/model/layers/blocks.py +++ b/models/rfd3/src/rfd3/model/layers/blocks.py @@ -144,6 +144,43 @@ def forward(self, f, collapse_length): ) +class TwoDFeatureEmbedder(nn.Module): + """ + Embeds 2D features into a single vector. + + Args: + features (dict): Dictionary of feature names and their number of channels. + output_channels (int): Output dimension of the projected embedding. + """ + + def __init__(self, features, output_channels): + super().__init__() + self.features = {k: v for k, v in features.items() if exists(v)} + total_embedding_input_features = sum(self.features.values()) + self.embedders = nn.ModuleDict( + { + feature: EmbeddingLayer( + n_channels, total_embedding_input_features, output_channels + ) + for feature, n_channels in self.features.items() + } + ) + + def collapse2D(self, x, L): + return x.reshape((L, L, x.numel() // (L * L))) + + def forward(self, f, collapse_length): + return sum( + tuple( + self.embedders[feature]( + self.collapse2D(f[feature].float(), collapse_length) + ) + for feature, n_channels in self.features.items() + if exists(n_channels) + ) + ) + + class SinusoidalDistEmbed(nn.Module): """ Applies sinusoidal embedding to pairwise distances and projects to c_atompair. diff --git a/models/rfd3/src/rfd3/model/layers/encoders.py b/models/rfd3/src/rfd3/model/layers/encoders.py index b0ed86fa..fc19a95c 100644 --- a/models/rfd3/src/rfd3/model/layers/encoders.py +++ b/models/rfd3/src/rfd3/model/layers/encoders.py @@ -14,6 +14,7 @@ PositionPairDistEmbedder, RelativePositionEncodingWithIndexRemoval, SinusoidalDistEmbed, + TwoDFeatureEmbedder, ) from rfd3.model.layers.chunked_pairwise import ( ChunkedPairwiseEmbedder, @@ -51,6 +52,7 @@ def __init__( token_1d_features, atom_1d_features, atom_transformer, + token_2d_features=None, use_chunked_pll=False, # New parameter for memory optimization ): super().__init__() @@ -62,6 +64,10 @@ def __init__( self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s) self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom) self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s) + if token_2d_features is not None: + self.token_2d_embedder = TwoDFeatureEmbedder(token_2d_features, c_z) + else: + self.token_2d_embedder = None self.downcast_atom = Downcast(c_atom=c_s, c_token=c_s, c_s=None, **downcast) self.transition_post_token = Transition(c=c_s, n=2) @@ -202,6 +208,9 @@ def init_tokens(): Z_init_II = Z_init_II + self.ref_pos_embedder_tok( f["ref_pos"][f["is_ca"]], valid_mask ) + # Add extra token pair features + if self.token_2d_embedder is not None: + Z_init_II = Z_init_II + self.token_2d_embedder(f, I) # Run a small transformer to provide position encodings to single. for block in self.transformer_stack: diff --git a/models/rfd3/src/rfd3/trainer/rfd3.py b/models/rfd3/src/rfd3/trainer/rfd3.py index a7f72e6f..93a8f18c 100644 --- a/models/rfd3/src/rfd3/trainer/rfd3.py +++ b/models/rfd3/src/rfd3/trainer/rfd3.py @@ -8,6 +8,7 @@ from omegaconf import DictConfig from rfd3.metrics.design_metrics import get_all_backbone_metrics from rfd3.metrics.hbonds_hbplus_metrics import get_hbond_metrics +from rfd3.metrics.nucleic_ss_metrics import get_NA_SS_F1 from rfd3.trainer.recycling import get_recycle_schedule from rfd3.trainer.trainer_utils import ( _build_atom_array_stack, @@ -428,9 +429,14 @@ def _build_predicted_atom_array_stack( # ... Delete virtual atoms and assign atom names and elements if self.cleanup_virtual_atoms: - atom_array = _cleanup_virtual_atoms_and_assign_atom_name_elements( - atom_array, association_scheme=self.association_scheme - ) + try: + atom_array = _cleanup_virtual_atoms_and_assign_atom_name_elements( + atom_array, association_scheme=self.association_scheme + ) + except Exception as e: + global_logger.warning( + f"Failed to cleanup virtual atoms from diffusion output: {e}" + ) # ... When cleaning up virtual atoms, we can also calculate native_array_metricsl metadata_dict[i]["metrics"] |= get_all_backbone_metrics( @@ -444,6 +450,12 @@ def _build_predicted_atom_array_stack( ): metadata_dict[i]["metrics"] |= get_hbond_metrics(atom_array) + if "bp_partners" in atom_array.get_annotation_categories(): + if not np.all(atom_array.bp_partners == None): # noqa: E711 + try: + metadata_dict[i]["metrics"] |= get_NA_SS_F1(atom_array) + except Exception: + pass if "partial_t" in f: # Try calcualte a CA RMSD to input: aa_in = example["atom_array"] diff --git a/models/rfd3/src/rfd3/trainer/trainer_utils.py b/models/rfd3/src/rfd3/trainer/trainer_utils.py index 59f43a1c..a2cd07c9 100644 --- a/models/rfd3/src/rfd3/trainer/trainer_utils.py +++ b/models/rfd3/src/rfd3/trainer/trainer_utils.py @@ -2,6 +2,7 @@ import numpy as np import torch +from atomworks.constants import STANDARD_DNA, STANDARD_RNA from atomworks.ml.encoding_definitions import AF3SequenceEncoding from atomworks.ml.utils.token import ( get_token_starts, @@ -11,9 +12,12 @@ from jaxtyping import Float, Int from rfd3.constants import ( ATOM14_ATOM_NAMES, + ATOM23_ATOM_NAMES_DNA, + ATOM23_ATOM_NAMES_RNA, VIRTUAL_ATOM_ELEMENT_NAME, association_schemes, association_schemes_stripped, + backbone_atoms_RNA, ) from rfd3.utils.io import ( build_stack_from_atom_array_and_batched_coords, @@ -148,7 +152,6 @@ def _cleanup_virtual_atoms_and_assign_atom_name_elements( is_seq_known = all( np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool) ) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool)) - # ... If sequence is known for the original atom array, just skip if is_seq_known: ret_mask += [True] * len(res_array) @@ -218,7 +221,10 @@ def _readout_seq_from_struc( # There might be a better way to do this. CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"] CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"] - if np.linalg.norm(CA_coord - CB_coord) < threshold: + + if cur_res_atom_array.is_dna[0] or cur_res_atom_array.is_rna[0]: + cur_central_atom = "C1'" + elif np.linalg.norm(CA_coord - CB_coord) < threshold: cur_central_atom = "CA" else: cur_central_atom = central_atom @@ -252,15 +258,28 @@ def _readout_seq_from_struc( continue # ... Find the index of virtual atom names in the standard atom14 names + ATOM_NAMES = ATOM14_ATOM_NAMES + if restype in STANDARD_DNA: + ATOM_NAMES = ATOM23_ATOM_NAMES_DNA + if not cur_res_atom_array.is_dna[0]: + continue + elif restype in STANDARD_RNA: + ATOM_NAMES = ATOM23_ATOM_NAMES_RNA + if not cur_res_atom_array.is_rna[0]: + continue + else: + # ATOM_NAMES = ATOM23_ATOM_NAMES_RNA + if not cur_res_atom_array.is_protein[0]: + continue + atom_name_idx_in_atom14_scheme = np.array( [ - np.where(ATOM14_ATOM_NAMES == atom_name)[0][0] + np.where(ATOM_NAMES == atom_name)[0][0] for atom_name in cur_pred_res_atom_names ] ) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7] - atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool) + atom14_scheme_mask = np.zeros_like(ATOM_NAMES, dtype=bool) atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True - # ... Find the matched restype by checking if all the non-None posititons and None positions match # This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later. if all(x is not None for x in atom_names[atom14_scheme_mask]) and all( @@ -403,7 +422,13 @@ def process_unindexed_outputs( row_ind, col_ind = linear_sum_assignment(dists) res_id_, chain_id_, _ = indices_to_components_(atom_array_diffused, col_ind) - assert (res_id_ == res_id) & (chain_id_ == chain_id) + try: + assert (res_id_ == res_id) & (chain_id_ == chain_id) + except Exception: + global_logger.warning( + "Unindexed mapping did not work properly, res_id, chain_id" + ) + inserted_mask = np.logical_or(inserted_mask, token_match) # ... Compute metrics based on the new distances @@ -427,13 +452,30 @@ def process_unindexed_outputs( else: join_atom = None + if join_atom is None: + pass + else: + dist = float(dists[row_ind[join_atom], col_ind[join_atom]]) + + elif not np.any(np.isin(token.atom_name, backbone_atoms_RNA)): + if np.sum(token.atomize) == 1: + join_atom = np.where(token.atomize)[0][0] + elif "C1'" in token.atom_name: + join_atom = np.where(token.atom_name == "C1'")[0][0] + else: + join_atom = None + if join_atom is None: global_logger.warning( - f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}." + "Skipping joint point rmsd, neither protein or NA backbone" ) else: dist = float(dists[row_ind[join_atom], col_ind[join_atom]]) + + try: metadata["join_point_rmsd_by_token"][token_pdb_id] = dist + except Exception: + pass metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}" diff --git a/models/rfd3/src/rfd3/transforms/conditioning_base.py b/models/rfd3/src/rfd3/transforms/conditioning_base.py index 01a303c4..682005bb 100644 --- a/models/rfd3/src/rfd3/transforms/conditioning_base.py +++ b/models/rfd3/src/rfd3/transforms/conditioning_base.py @@ -235,6 +235,7 @@ def __init__( train_conditions: dict, meta_conditioning_probabilities: dict, sequence_encoding, + association_scheme, ): if exists(train_conditions): train_conditions = hydra.utils.instantiate( @@ -242,7 +243,12 @@ def __init__( ) self.meta_conditioning_probabilities = meta_conditioning_probabilities self.train_conditions = train_conditions + + for item in self.train_conditions: + self.train_conditions[item].association_scheme = association_scheme + self.sequence_encoding = sequence_encoding + self.association_scheme = association_scheme def check_input(self, data: dict): assert not data["is_inference"], "This transform is only used during training!" @@ -259,10 +265,13 @@ def check_input(self, data: dict): assert "conditions" in data, "Conditioning dict not initialized" def forward(self, data): + # for item in self.train_conditions: + # print(self.train_conditions[item].is_valid_for_example(data)) + valid_conditions = [ cond for cond in self.train_conditions.values() - if cond.frequency > 0 and cond.is_valid_for_example(data) + if cond.is_valid_for_example(data) and cond.frequency > 0 ] if len(valid_conditions) == 0: @@ -278,6 +287,8 @@ def forward(self, data): i_cond = np.random.choice(np.arange(len(p_cond)), p=p_cond) cond = valid_conditions[i_cond] + cond.association_scheme = self.association_scheme + data["sampled_condition"] = cond data["sampled_condition_name"] = cond.name data["sampled_condition_cls"] = cond.__class__ @@ -296,6 +307,9 @@ class SampleConditioningFlags(Transform): "SampleConditioningType", ] # We use is_protein in the PPI training condition + def __init__(self, association_scheme): + self.association_scheme = association_scheme + def check_input(self, data): assert not data[ "is_inference" @@ -317,13 +331,14 @@ class UnindexFlaggedTokens(Transform): Serves as the merge point between training / infernece conditioning pipelines """ - def __init__(self, central_atom): + def __init__(self, central_atom, association_scheme): """ Args: central_atom: The atom to use as the central atom for unindexed motifs. """ super().__init__() self.central_atom = central_atom + self.association_scheme = association_scheme def check_input(self, data: dict): check_contains_keys(data, ["atom_array"]) @@ -368,8 +383,18 @@ def expand_unindexed_motifs( token.res_id = token.res_id + max_resid token.is_C_terminus[:] = False token.is_N_terminus[:] = False - assert token.is_protein.all(), f"Cannot unindex non-protein token: {token}" - token = add_representative_atom(token, central_atom=self.central_atom) + + if not self.association_scheme == "atom23": + assert token.is_protein.all(), f"Cannot unindex non-protein token: {token} unless using atom23 association scheme" + token = add_representative_atom(token, central_atom=self.central_atom) + else: + if token.is_protein.all(): + token = add_representative_atom( + token, central_atom=self.central_atom + ) + else: + token = add_representative_atom(token, central_atom="C1'") + unindexed_tokens.append(token) # ... Remove original tokens e.g. during inference @@ -404,7 +429,6 @@ def expand_unindexed_motifs( f"Failed to create uniquely recognised tokens after concatenation.\n" f"Concatenated tokens: {get_token_count(atom_array_full)}, unindexed: {n_unindexed_tokens}" ) - return atom_array_full def create_unindexed_masks( diff --git a/models/rfd3/src/rfd3/transforms/design_transforms.py b/models/rfd3/src/rfd3/transforms/design_transforms.py index 38c79979..e7d2d4cb 100644 --- a/models/rfd3/src/rfd3/transforms/design_transforms.py +++ b/models/rfd3/src/rfd3/transforms/design_transforms.py @@ -38,6 +38,7 @@ UnindexFlaggedTokens, get_motif_features, ) +from rfd3.transforms.na_geom import na_ss_feats_from_annotation from rfd3.transforms.rasa import discretize_rasa from rfd3.transforms.util_transforms import ( AssignTypes, @@ -70,8 +71,10 @@ class SubsampleToTypes(Transform): def __init__( self, allowed_types: list | str = ["is_protein"], + association_scheme: str = "atom14", ): self.allowed_types = allowed_types + self.association_scheme = association_scheme if not self.allowed_types == "ALL": for k in allowed_types: if not k.startswith("is_"): @@ -103,7 +106,7 @@ def forward(self, data): ) ) - if atom_array.is_protein.sum() == 0: + if self.association_scheme != "atom23" and atom_array.is_protein.sum() == 0: raise ValueError( "No protein atoms found in the atom array. Example ID: {}".format( data.get("example_id", "unknown") @@ -697,10 +700,12 @@ def __init__( token_1d_features, atom_1d_features, autofill_zeros_if_not_present_in_atomarray=False, + association_scheme="atom14", ): self.autofill = autofill_zeros_if_not_present_in_atomarray self.token_1d_features = token_1d_features self.atom_1d_features = atom_1d_features + self.association_scheme = association_scheme def check_input(self, data) -> None: check_contains_keys(data, ["atom_array"]) @@ -753,6 +758,13 @@ def forward(self, data: Dict[str, Any]) -> Dict[str, Any]: if "feats" not in data.keys(): data["feats"] = {} + if self.association_scheme == "atom23": + data["atom_array"].set_annotation( + "is_protein_token", data["atom_array"].is_protein + ) + data["atom_array"].set_annotation("is_dna_token", data["atom_array"].is_dna) + data["atom_array"].set_annotation("is_rna_token", data["atom_array"].is_rna) + for feature_name, n_dims in self.token_1d_features.items(): data = self.generate_feature(feature_name, n_dims, data, "token") @@ -762,6 +774,91 @@ def forward(self, data: Dict[str, Any]) -> Dict[str, Any]: return data +class AddAdditional2dFeaturesToFeats(Transform): + """ + Adds any net.token_initializer.token_2d_features and net.diffusion_module.diffusion_atom_encoder.atom_2d_features present in the atomarray but not in data['feats'] to data['feats'] + Args: + - autofill_zeros_if_not_present_in_atomarray: self explanatory + - token_2d_features: List of single-item dictionaries, corresponding to feature_name: n_feature_dims. Should be hydra interpolated from + net.token_initializer.token_2d_features + """ + + incompatible_previous_transforms = ["AddAdditional2dFeaturesToFeats"] + + def __init__( + self, + token_2d_features, + autofill_zeros_if_not_present_in_atomarray=False, + association_scheme="atom14", + ): + self.autofill = autofill_zeros_if_not_present_in_atomarray + self.token_2d_features = token_2d_features + self.association_scheme = association_scheme + + # Need to pre-define custom constructor functions + # to map from atomarray annotations to tensors. + self.constructor_functions = { + "bp_partners": na_ss_feats_from_annotation, + } + + def check_input(self, data) -> None: + check_contains_keys(data, ["atom_array"]) + check_is_instance(data, "atom_array", AtomArray) + + def generate_token_feature(self, feature_name, n_dims, data): + # Don't do this if we already have the feature + if feature_name in data["feats"].keys(): + return data + + # For these, we need to use a constructor function mapping, + # since pair features may require custom logic/conventions. + + ## for old ckpt handling ## + if feature_name in self.constructor_functions.keys(): + feature_array = self.constructor_functions[feature_name](data["atom_array"]) + else: + raise ValueError( + f"No constructor function found for 2d feature `{feature_name}`" + ) + + # We can fix shape issues here: + if len(feature_array.shape) == 2 and n_dims == 1: + feature_array = feature_array.unsqueeze(1) + + # ensure that feature_array is a 3d array with third dim == n_dims: + if len(feature_array.shape) != 3: + raise ValueError( + f"token 2d_feature `{feature_name}` must be a 3d array, got {len(feature_array.shape)}d." + ) + if feature_array.shape[2] != n_dims: + raise ValueError( + f"token 2d_feature `{feature_name}` dimensions in atomarray ({feature_array.shape[-1]}) does not match dimension declared in config, ({n_dims})" + ) + # Ensure correct shape in first two dims (I,I,...) + if feature_array.shape[0] != feature_array.shape[1]: + raise ValueError( + f"token 2d_feature `{feature_name}` first two dimensions must be equal (square matrix), got {feature_array.shape[0]} and {feature_array.shape[1]}" + ) + + data["feats"][feature_name] = feature_array + + return data + + def forward(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Checks if the 2d_features are present in data['feats']. If not present, adds them from the atomarray. + If annotation is not present in atomarray, either autofills the feature with 0s or throws an error + """ + if "feats" not in data.keys(): + data["feats"] = {} + # Only apply for features that the model is expecting: + if self.token_2d_features is None: + return data + for feature_name, n_dims in self.token_2d_features.items(): + data = self.generate_token_feature(feature_name, n_dims, data) + return data + + class FeaturizepLDDT(Transform): """ Provides: diff --git a/models/rfd3/src/rfd3/transforms/na_geom.py b/models/rfd3/src/rfd3/transforms/na_geom.py new file mode 100644 index 00000000..ee264f2b --- /dev/null +++ b/models/rfd3/src/rfd3/transforms/na_geom.py @@ -0,0 +1,323 @@ +from typing import Any + +import numpy as np +from atomworks.ml.transforms._checks import ( + check_atom_array_annotation, + check_contains_keys, + check_is_instance, +) +from atomworks.ml.transforms.base import Transform +from atomworks.ml.utils.token import get_token_starts, spread_token_wise +from biotite.structure import AtomArray +from rfd3.transforms.conditioning_utils import sample_island_tokens +from rfd3.transforms.na_geom_utils import ( + DEFAULT_NA_SS_FEATURE_INFO, + annotate_na_ss, + annotate_na_ss_from_data_specification, +) + + +def na_ss_feats_from_annotation( + atom_array: AtomArray, + token_starts=None, + n_tokens=None, + return_as_onehot=True, +) -> np.ndarray: + """ + Takes in atom array and constucts a base pair feature matrix from annotations, + according to to custom feature constuction + masking system. + This featurization utilizes info from BasePairEnum to assign int values + to paired, unpaired, and masked positions in the matrix. + + Args: + * atom_array: AtomArray with bp_partners annotation at atom level + * token_starts (optional): indices of token starts in the atom array + * n_tokens (optional): number of tokens (length of token_starts) + * return_as_onehot (optional): if False, return integer-encoded + matrix instead of one-hot encoded matrix + + returns: + * na_ss_matrix: + If ``return_as_onehot`` is True (default): + np.ndarray of shape (n_tokens, n_tokens, n_classes) + with one-hot encoded values according to BasePairEnum + + If ``return_as_onehot`` is False : + np.ndarray of shape (n_tokens, n_tokens) + with int values according to BasePairEnum + + + """ + # Get this info from atom_array, or avoid if given + if (token_starts is None) or (n_tokens is None): + token_starts = get_token_starts(atom_array) + n_tokens = len(token_starts) + + # Collect token inds for paired or loop positions: + pair_inds = [] + loop_inds = [] + token_bp_partners = atom_array.get_annotation("bp_partners")[ + token_starts + ] # get bp_partners at token level + assert ( + len(token_bp_partners) == n_tokens + ), "Length of token_bp_partners should match n_tokens" + for i, j_list in enumerate(token_bp_partners): + if j_list is not None: + if len(j_list) > 0: + for j in j_list: + pair_inds.append((i, j)) + else: + loop_inds.append(i) + + # The standard system for constructing meaningful base pair features: + # 0). Initialize with values of UNSPECIFIED (0): int matrix of shape (n_tokens, n_tokens) + na_ss_matrix = np.full( + (n_tokens, n_tokens), DEFAULT_NA_SS_FEATURE_INFO["NA_SS_MASK"], dtype=np.int64 + ) + + # 1). Fill in with values of PAIR (1) at positions that have bp_partners annotated as a non-empty list + for pair_i, pair_j in pair_inds: + na_ss_matrix[pair_i, pair_j] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"] + na_ss_matrix[pair_j, pair_i] = DEFAULT_NA_SS_FEATURE_INFO[ + "NA_SS_PAIR" + ] # ensure symmetry + + # 2). Fill in with values of LOOP (2) at positions that have bp_partners annotated as an empty list (explicitly unpaired) + # (we make full stripes across that position's row/col to indicate that NONE of those other positions are paired ) + for loop_i in loop_inds: + na_ss_matrix[loop_i, :] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"] + na_ss_matrix[:, loop_i] = DEFAULT_NA_SS_FEATURE_INFO[ + "NA_SS_LOOP" + ] # ensure symmetry + + # Optional: convert NA-SS matrix to one-hot encoding according for model input: + if return_as_onehot: + na_ss_matrix = np.eye(len(DEFAULT_NA_SS_FEATURE_INFO), dtype=np.int64)[ + na_ss_matrix + ] + + return na_ss_matrix + + +class CalculateNucleicAcidGeomFeats(Transform): + """ + Transform for constructing nucleic-acid conditioning features. + + This transform currently produces only nucleic-acid secondary-structure (NA-SS) + features as a 2D token-token matrix with 3 bins: + * 0: mask / unspecified + * 1: paired + * 2: loop / explicitly unpaired + + Training: + - Computes geometry/H-bond-based base pairs and writes them onto the AtomArray + via the ``bp_partners`` annotation (annotation-first), then reconstructs the + matrix (and optionally masks parts of it) before one-hot encoding. + + Inference: + - Interprets user-provided secondary-structure specifications, writes the same + ``bp_partners`` annotation, then follows the same matrix + one-hot path. + + Note: helical-parameter features are not implemented/used in this refactored path. + """ + + def __init__( + self, + is_inference, + # Conditional sampling parameters all stored in this dict: + meta_conditioning_probabilities: dict[str, float] = None, + # Mask control paramerers: + nucleic_ss_min_shown: float = 0.2, + nucleic_ss_max_shown: float = 1.0, + n_islands_min: int = 1, + n_islands_max: int = 6, + # USE_RF2AA_NAMES: bool = False, + NA_only: bool = False, + planar_only: bool = True, + ): + # Critical, must always have to know how to handle + self.is_inference = is_inference + + self.meta_conditioning_probabilities = meta_conditioning_probabilities or {} + + # Control whether we show some nucleic SS or default to full 2D mask + self.p_is_nucleic_ss_example = self.meta_conditioning_probabilities.get( + "p_is_nucleic_ss_example", 0.0 + ) + + # Control whether we define full SS or just part of it (only applies if is NA SS example) + self.p_show_partial_feats = self.meta_conditioning_probabilities.get( + "p_nucleic_ss_show_partial_feats", 0.0 + ) + + # Some frac of time default to only showing canonical base pairs + self.p_canonical_bp_filter = self.meta_conditioning_probabilities.get( + "p_canonical_bp_filter", 0.5 + ) + + # mask patterning control to make things resemble design scenarios + self.nucleic_ss_min_shown = nucleic_ss_min_shown + self.nucleic_ss_max_shown = nucleic_ss_max_shown + self.n_islands_min = n_islands_min + self.n_islands_max = n_islands_max + + # Filters for what can be considered a planar contact interaction + self.NA_only = ( + NA_only # only annotate base-like interactions for nucleic acid residues + ) + self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations, + + def check_input(self, data: dict[str, Any]) -> None: + check_contains_keys(data, ["atom_array"]) + check_is_instance(data, "atom_array", AtomArray) + check_atom_array_annotation(data, ["res_name"]) + # maybe do later: check_atom_array_has_hydrogen(data) + + def _sample_training_flags(self) -> tuple[bool, bool]: + """Sample booleans controlling whether/how features are shown in training.""" + is_nucleic_ss_example = bool(np.random.rand() < self.p_is_nucleic_ss_example) + give_partial_feats = bool(np.random.rand() < self.p_show_partial_feats) + return is_nucleic_ss_example, give_partial_feats + + def forward(self, data: dict) -> dict: + atom_array = data["atom_array"] + + # Calculate n_tokens (assuming one token per residue for simplicity) + token_starts = get_token_starts(atom_array) + n_tokens = len(token_starts) + # token_level_array = atom_array[token_starts] + + # Handle the training case with ground truth and masking + if not self.is_inference: + # First, annotate as usual + is_nucleic_ss_example, give_partial_feats = self._sample_training_flags() + + if is_nucleic_ss_example: + atom_array = annotate_na_ss( + atom_array, + NA_only=self.NA_only, + planar_only=self.planar_only, + p_canonical_bp_filter=self.p_canonical_bp_filter, + ) + + # Generate symmetric partner annotations at the token level for masking purposes. + # choice for object-consistency: if already masked/undefined: be a list mapping to self-index. + partner_sym_map = { + i: atom_array.bp_partners[ts_i] + if atom_array.bp_partners[ts_i] is not None + else [i] + for i, ts_i in enumerate(token_starts) + } + + # # Sample mask on token level: + token_mask_to_show = self._sample_where_to_show_ss( + n_tokens, + is_nucleic_ss_example=is_nucleic_ss_example, + give_partial_feats=give_partial_feats, + partner_sym_map=partner_sym_map, + ) # Mask vec for tokens where ss shown + + # Spread mask to atom level + is_ss_shown = spread_token_wise(atom_array, token_mask_to_show) + + # Extract the base pair annotations + bp_partners_atom = atom_array.get_annotation("bp_partners") + + # Remove unshown positions from bp_partners annotation + bp_partners_atom[~is_ss_shown] = None + + # Reset the annotation with newly hidden positions + atom_array.set_annotation("bp_partners", bp_partners_atom) + else: + atom_array.set_annotation( + "bp_partners", np.array([None] * len(atom_array)) + ) + + # Inference case: create from commandline args + else: + """ + Different cases handled: + - 1). Single dot-bracket string + - 2). multiple dot bracket strings with chain/ind ranges specified + - 3). Lists of paired indices + """ + atom_array = annotate_na_ss_from_data_specification( + data, + overwrite=True, + ) + + # Check feats existence and update: + if "feats" not in data: + data["feats"] = {} + + data.setdefault("log_dict", {}) + log_dict = data["log_dict"] + data["log_dict"] = log_dict + data["atom_array"] = atom_array + + return data + + def _sample_where_to_show_ss( + self, + n_tokens: int, + is_nucleic_ss_example: bool = True, + give_partial_feats: bool = True, + partner_sym_map: dict[int, list[int]] = None, + ) -> np.ndarray: + """Sample token-level islands indicating which SS rows/cols to reveal. + This custom function allows for enforcing symmetry in the shown features according + to the partner_sym_map, which encodes which tokens are partners in the SS + matrix and thus should be masked/unmasked together to maintain consistency. + + """ + # If NOT is_nucleic_ss_example, set is_shown to all False + if not is_nucleic_ss_example: + token_mask_to_show = np.zeros((n_tokens,), dtype=bool) + + # If NOT give_partial_feats, set is_shown to all True + if not give_partial_feats: + token_mask_to_show = np.ones((n_tokens,), dtype=bool) + else: + # Get numerical parameters for that govern the mask pattern + frac_shown = ( + self.nucleic_ss_min_shown + + (self.nucleic_ss_max_shown - self.nucleic_ss_min_shown) + * np.random.rand() + ) + frac_shown = float(np.clip(frac_shown, 0.0, 1.0)) + max_length = int(np.ceil(frac_shown * n_tokens)) + if max_length <= 0: + token_mask_to_show = np.zeros((n_tokens,), dtype=bool) + island_len_min = max( + 1, int(frac_shown * n_tokens // max(int(self.n_islands_max), 1)) + ) + island_len_max = max( + 1, int(frac_shown * n_tokens // max(int(self.n_islands_min), 1)) + ) + island_len_min = min(island_len_min, n_tokens) + island_len_max = min(island_len_max, n_tokens) + island_len_max = max(island_len_max, island_len_min) + + # Sample the actual mask using the utility function: + token_mask_to_show = sample_island_tokens( + n_tokens, + island_len_min=island_len_min, + island_len_max=island_len_max, + n_islands_min=self.n_islands_min, + n_islands_max=self.n_islands_max, + max_length=max_length, + ) + + # Handle symmetry by iterating through the partner_sym_map items and setting + # `partner_mask_to_show` at partner positions to match `token_mask_to_show` + # initialize as all shown so effect comes from hiding + logical AND condition + partner_mask_to_show = np.ones_like(token_mask_to_show) + for token_i, partner_ind_list in partner_sym_map.items(): + for partner_ind in partner_ind_list: + partner_mask_to_show[partner_ind] = token_mask_to_show[token_i] + + # Combine the original mask with the partner mask to ensure symmetry + token_mask_to_show = token_mask_to_show & partner_mask_to_show + + return token_mask_to_show diff --git a/models/rfd3/src/rfd3/transforms/na_geom_utils.py b/models/rfd3/src/rfd3/transforms/na_geom_utils.py new file mode 100644 index 00000000..7a4dd911 --- /dev/null +++ b/models/rfd3/src/rfd3/transforms/na_geom_utils.py @@ -0,0 +1,1866 @@ +import math +import os +import subprocess +import tempfile +from datetime import datetime +from typing import Dict, Optional + +import biotite.structure as struc +import numpy as np +from atomworks.constants import ( + STANDARD_AA, + STANDARD_DNA, + STANDARD_RNA, +) +from atomworks.io.utils.sequence import ( + is_purine, + is_pyrimidine, +) +from atomworks.ml.encoding_definitions import AF3SequenceEncoding +from atomworks.ml.utils.token import ( + get_token_starts, + is_glycine, + is_protein_unknown, + is_standard_aa_not_glycine, + is_unknown_nucleotide, +) +from biotite.structure import AtomArray +from rfd3.constants import ( + ATOM_REGION_BY_RESI, + PLANAR_ATOMS_BY_RESI, +) +from rfd3.transforms.hbonds_hbplus import save_atomarray_to_pdb + +# Derived: True when the residue has any planar sidechain atoms +HAS_PLANAR_SC = {res: bool(atoms) for res, atoms in PLANAR_ATOMS_BY_RESI.items()} + +DEFAULT_NA_SS_FEATURE_INFO: dict[str, int] = { + "NA_SS_MASK": 0, + "NA_SS_PAIR": 1, + "NA_SS_LOOP": 2, +} + +AA_PLANAR_ATOMS = sorted( + set( + atom + for res in STANDARD_AA + if res in PLANAR_ATOMS_BY_RESI + for atom in PLANAR_ATOMS_BY_RESI[res] + ) +) + +NA_PLANAR_ATOMS = sorted( + set( + atom + for res in (*STANDARD_RNA, *STANDARD_DNA) + if res in PLANAR_ATOMS_BY_RESI + for atom in PLANAR_ATOMS_BY_RESI[res] + ) +) + + +class NucMolInfo: + """Constants and parameters for nucleic-acid geometry and interaction scoring. + + All parameters are set to empirically validated defaults. No constructor + arguments are currently accepted. + """ + + def __init__(self) -> None: + # Hbond interaction-class indices of the `hbond_count`` array: + # `hbond_count`` array is (L, L, 3), where the last dimension + # encodes interaction type between tokens i & j + self.BB_BB = 0 # backbone-backbone hbond interactions + self.BB_SC = 1 # backbone-sidechain hbond interactions + self.SC_SC = 2 # sidechain-sidechain hbond interactions + + # We sum over the last dimension of the hbond_count array, scaling + # count by the following weights to get the interaction score: + self.bp_weight_BB_BB = 0.0 + self.bp_weight_BB_SC = 0.5 + self.bp_weight_SC_SC = 1.0 + self.bp_summation_weights = [ + self.bp_weight_BB_BB, + self.bp_weight_BB_SC, + self.bp_weight_SC_SC, + ] + + # Parameters fo sigmoid function that gives us a continuous step function for + # meeting basepair interaction criteria based on hbond counts alone (1st filter). + # Calibrated such that: + # >= 2 base-base H-bonds -> ~1.0 + # 1 base-base H-bond + 1 base-backbone H-bond -> ~0.5 + self.min_hbonds_for_bp = 2.0 + self.bp_hbond_coeff = 9.8 # determined heuristically + self.bp_val_cutoff = ( + 0.5 # minimum basepairing score for binarizing basepairs when needed + ) + + self.base_geometry_limits = {} + self.base_geometry_limits["D_ij"] = 16.0 + self.base_geometry_limits["H_ij"] = 1.5 + self.base_geometry_limits["P_ij"] = math.pi / 5 + self.base_geometry_limits["B_ij"] = math.pi / 5 + + self.rep_atom_dict = {"protein": "CA", "rna": "C1'", "dna": "C1'"} + + # go through self.vec_atom_dict and remove spaces from atom names (values in inner dicts), and remove spaces from keys + replace 'R' with '' in outer dict keys + self.vec_atom_dict = { + "DA": { + "W_start": "N1", + "W_stop": "N6", + "H_start": "N7", + "H_stop": "N6", + "S_start": "C1'", + "S_stop": "N3", + "B_start": "C1'", + "B_stop": "N9", + }, + "DG": { + "W_start": "N1", + "W_stop": "O6", + "H_start": "N7", + "H_stop": "O6", + "S_start": "C1'", + "S_stop": "N3", + "B_start": "C1'", + "B_stop": "N9", + }, + "DC": { + "W_start": "N3", + "W_stop": "N4", + "H_start": "C5", + "H_stop": "N4", + "S_start": "C1'", + "S_stop": "O2", + "B_start": "C1'", + "B_stop": "N1", + }, + "DT": { + "W_start": "N3", + "W_stop": "O4", + "H_start": "C5", + "H_stop": "O4", + "S_start": "C1'", + "S_stop": "O2", + "B_start": "C1'", + "B_stop": "N1", + }, + "A": { + "W_start": "N1", + "W_stop": "N6", + "H_start": "N7", + "H_stop": "N6", + "S_start": "C1'", + "S_stop": "N3", + "B_start": "C1'", + "B_stop": "N9", + }, + "G": { + "W_start": "N1", + "W_stop": "O6", + "H_start": "N7", + "H_stop": "O6", + "S_start": "C1'", + "S_stop": "N3", + "B_start": "C1'", + "B_stop": "N9", + }, + "C": { + "W_start": "N3", + "W_stop": "N4", + "H_start": "C5", + "H_stop": "N4", + "S_start": "C1'", + "S_stop": "O2", + "B_start": "C1'", + "B_stop": "N1", + }, + "U": { + "W_start": "N3", + "W_stop": "O4", + "H_start": "C5", + "H_stop": "O4", + "S_start": "C1'", + "S_stop": "O2", + "B_start": "C1'", + "B_stop": "N1", + }, + } + + +def calculate_hb_counts( + atom_array: AtomArray, + token_level_data: dict, + mol_info: NucMolInfo, + cutoff_HA_dist: float = 2.5, + cutoff_DA_dist: float = 3.9, +): + """Count hydrogen bonds between residue pairs using HBPLUS. + + Args: + atom_array: Structure to analyse. + token_level_data: Token-level metadata dict (must contain + ``token_id_list`` and ``resi2index``). + mol_info: Molecular-info object for backbone/sidechain atom lookup. + cutoff_HA_dist: H–A distance cutoff (Å) passed to HBPLUS. + cutoff_DA_dist: D–A distance cutoff (Å) passed to HBPLUS. + + Returns: + np.ndarray of shape ``(I, I, 3)`` (int32) where the last axis + encodes: 0 = BB–BB, 1 = BB–SC, 2 = SC–SC H-bond counts. + """ + hbplus_exe = os.environ.get("HBPLUS_PATH") + + if hbplus_exe is None or hbplus_exe == "": + raise ValueError( + "HBPLUS_PATH environment variable not set. " + "Please set it to the path of the hbplus executable in order to calculate hydrogen bonds." + ) + with tempfile.TemporaryDirectory() as tmpdir: + dtstr = datetime.now().strftime("%Y%m%d%H%M%S") + pdb_filename = f"{dtstr}_{np.random.randint(10000)}.pdb" + pdb_path = os.path.join(tmpdir, pdb_filename) + atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path) + + subprocess.call( + [ + hbplus_exe, + "-h", + str(cutoff_HA_dist), + "-d", + str(cutoff_DA_dist), + pdb_path, + pdb_path, + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + cwd=tmpdir, + ) + + num_resis_total = len(token_level_data["token_id_list"]) + + hbond_count = np.zeros((num_resis_total, num_resis_total, 3), dtype=np.int32) + + hb2_path = pdb_path.replace("pdb", "hb2") + if not os.path.exists(hb2_path): + print("WARNING: HB2 file could not be found; skipping NA SS metric") + return hbond_count + with open(hb2_path, "r") as hb2_f: + for i, line in enumerate(hb2_f): + if i < 8: + continue + if len(line) < 28: + continue + + d_chain_iid = chain_map[line[0]] + d_resi = int(line[1:5].strip()) + d_resn = line[6:9].strip() + d_atom_name = line[9:13].strip() + + # Initialize donor/acceptor sidechain/backbone flags: + # then replace with True if valid for summation + d_is_sc = False + d_is_bb = False + a_is_sc = False + a_is_bb = False + + d_mask = ( + (atom_array.atom_name == d_atom_name) + & (atom_array.res_name == d_resn) + & (atom_array.res_id == d_resi) + & (atom_array.chain_iid == d_chain_iid) + ) + # d_atm = atom_array[d_mask] + # d_idx = d_atm.token_id + d_idx = token_level_data["resi2index"].get( + f"{d_chain_iid}__{d_resi}", None + ) + if d_idx is None: + continue + + # Handle standard polymer residues for donor atom: + if d_resn in ATOM_REGION_BY_RESI.keys(): + d_is_sc = d_atom_name in ATOM_REGION_BY_RESI[d_resn]["sc"] + d_is_bb = d_atom_name in ATOM_REGION_BY_RESI[d_resn]["bb"] + else: + # If non-polymer, define any ligand HBonding atom as backbone: + if d_mask.sum() > 0: + d_is_bb = atom_array[d_mask][0].is_ligand + + a_chain_iid = chain_map[line[14]] + a_resi = int(line[15:19].strip()) + a_resn = line[20:23].strip() + a_atom_name = line[23:27].strip() + + a_mask = ( + (atom_array.atom_name == a_atom_name) + & (atom_array.res_name == a_resn) + & (atom_array.res_id == a_resi) + & (atom_array.chain_iid == a_chain_iid) + ) + a_idx = token_level_data["resi2index"].get( + f"{a_chain_iid}__{a_resi}", None + ) + if a_idx is None: + continue + + # Handle standard polymer residues for acceptor atom: + if a_resn in ATOM_REGION_BY_RESI.keys(): + a_is_sc = a_atom_name in ATOM_REGION_BY_RESI[a_resn]["sc"] + a_is_bb = a_atom_name in ATOM_REGION_BY_RESI[a_resn]["bb"] + else: + # If non-polymer, define any ligand HBonding atom as backbone: + if a_mask.sum() > 0: + a_is_bb = atom_array[a_mask][0].is_ligand + + # 0 -> both backbone (BB-BB) + hbond_count[a_idx, d_idx, 0] += a_is_bb * d_is_bb + hbond_count[d_idx, a_idx, 0] += d_is_bb * a_is_bb + + # 1 -> one backbone, one sidechain (BB-SC) + hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | ( + a_is_sc * d_is_bb + ) + hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | ( + d_is_sc * a_is_bb + ) + + # 2 -> both sidechain (SC-SC) + hbond_count[a_idx, d_idx, 2] += a_is_sc * d_is_sc + hbond_count[d_idx, a_idx, 2] += d_is_sc * a_is_sc + """ + try: + os.remove(pdb_path) + os.remove(hb2_path) + except: + print("temp pdb/hb already removed or not created to begin with") + """ + return hbond_count + + +def find_planar_positions( + atom_array: AtomArray, + mol_info: NucMolInfo, + tol: float = 1e-2, +) -> Dict: + """Identify residues with planar sidechains via known atom lists or PCA plane-fitting. + + For canonical residues the planar atoms are looked up from ``mol_info``; + for non-canonical residues a plane is fitted to the four tip-most sidechain + atoms, and all atoms within *tol* of that plane are returned. + + Args: + atom_array: Structure to analyse. + mol_info: Molecular-info object supplying per-residue planar atom lists. + tol: Distance tolerance (Å) from the fitted plane for an atom to be + considered planar. + + Returns: + Dictionary ``{(chain_iid, res_id): [atom_name, ...]}`` mapping each + unique residue position to its list of planar sidechain atom names. + """ + unique_positions_list = [] + for atm in atom_array: + pos_id = (atm.chain_iid, atm.res_id, atm.res_name) + if pos_id not in unique_positions_list: + unique_positions_list.append(pos_id) + + # Get candidate planar atoms: + planar_atom_list_dict = {} + + # for chain_iid, res_id in unique_positions_list: + for chain_iid, res_id, res_name in unique_positions_list: + mask = ( + (atom_array.chain_iid == chain_iid) + & (atom_array.res_id == res_id) + & (atom_array.res_name == res_name) + ) + res_atoms = atom_array[mask] + + # If possible, speed up by using known planar atoms for this residue type: + if res_name in PLANAR_ATOMS_BY_RESI.keys(): + # Shared atoms between residue and known planar atoms for that residue type: + planar_atom_list = list( + set([atm.atom_name for atm in res_atoms]) + & set(PLANAR_ATOMS_BY_RESI[res_name]) + ) + planar_atom_list_dict[(chain_iid, res_id)] = planar_atom_list + + # If unknown or noncanonical residue, compute planar atoms geometrically: + else: + candidate_planar_atm_names = [] + candidate_planar_atm_coords = [] + + for atm in res_atoms: + # Can pre-filter protein planar atoms: + if atm.is_protein and (atm.atom_name in AA_PLANAR_ATOMS): + candidate_planar_atm_names.append(atm.atom_name) + candidate_planar_atm_coords.append(atm.coord) + # Can pre-filter nucleic acid planar atoms: + elif (atm.is_rna or atm.is_dna) and (atm.atom_name in NA_PLANAR_ATOMS): + candidate_planar_atm_names.append(atm.atom_name) + candidate_planar_atm_coords.append(atm.coord) + # Otherwise, consider all atoms for plane fitting: + else: + candidate_planar_atm_names.append(atm.atom_name) + candidate_planar_atm_coords.append(atm.coord) + + # reverse order to prioritize atoms further away from bb: + candidate_planar_atm_names = list(reversed(candidate_planar_atm_names)) + candidate_planar_atm_coords = list(reversed(candidate_planar_atm_coords)) + + # Use first four candidate atoms only to define the plane: + if len(candidate_planar_atm_coords) >= 4: + coords = np.asarray(candidate_planar_atm_coords, dtype=np.float32) + + # compute 4-atom based plane: + quad_coords = coords[:4, :] + + # fit plane via PCA (use smallest‑variance eigenvector as normal) + quad_center = quad_coords.mean(axis=0, keepdims=True) + all_quad_centered = coords - quad_center + quad_centered = quad_coords - quad_center + # covariance matrix + quad_cov = (quad_centered.T @ quad_centered) / max( + quad_coords.shape[0] - 1, 1 + ) + # eigen decomposition + _, quad_eigvecs = np.linalg.eigh(quad_cov) + quad_normal = quad_eigvecs[:, 0] # eigenvector with smallest eigenvalue + quad_normal = quad_normal / (np.linalg.norm(quad_normal) + 1e-8) + # compute distances from plane for all candidate atoms + quad_dists = np.abs(all_quad_centered @ quad_normal) + # keep only atoms within tolerance + quad_valid_mask = quad_dists <= tol + + # Filter for if we have a valid plane in the first place: + valid_plane_filter = np.nanmax(quad_dists[:4]) < tol + # Filter for if we have enough atoms in the plane: + plane_atom_filter = int(np.sum(quad_valid_mask)) >= 4 + if valid_plane_filter and plane_atom_filter: + # Set the planar atom list for this position to those that are within tol of the plane: + # using quad_valid_mask and candidate_planar_atm_names: + planar_atom_list = [ + n + for n, keep in zip( + candidate_planar_atm_names, quad_valid_mask.tolist() + ) + if keep + ] + + # not enough atoms close to a common plane + else: + planar_atom_list = [] + + else: + # need at least 4 atoms to define a robust plane + planar_atom_list = [] + + planar_atom_list_dict[(chain_iid, res_id)] = planar_atom_list + + return planar_atom_list_dict + + +def make_coord_list( + atom_array: AtomArray, + residue_list: list[str], + chain_list: list[str], + atom_list: list[str], +) -> list[list[str]]: + """Extract per-residue representative coordinates from an AtomArray. + + All three input lists must have the same length. Missing atoms are + filled with ``[NaN, NaN, NaN]``. + + Args: + atom_array: Biotite AtomArray to query. + residue_list: Residue IDs (one per token). + chain_list: Chain identifiers (one per token). + atom_list: Atom names to extract (use ``"atomized"`` to take the + first atom of the residue). + + Returns: + List of ``[x, y, z]`` coordinate lists, same length as input. + """ + coord_list = [] + for res_id, chain_id, atom_name in zip(residue_list, chain_list, atom_list): + # Check if the residue exists in the atom array + if atom_name == "atomized": + # Check for atomized residue, in which case we take the first atom of the residue + # full mask should be length-1 if atomized + mask = (atom_array.chain_id == chain_id) & (atom_array.res_id == res_id) + else: + # General case for non-atomized residues + # should have a unique solution, but we take the first entry either way. + mask = ( + (atom_array.chain_id == chain_id) + & (atom_array.res_id == res_id) + & (atom_array.atom_name == atom_name) + ) + + # Get the coordinates for the masked atoms + coords = atom_array.coord[mask][0:1] + + if len(coords) < 1: + coord_list.append([float("nan"), float("nan"), float("nan")]) + else: + coord_list.append(coords[0].tolist()) + + return coord_list + + +def get_token_level_metadata( + atom_array: AtomArray, + mol_info: "NucMolInfo", + *, + NA_only: bool = False, + planar_only: bool = True, + seq_cutoff=2, + gap_length=200, +) -> dict: + """Build lightweight token-level metadata (no coordinate geometry). + + Sufficient for SS reconstruction, loop labeling from ``bp_partners``, + and inference-time SS specification parsing. For geometry keys + (``xyz_planar``, ``frame_xyz``, ``M_i``), follow up with + :func:`add_token_level_geometry_data`. + + Args: + atom_array: Structure to analyse. + mol_info: Molecular-info constants. + NA_only: If True, restrict filter_mask to nucleic-acid tokens. + planar_only: If True, restrict filter_mask to tokens with planar + sidechains. + seq_cutoff: Sequence-distance threshold for the ``seq_neighbors`` + boolean mask. + gap_length: Artificial gap inserted between chains for relative + sequence position computation. + + Returns: + Dict with keys: ``token_starts``, ``token_index``, ``is_na``, + ``is_planar``, ``chain_list``, ``chain_iid_list``, ``resi_list``, + ``resn_list``, ``token_id_list``, ``resi2index``, ``len_s``, + ``seq_neighbors``, ``na_inds``, ``na_tensor_inds``, + ``filter_mask``, ``rep_atom_list``, ``S_start_atom_list``, + ``S_stop_atom_list``, ``include_geometry`` (False). + """ + + # Use residue starts (not token starts) so atomized atoms within one residue + # map to a single NA-SS position. + token_starts = struc.get_residue_starts(atom_array) + token_level_array = atom_array[token_starts] + + token_index = np.arange(len(token_starts)) + + # molecule type flags + # Instantiate encoding locally to avoid retaining large arrays at module scope. + sequence_encoding = AF3SequenceEncoding() + + ################### + # is_protein = np.isin( + # token_level_array.res_name, + # sequence_encoding.all_res_names[sequence_encoding.is_aa_like], + # ) + ################## + is_rna = np.isin( + token_level_array.res_name, + sequence_encoding.all_res_names[sequence_encoding.is_rna_like], + ) + is_dna = np.isin( + token_level_array.res_name, + sequence_encoding.all_res_names[sequence_encoding.is_dna_like], + ) + + is_na_arr = (is_dna | is_rna).astype(bool) + + chain_list: list[str] = [] + chain_iid_list: list[str] = [] + resi_list: list[int] = [] + ind_list: list[int] = [] + res_name_list: list[str] = [] + token_id_list: list[str] = [] + + rep_atom_list: list[str | None] = [] + S_start_atom_list: list[str | None] = [] + S_stop_atom_list: list[str | None] = [] + sc_planarity_list: list[bool] = [] + + for i, atm in enumerate(token_level_array): + chain_list.append(atm.chain_id) + chain_iid_list.append(atm.chain_iid) + resi_list.append(int(atm.res_id)) + ind_list.append(int(i)) + res_name_list.append(atm.res_name) + token_id_list.append(str(atm.token_id)) + + if atm.is_polymer and (atm.res_name in HAS_PLANAR_SC.keys()): + sc_planarity_list.append(bool(HAS_PLANAR_SC[atm.res_name])) + else: + sc_planarity_list.append(False) + + # representative & sugar-edge atoms + if is_glycine(atm.res_name) | is_protein_unknown(atm.res_name): + rep_atom_i = "CA" + S_start_atom_i = None + S_stop_atom_i = None + elif is_standard_aa_not_glycine(atm.res_name): + rep_atom_i = "CA" + S_start_atom_i = "CA" + S_stop_atom_i = "CB" + elif is_pyrimidine(atm.res_name): + rep_atom_i = "C1'" + S_start_atom_i = "C1'" + S_stop_atom_i = "O2" + elif is_purine(atm.res_name): + rep_atom_i = "C1'" + S_start_atom_i = "C1'" + S_stop_atom_i = "N3" + elif is_unknown_nucleotide(atm.res_name): + rep_atom_i = "C1'" + S_start_atom_i = None + S_stop_atom_i = None + elif getattr(atm, "atomize", False): + rep_atom_i = atm.atom_name + S_start_atom_i = None + S_stop_atom_i = None + else: + rep_atom_i = None + S_start_atom_i = None + S_stop_atom_i = None + + rep_atom_list.append(rep_atom_i) + S_start_atom_list.append(S_start_atom_i) + S_stop_atom_list.append(S_stop_atom_i) + + # residue index <-> token index map + resi2index = { + f"{c}__{r}": i for c, r, i in zip(chain_iid_list, resi_list, ind_list) + } + + # relative sequence positions w/ chain gaps + rel_pos_list: list[int] = [] + current_chain = "" + chn_bias = -gap_length + for r, c in zip(resi_list, chain_iid_list): + if c != current_chain: + chn_bias += gap_length + current_chain = c + rel_pos_list.append(int(r + chn_bias)) + + rel_pos = np.asarray(rel_pos_list, dtype=np.int64) + seq_neighbors = np.abs(rel_pos[:, None] - rel_pos[None, :]) <= int(seq_cutoff) + + na_inds = np.nonzero(is_na_arr)[0].tolist() + na_tensor_inds = {na_i: i for i, na_i in enumerate(na_inds)} + + # Cheap planarity heuristic from residue name lookup + is_planar_arr = np.asarray(sc_planarity_list, dtype=bool) + + # filter mask using NA_only / planar_only flags + if NA_only and planar_only: + filter_mask = is_na_arr & is_planar_arr + elif NA_only and (not planar_only): + filter_mask = is_na_arr.copy() + elif (not NA_only) and planar_only: + filter_mask = is_planar_arr.copy() + else: + filter_mask = np.ones_like(is_na_arr, dtype=bool) + + return { + "token_starts": token_starts, + "token_index": token_index, + "is_na": is_na_arr, + "is_planar": is_planar_arr, + "chain_list": chain_list, + "chain_iid_list": chain_iid_list, + "resi_list": resi_list, + "resn_list": res_name_list, + "token_id_list": token_id_list, + "resi2index": resi2index, + "len_s": int(len(token_level_array)), + "seq_neighbors": seq_neighbors, + "na_inds": na_inds, + "na_tensor_inds": na_tensor_inds, + "filter_mask": filter_mask, + "rep_atom_list": rep_atom_list, + "S_start_atom_list": S_start_atom_list, + "S_stop_atom_list": S_stop_atom_list, + "include_geometry": False, + } + + +def add_token_level_geometry_data( + atom_array: AtomArray, + mol_info: "NucMolInfo", + token_level_data: dict, + *, + NA_only: bool = False, + planar_only: bool = True, +) -> dict: + """Augment token-level metadata with coordinate-derived geometry fields. + + Populates ``xyz_planar``, ``xyz_S_start``, ``xyz_S_stop``, + ``frame_xyz``, ``M_i`` and updates ``is_planar`` / ``filter_mask`` + using coordinate-derived planarity. Sets ``include_geometry=True``. + + No-ops if geometry was already computed. + + Args: + atom_array: Structure to extract coordinates from. + mol_info: Molecular-info constants. + token_level_data: Dict produced by :func:`get_token_level_metadata` + (modified in-place and returned). + NA_only: Restrict filter_mask to nucleic-acid tokens. + planar_only: Restrict filter_mask to tokens with planar sidechains. + + Returns: + The same ``token_level_data`` dict, augmented with geometry keys. + """ + + if bool(token_level_data.get("include_geometry", False)): + return token_level_data + + # Backward-compatibility: older token_level_data dicts (or user-provided ones) + # may not contain the metadata keys this function needs. + required_keys = ( + "chain_iid_list", + "chain_list", + "resi_list", + "rep_atom_list", + "S_start_atom_list", + "S_stop_atom_list", + "is_na", + ) + if any(k not in token_level_data for k in required_keys): + token_level_data = get_token_level_metadata( + atom_array, + mol_info, + NA_only=NA_only, + planar_only=planar_only, + ) + + chain_iid_list: list[str] = token_level_data["chain_iid_list"] + chain_list: list[str] = token_level_data["chain_list"] + resi_list: list[int] = token_level_data["resi_list"] + rep_atom_list: list[str | None] = token_level_data["rep_atom_list"] + S_start_atom_list: list[str | None] = token_level_data["S_start_atom_list"] + S_stop_atom_list: list[str | None] = token_level_data["S_stop_atom_list"] + + planar_atom_list_dict = find_planar_positions( + atom_array, mol_info + ) # {(chain_iid, res_id): [atom_name, ...]} + has_planar_sc: list[bool] = [] + + xyz_planar: list[ + list[list[float]] + ] = [] # list[I] of [K_i, 3] (K_i varies per residue) + xyz_S_start: list[list[float]] = [] # list[I] of [3] + xyz_S_stop: list[list[float]] = [] # list[I] of [3] + + for c, r, S_start_atm, S_stop_atm in zip( + chain_iid_list, + resi_list, + S_start_atom_list, + S_stop_atom_list, + ): + planar_atoms_i = planar_atom_list_dict[(c, r)] + has_planar_sc.append(bool(len(planar_atoms_i) >= 4)) + + atom_array_i = atom_array[ + (atom_array.chain_iid == c) & (atom_array.res_id == r) + ] + + planar_coords_i: list[list[float]] = [] + for pl_atm_name_j in planar_atoms_i: + pl_atom_array_ij = atom_array_i[atom_array_i.atom_name == pl_atm_name_j] + if len(pl_atom_array_ij) == 0: + planar_coords_i.append([float("nan"), float("nan"), float("nan")]) + else: + planar_coords_i.append(pl_atom_array_ij[0].coord) + + xyz_planar.append( + planar_coords_i if len(planar_coords_i) > 3 else [[float("nan")] * 3] + ) + + if S_start_atm is None: + xyz_S_start.append([float("nan"), float("nan"), float("nan")]) + else: + S_start_atom_array_i = atom_array_i[atom_array_i.atom_name == S_start_atm] + xyz_S_start.append( + [float("nan"), float("nan"), float("nan")] + if len(S_start_atom_array_i) == 0 + else S_start_atom_array_i[0].coord + ) + + if S_stop_atm is None: + xyz_S_stop.append([float("nan"), float("nan"), float("nan")]) + else: + S_stop_atom_array_i = atom_array_i[atom_array_i.atom_name == S_stop_atm] + xyz_S_stop.append( + [float("nan"), float("nan"), float("nan")] + if len(S_stop_atom_array_i) == 0 + else S_stop_atom_array_i[0].coord + ) + + del atom_array_i + + # frame coordinates and backbone direction + frame_xyz = np.asarray( # [I, 3] representative-atom coordinates + make_coord_list(atom_array, resi_list, chain_list, rep_atom_list), + dtype=np.float32, + ) + + padded_centers = np.concatenate( + [frame_xyz[:1], frame_xyz, frame_xyz[-1:]], axis=0 + ) # [I+2, 3] + M_i = ( + ( # [I, 3] smoothed backbone-direction vectors + (padded_centers[1:-1] - padded_centers[:-2]) + + (padded_centers[2:] - padded_centers[1:-1]) + ) + / 2.0 + ) + + is_planar_arr = np.asarray(has_planar_sc, dtype=bool) # [I] + token_level_data["is_planar"] = is_planar_arr + + is_na_arr = np.asarray(token_level_data["is_na"], dtype=bool) # [I] + if NA_only and planar_only: + filter_mask = is_na_arr & is_planar_arr + elif NA_only and (not planar_only): + filter_mask = is_na_arr.copy() + elif (not NA_only) and planar_only: + filter_mask = is_planar_arr.copy() + else: + filter_mask = np.ones_like(is_na_arr, dtype=bool) + token_level_data["filter_mask"] = filter_mask # [I] bool + + token_level_data.update( + { + "xyz_planar": xyz_planar, + "xyz_S_start": xyz_S_start, + "xyz_S_stop": xyz_S_stop, + "frame_xyz": frame_xyz, + "M_i": M_i, + "include_geometry": True, + } + ) + + del planar_atom_list_dict, padded_centers + return token_level_data + + +# --------------------------------------------------------------------------- +# Sub-calculations used by compute_nucleic_ss +# --------------------------------------------------------------------------- + + +def _compute_local_frames( + xyz_planar: list[np.ndarray], + planar_centers: np.ndarray, + M_i: np.ndarray, + *, + xyz_S_start: list | None = None, + xyz_S_stop: list | None = None, + compute_full_frame: bool = False, + eps: float = 1e-8, +) -> dict[str, np.ndarray]: + """Build per-residue local coordinate frames from planar sidechain atoms. + + The base-normal direction Z_i is always computed via PCA on the planar + atom cloud, corrected for backbone direction. When *compute_full_frame* + is True the sugar-edge vector is used to derive X_i and Y_i as well. + + Args: + xyz_planar: Per-residue planar-atom coordinates, list[I] of [K_i, 3]. + planar_centers: Sidechain planar-atom centroids, [I, 3]. + M_i: Backbone-direction vectors, [I, 3]. + xyz_S_start: Sugar-edge start coordinates, list[I] of [3]. + Required when *compute_full_frame* is True. + xyz_S_stop: Sugar-edge stop coordinates, list[I] of [3]. + Required when *compute_full_frame* is True. + compute_full_frame: If True, also compute X_i and Y_i. + eps: Small constant for numerical stability. + + Returns: + Dict with ``"Z_i"`` (always), and ``"X_i"``, ``"Y_i"`` when + *compute_full_frame* is True. Each array has shape ``[I, 3]``. + """ + n_tokens = len(xyz_planar) + + # Mean-centre the planar atoms per residue + centered_points = [ # list[I] of [K_i, 3] + np.asarray(xyz_i, dtype=np.float32) - cen_i + for xyz_i, cen_i in zip(xyz_planar, planar_centers) + ] + + # PCA → eigenvectors per residue + eigenvectors = np.full((n_tokens, 3, 3), np.nan, dtype=np.float32) # [I, 3, 3] + + for i, xyz_i in enumerate(centered_points): + xyz_i = xyz_i[~np.isnan(xyz_i).any(axis=1)] + if xyz_i.shape[0] >= 3: + cov_matrix = np.einsum("ij,ik->jk", xyz_i, xyz_i) / max( # [3, 3] + xyz_i.shape[0] - 1, 1 + ) + _, eigvecs = np.linalg.eigh(cov_matrix) # [3, 3] + eigenvectors[i] = eigvecs + + # Base-normal: smallest-eigenvalue direction, corrected for backbone dir + N_i = eigenvectors[:, :, 0] # [I, 3] + N_i = N_i / (np.linalg.norm(N_i, axis=1, keepdims=True) + eps) + + Z_i = N_i * np.sum(M_i * N_i, axis=-1, keepdims=True) # [I, 3] + Z_i = Z_i / (np.linalg.norm(Z_i, axis=-1, keepdims=True) + eps) + + result: dict[str, np.ndarray] = {"Z_i": Z_i} + + if compute_full_frame: + if xyz_S_start is None or xyz_S_stop is None: + raise ValueError("xyz_S_start and xyz_S_stop are required for full frame") + + X_s_i = ( # [I, 3] sugar-edge direction + np.asarray(xyz_S_stop, dtype=np.float32) + - np.asarray(xyz_S_start, dtype=np.float32) + ) + X_s_i = X_s_i / (np.linalg.norm(X_s_i, axis=-1, keepdims=True) + eps) + + X_i = np.cross(Z_i, X_s_i) # [I, 3] + X_i = X_i / (np.linalg.norm(X_i, axis=-1, keepdims=True) + eps) + result["X_i"] = X_i + + Y_i = np.cross(X_i, Z_i) # [I, 3] + Y_i = Y_i / (np.linalg.norm(Y_i, axis=-1, keepdims=True) + eps) + result["Y_i"] = Y_i + + return result + + +def _compute_pairwise_geometry( + Z_i: np.ndarray, + frame_D_ij_vec: np.ndarray, + sc_D_ij_vec: np.ndarray, + *, + X_i: np.ndarray | None = None, + clamp: bool = True, + compute_opening: bool = False, + eps: float = 1e-8, +) -> dict[str, np.ndarray]: + """Compute pairwise base-step geometry between all residue pairs. + + Derives the pairwise coordinate frame (X_ij, Y_ij, Z_ij) and the + base-pair geometry parameters: rise (H_ij), buckle (B_ij), propeller + (P_ij), and optionally opening angle (O_ij). + + Args: + Z_i: Per-residue base-normal vectors, [I, 3]. + frame_D_ij_vec: Pairwise backbone displacement vectors, [I, I, 3]. + sc_D_ij_vec: Pairwise sidechain-centroid displacement vectors, [I, I, 3]. + X_i: Per-residue local X-axis, [I, 3]. Required when + *compute_opening* is True. + clamp: Clamp cosines to [-1, 1] before ``arccos``. + compute_opening: If True, compute opening angle O_ij. + eps: Small constant for numerical stability. + + Returns: + Dict with keys ``"H_ij"`` [I, I], ``"B_ij"`` [I, I], + ``"P_ij"`` [I, I], ``"base_ori_ij"`` [I, I], + ``"X_ij"`` [I, I, 3], ``"Y_ij"`` [I, I, 3], + ``"Z_ij"`` [I, I, 3], and optionally ``"O_ij"`` [I, I]. + """ + # Orientation-selected pairwise Z-axis + Z_sum = Z_i[:, None, :] + Z_i[None, :, :] # [I, I, 3] + Z_diff = Z_i[:, None, :] - Z_i[None, :, :] # [I, I, 3] + Z_ij_oris = 0.5 * np.stack((Z_sum, Z_diff), axis=0) # [2, I, I, 3] + + base_ori_ij = ( # [I, I] 0=parallel, 1=antiparallel + np.linalg.norm(Z_ij_oris[1], axis=-1) > np.linalg.norm(Z_ij_oris[0], axis=-1) + ).astype(np.int64) + + Z_ij = np.where( + base_ori_ij[..., None] == 0, Z_ij_oris[0], Z_ij_oris[1] + ) # [I, I, 3] + Z_ij = Z_ij / (np.linalg.norm(Z_ij, axis=-1, keepdims=True) + eps) + + # Pairwise Y (inter-residue direction) and X axes + Y_ij = frame_D_ij_vec / ( + np.linalg.norm(frame_D_ij_vec, axis=-1, keepdims=True) + eps + ) # [I, I, 3] + X_ij = np.cross(Z_ij, Y_ij) # [I, I, 3] + X_ij = X_ij / (np.linalg.norm(X_ij, axis=-1, keepdims=True) + eps) + + # Rise (H_ij) + H_ij = np.sum(sc_D_ij_vec * Z_ij, axis=-1) # [I, I] + D_ij = np.linalg.norm(sc_D_ij_vec, axis=-1) # [I, I] + + # Buckle (B_ij) + proj_Z_i_YZ = ( # [I, I, 3] + np.sum(Z_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij + + np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij + ) + proj_Z_i_YZ_norm = proj_Z_i_YZ / ( + np.linalg.norm(proj_Z_i_YZ, axis=-1, keepdims=True) + eps + ) + cos_buckle = np.sum( + proj_Z_i_YZ_norm * (-proj_Z_i_YZ_norm.swapaxes(0, 1)), axis=-1 + ) # [I, I] + + # Propeller (P_ij) + proj_Z_i_ZX = ( # [I, I, 3] + np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij + + np.sum(Z_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij + ) + proj_Z_i_ZX_norm = proj_Z_i_ZX / ( + np.linalg.norm(proj_Z_i_ZX, axis=-1, keepdims=True) + eps + ) + cos_propeller = np.sum( + proj_Z_i_ZX_norm * (-proj_Z_i_ZX_norm.swapaxes(0, 1)), axis=-1 + ) # [I, I] + + if clamp: + cos_buckle = np.clip(cos_buckle, -1.0, 1.0) + cos_propeller = np.clip(cos_propeller, -1.0, 1.0) + + B_ij = np.arccos(cos_buckle) # [I, I] + P_ij = np.arccos(cos_propeller) # [I, I] + + result: dict[str, np.ndarray] = { + "H_ij": H_ij, + "B_ij": B_ij, + "P_ij": P_ij, + "D_ij": D_ij, + "base_ori_ij": base_ori_ij, + "X_ij": X_ij, + "Y_ij": Y_ij, + "Z_ij": Z_ij, + } + + # Opening angle (O_ij) — purely diagnostic + if compute_opening: + if X_i is None: + raise ValueError("X_i is required to compute opening angle") + + proj_X_i_XY = ( # [I, I, 3] + np.sum(X_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij + + np.sum(X_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij + ) + proj_X_i_XY_norm = proj_X_i_XY / ( + np.linalg.norm(proj_X_i_XY, axis=-1, keepdims=True) + eps + ) + cos_opening = np.sum( + proj_X_i_XY_norm * proj_X_i_XY_norm.swapaxes(0, 1), axis=-1 + ) # [I, I] + if clamp: + cos_opening = np.clip(cos_opening, -1.0, 1.0) + result["O_ij"] = np.arccos(cos_opening) # [I, I] + + return result + + +def _compute_basepair_mask( + hbond_count: np.ndarray, + seq_neighbors: np.ndarray, + H_ij: np.ndarray, + B_ij: np.ndarray, + P_ij: np.ndarray, + D_ij: np.ndarray, + mol_info, + *, + bool_only: bool = False, + eps: float = 1e-8, +) -> dict[str, np.ndarray] | np.ndarray: + """Identify base pairs by combining H-bond scores with geometry filters. + + Computes a sigmoid-based base-pair probability from weighted H-bond + counts and gates it with rise / buckle / propeller geometry limits. + + Args: + hbond_count: H-bond counts, [I, I, 3] (BB-BB / BB-SC / SC-SC). + seq_neighbors: Sequence-neighbor boolean mask, [I, I]. + H_ij: Rise displacement, [I, I]. + B_ij: Buckle angle (radians), [I, I]. + P_ij: Propeller angle (radians), [I, I]. + mol_info: Molecular-info object with ``bp_summation_weights``, + ``bp_hbond_coeff``, ``min_hbonds_for_bp``, ``bp_val_cutoff``, + and ``base_geometry_limits``. + bool_only: If True, return only the boolean mask array. + eps: Small constant for numerical stability. + + Returns: + If *bool_only*: ``np.ndarray`` of shape ``(I, I)`` (bool). + Otherwise: dict with ``"basepairs_bool_ij"`` [I, I] (bool), + ``"basepairs_ij"`` [I, I] (float), and + ``"hbond_summation"`` [I, I] (float). + """ + hbond_summation = np.tensordot( # [I, I] + hbond_count.astype(np.float32), + np.asarray(mol_info.bp_summation_weights, dtype=np.float32), + axes=([2], [0]), + ) + + logits = mol_info.bp_hbond_coeff * ( # [I, I] + hbond_summation - (mol_info.min_hbonds_for_bp - 1) + ) + bp_preds = (1.0 / (1.0 + np.exp(-logits))) + eps # [I, I] + + # Geometry filters + H_ij_filter = ( # [I, I] + (H_ij >= -mol_info.base_geometry_limits["H_ij"]) + & (H_ij <= mol_info.base_geometry_limits["H_ij"]) + ) + B_ij_filter = ( # [I, I] + (B_ij <= mol_info.base_geometry_limits["B_ij"]) + | (B_ij >= math.pi - mol_info.base_geometry_limits["B_ij"]) + ) + P_ij_filter = ( # [I, I] + (P_ij <= mol_info.base_geometry_limits["P_ij"]) + | (P_ij >= math.pi - mol_info.base_geometry_limits["P_ij"]) + ) + + D_ij_filter = D_ij <= mol_info.base_geometry_limits["D_ij"] + + bp_geom_filter = H_ij_filter & B_ij_filter & P_ij_filter & D_ij_filter # [I, I] + + if bool_only: + basepairs_bool_ij = ( # [I, I] + (~seq_neighbors) + & bp_geom_filter + & (bp_preds >= float(mol_info.bp_val_cutoff)) + ) + return basepairs_bool_ij + + basepairs_ij = ( # [I, I] + (~seq_neighbors).astype(np.float32) + * bp_geom_filter.astype(np.float32) + * bp_preds.astype(np.float32) + ) + basepairs_bool_ij = basepairs_ij >= mol_info.bp_val_cutoff # [I, I] + + return { + "basepairs_bool_ij": basepairs_bool_ij, + "basepairs_ij": basepairs_ij, + "hbond_summation": hbond_summation, + } + + +def compute_nucleic_ss( + mol_info, + token_level_data, + hbond_count, + clamp_pairwise_params=True, + eps=1e-8, + *, + return_local_params: bool = False, + return_pairwise_geometry: bool = False, + return_opening_angle: bool = False, + return_basepairs_only: bool = False, +): + """Compute nucleic-acid pairwise base-pair geometry and filters. + + Operates in two modes: + + * **Fast annotation** (default / ``return_basepairs_only=True``): returns + only ``basepairs_bool_ij`` and frees intermediate arrays. + * **Diagnostic**: additionally returns local/pairwise geometry when + ``return_pairwise_geometry``, ``return_local_params``, or + ``return_opening_angle`` are set. + + Args: + mol_info: Molecular-info constants (geometry limits, H-bond weights). + token_level_data: Token-level dict with geometry (from + :func:`add_token_level_geometry_data`). + hbond_count: H-bond count array, shape ``(I_full, I_full, 3)``. + clamp_pairwise_params: Clamp cosines to [-1, 1] before ``arccos``. + eps: Small constant for numerical stability. + return_local_params: Return per-residue X/Y/Z local frames. + return_pairwise_geometry: Return pairwise X_ij/Y_ij/Z_ij arrays. + return_opening_angle: Return pairwise opening angle O_ij. + return_basepairs_only: Return only the boolean base-pair mask + (fastest path). + + Returns: + If ``return_basepairs_only``: ``np.ndarray`` of shape ``(I, I)`` + (bool) — the base-pair boolean mask. + + Otherwise: dict ``{"pair_params": {...}, "local_params": {...}}`` + containing the requested geometry arrays (all shape ``(I, I)`` or + ``(I, 3)``). + """ + + mask_1d = np.asarray(token_level_data["filter_mask"], dtype=bool) # [I_full] + + # --- Unpack and filter token-level data ---------------------- + M_i = np.asarray(token_level_data["M_i"], dtype=np.float32)[mask_1d] # [I, 3] + frame_xyz = np.asarray(token_level_data["frame_xyz"], dtype=np.float32)[ + mask_1d + ] # [I, 3] + xyz_S_start = [ + v for v, k in zip(token_level_data["xyz_S_start"], mask_1d) if k + ] # list[I] of [3] + xyz_S_stop = [ + v for v, k in zip(token_level_data["xyz_S_stop"], mask_1d) if k + ] # list[I] of [3] + xyz_planar = [ + v for v, k in zip(token_level_data["xyz_planar"], mask_1d) if k + ] # list[I] of [K_i, 3] + + hbond_count = np.asarray(hbond_count)[mask_1d, :][:, mask_1d] # [I, I, 3] + seq_neighbors = np.asarray(token_level_data["seq_neighbors"], dtype=bool)[ + mask_1d, : + ][:, mask_1d] # [I, I] + + # Nothing passed NA/planar filtering for this structure. + # Return empty outputs instead of failing downstream on np.stack([]). + if len(xyz_planar) == 0: + if return_basepairs_only: + return np.zeros((0, 0), dtype=bool) + + pair_params: dict[str, np.ndarray] = { + "H_ij": np.zeros((0, 0), dtype=np.float32), + "B_ij": np.zeros((0, 0), dtype=np.float32), + "P_ij": np.zeros((0, 0), dtype=np.float32), + "D_ij": np.zeros((0, 0), dtype=np.float32), + "base_ori_ij": np.zeros((0, 0), dtype=np.float32), + "basepairs_bool_ij": np.zeros((0, 0), dtype=bool), + "basepairs_ij": np.zeros((0, 0), dtype=np.float32), + "hbond_summation": np.zeros((0, 0), dtype=np.float32), + } + + if return_opening_angle: + pair_params["O_ij"] = np.zeros((0, 0), dtype=np.float32) + + if return_pairwise_geometry: + pair_params["X_ij"] = np.zeros((0, 0), dtype=np.float32) + pair_params["Y_ij"] = np.zeros((0, 0), dtype=np.float32) + pair_params["Z_ij"] = np.zeros((0, 0), dtype=np.float32) + + nucleic_ss_data: dict = {"pair_params": pair_params} + if return_local_params: + nucleic_ss_data["local_params"] = { + "X_i": np.zeros((0, 3), dtype=np.float32), + "Y_i": np.zeros((0, 3), dtype=np.float32), + "Z_i": np.zeros((0, 3), dtype=np.float32), + } + + return nucleic_ss_data + + # --- Precompute centroids and displacement vectors ----------- + planar_centers = np.stack( # [I, 3] + [ + np.nanmean(np.asarray(xyz_i, dtype=np.float32), axis=0) + for xyz_i in xyz_planar + ], + axis=0, + ).astype(np.float32) + + frame_D_ij_vec = frame_xyz[None, :, :] - frame_xyz[:, None, :] # [I, I, 3] + sc_D_ij_vec = planar_centers[None, :, :] - planar_centers[:, None, :] # [I, I, 3] + + # --- CALC I: per-residue local coordinate frames ------------- + need_full_frame = return_local_params or return_opening_angle + local_frames = _compute_local_frames( + xyz_planar, + planar_centers, + M_i, + xyz_S_start=xyz_S_start if need_full_frame else None, + xyz_S_stop=xyz_S_stop if need_full_frame else None, + compute_full_frame=need_full_frame, + eps=eps, + ) + Z_i = local_frames["Z_i"] # [I, 3] + X_i = local_frames.get("X_i") # [I, 3] or None + + # --- CALC II: pairwise base-step geometry -------------------- + pw_geom = _compute_pairwise_geometry( + Z_i, + frame_D_ij_vec, + sc_D_ij_vec, + X_i=X_i, + clamp=clamp_pairwise_params, + compute_opening=return_opening_angle, + eps=eps, + ) + + # --- CALC III: base-pair identification ---------------------- + bp_result = _compute_basepair_mask( + hbond_count, + seq_neighbors, + pw_geom["H_ij"], + pw_geom["B_ij"], + pw_geom["P_ij"], + pw_geom["D_ij"], + mol_info, + bool_only=return_basepairs_only, + eps=eps, + ) + + if return_basepairs_only: + return bp_result # np.ndarray [I, I] bool + + # --- Assemble output dict ------------------------------------ + assert isinstance(bp_result, dict) + + pair_params: dict[str, np.ndarray] = { + "H_ij": pw_geom["H_ij"], + "B_ij": pw_geom["B_ij"], + "P_ij": pw_geom["P_ij"], + "base_ori_ij": pw_geom["base_ori_ij"], + "basepairs_bool_ij": bp_result["basepairs_bool_ij"], + "basepairs_ij": bp_result["basepairs_ij"], + "hbond_summation": bp_result["hbond_summation"], + } + + if return_opening_angle and "O_ij" in pw_geom: + pair_params["O_ij"] = pw_geom["O_ij"] + + if return_pairwise_geometry: + pair_params["X_ij"] = pw_geom["X_ij"] + pair_params["Y_ij"] = pw_geom["Y_ij"] + pair_params["Z_ij"] = pw_geom["Z_ij"] + + nucleic_ss_data: dict = {"pair_params": pair_params} + if return_local_params and "Y_i" in local_frames: + nucleic_ss_data["local_params"] = { + "X_i": local_frames["X_i"], + "Y_i": local_frames["Y_i"], + "Z_i": local_frames["Z_i"], + } + + return nucleic_ss_data + + +def annotate_na_ss( + atom_array: AtomArray, + *, + NA_only: bool = False, + planar_only: bool = True, + p_canonical_bp_filter: float = 0.0, + mol_info: Optional[NucMolInfo] = None, + overwrite: bool = True, + token_level_data: Optional[dict] = None, + cutoff_HA_dist: float = 3.5, + cutoff_DA_dist: float = 3.5, +) -> AtomArray: + """Compute base pairs and write a ``bp_partners`` annotation onto *atom_array*. + + Uses H-bond counts and pairwise geometry filters to identify base pairs, + then stores the result as a per-atom annotation with the following + semantics: + + * ``[]`` — explicitly unpaired (loop) + * ``[token_id, ...]`` — paired partner token IDs + * ``None`` — unannotated / masked (non-NA or filtered-out tokens) + + Args: + atom_array: Structure to annotate (modified in-place). + NA_only: Restrict geometry filter to nucleic-acid tokens. + planar_only: Restrict geometry filter to tokens with planar + sidechains. + p_canonical_bp_filter: Probability of discarding non-canonical + base pairs (keeps only A–U, A–T, G–C). + mol_info: Molecular-info constants; created if ``None``. + overwrite: If False, merge with existing ``bp_partners``. + token_level_data: Pre-computed metadata dict; augmented with + geometry as needed. + cutoff_HA_dist: H–A distance cutoff (Å) for HBPLUS. + cutoff_DA_dist: D–A distance cutoff (Å) for HBPLUS. + + Returns: + The same *atom_array* with the ``bp_partners`` annotation set. + """ + + if mol_info is None: + mol_info = NucMolInfo() + + # Residue representatives (0..L-1) and their corresponding atom indices. + # Keep this aligned with get_token_level_metadata(), which uses residue starts. + if token_level_data is not None and "token_starts" in token_level_data: + token_starts = np.asarray(token_level_data["token_starts"], dtype=int) + else: + token_starts = struc.get_residue_starts(atom_array) + residue_start_end = np.concatenate([token_starts, [atom_array.array_length()]]) + token_level_array = atom_array[token_starts] + # token_id is assigned token-wise and matches get_token_starts() segmentation. + token_ids: list[int] = [int(t) for t in list(token_level_array.token_id)] + token_res_names: list[str] = [str(rn) for rn in list(token_level_array.res_name)] + + # Compute basepairs on the token graph (respecting NA_only/planar_only filtering) + if token_level_data is None: + token_level_data = get_token_level_metadata( + atom_array, + mol_info, + NA_only=NA_only, + planar_only=planar_only, + ) + token_level_data = add_token_level_geometry_data( + atom_array, + mol_info, + token_level_data, + NA_only=NA_only, + planar_only=planar_only, + ) + # Note: this mask gives positions that are *chemically valid* for forming + # base pairs, which is different from custom mask-generation for features + mask_1d = np.asarray(token_level_data["filter_mask"], dtype=bool) + + subset_idxs = np.nonzero(mask_1d)[0] + + is_na_full = np.asarray(token_level_data["is_na"], dtype=bool) + + hbond_count = calculate_hb_counts( + atom_array, + token_level_data, + mol_info, + cutoff_HA_dist=cutoff_HA_dist, + cutoff_DA_dist=cutoff_DA_dist, + ) + bp_bool = np.asarray( + compute_nucleic_ss( + mol_info, + token_level_data, + hbond_count, + clamp_pairwise_params=True, + eps=1e-8, + return_local_params=False, + return_pairwise_geometry=False, + return_opening_angle=False, + return_basepairs_only=True, + ), + dtype=bool, + ) + + # Apply optional filters + if NA_only: + bp_bool &= is_na_full[:, None] + bp_bool &= is_na_full[None, :] + if planar_only: + n_tokens = bp_bool.shape[0] + has_planar_sc = np.asarray( + token_level_data.get("has_planar_sc", np.ones(n_tokens, dtype=bool)), + dtype=bool, + ) + bp_bool &= has_planar_sc[:, None] + bp_bool &= has_planar_sc[None, :] + + # Optional: filter to canonical Watson-Crick basepairs only. + # Sampled probabilistically to allow mixed supervision during training. + do_canonical_filter = bool( + p_canonical_bp_filter and (np.random.rand() < float(p_canonical_bp_filter)) + ) + if do_canonical_filter: + + def _base_letter(res_name: str) -> str | None: + rn = str(res_name).strip().upper() + if rn in STANDARD_RNA: + return rn + if rn in STANDARD_DNA: + return rn[1] # DA/DC/DG/DT -> A/C/G/T + return None + + allowed_pairs = { + ("A", "U"), + ("U", "A"), + ("A", "T"), + ("T", "A"), + ("G", "C"), + ("C", "G"), + } + base_letters_full: list[str | None] = [ + _base_letter(rn) for rn in token_res_names + ] + + bp_bool = np.asarray(bp_bool, dtype=bool) + bp_rows_tmp, bp_cols_tmp = np.nonzero(bp_bool) + for r, c in zip(bp_rows_tmp.tolist(), bp_cols_tmp.tolist()): + full_i = int(subset_idxs[int(r)]) + full_j = int(subset_idxs[int(c)]) + bi = base_letters_full[full_i] + bj = base_letters_full[full_j] + if bi is None or bj is None or (bi, bj) not in allowed_pairs: + bp_bool[int(r), int(c)] = False + bp_bool[int(c), int(r)] = False + + bp_bool = np.asarray(bp_bool, dtype=bool) + bp_rows, bp_cols = np.nonzero(bp_bool) + + # Build residue-level annotation first, then spread to all atoms in each residue. + if (not overwrite) and ("bp_partners" in atom_array.get_annotation_categories()): + existing_ann = atom_array.bp_partners + if len(existing_ann) != len(atom_array): + raise ValueError("Existing bp_partners annotation has wrong length") + residue_bp_partners = np.empty(len(token_starts), dtype=object) + residue_bp_partners[:] = None + for i, start in enumerate(token_starts.tolist()): + residue_bp_partners[i] = existing_ann[int(start)] + else: + residue_bp_partners = np.empty(len(token_starts), dtype=object) + residue_bp_partners[:] = None + + # Explicit-loop semantics: + # - Only nucleic-acid token-start atoms *within subset_idxs* get a list container. + # - [] means explicitly unpaired loop. + # - None means unannotated/masked. + for full_i in subset_idxs.tolist(): + if not bool(is_na_full[int(full_i)]): + continue + if residue_bp_partners[int(full_i)] is None: + residue_bp_partners[int(full_i)] = [] + + # Populate partners using token_id ints + # We only process each unordered pair once to avoid duplicates. + for r, c in zip(bp_rows.tolist(), bp_cols.tolist()): + if r == c: + continue + + full_i = int(subset_idxs[int(r)]) + full_j = int(subset_idxs[int(c)]) + if full_i == full_j: + continue + + # Only annotate NA-NA basepairs as nucleic secondary structure. + if (not bool(is_na_full[int(full_i)])) or (not bool(is_na_full[int(full_j)])): + continue + + # Enforce uniqueness: only handle (i,j) where i < j + if full_j < full_i: + continue + + partner_i = int(token_ids[full_j]) + partner_j = int(token_ids[full_i]) + + if residue_bp_partners[full_i] is None: + residue_bp_partners[full_i] = [] + if residue_bp_partners[full_j] is None: + residue_bp_partners[full_j] = [] + + # Add if not present + if partner_i not in residue_bp_partners[full_i]: + residue_bp_partners[full_i].append(partner_i) + if partner_j not in residue_bp_partners[full_j]: + residue_bp_partners[full_j].append(partner_j) + + # Project residue-level annotations back to atom-level storage: + # - atomized residues: spread to all atoms in that residue + # - non-atomized residues: keep only on token-start representative atom + bp_partners_ann = np.empty(len(atom_array), dtype=object) + bp_partners_ann[:] = None + for i, start in enumerate(token_starts.tolist()): + stop = int(residue_start_end[i + 1]) + value = residue_bp_partners[i] + if value is None: + continue + # A residue is treated as atomized if any atom in the residue carries atomize=True. + if "atomize" in atom_array.get_annotation_categories(): + residue_is_atomized = bool( + np.any(np.asarray(atom_array.atomize[int(start) : stop], dtype=bool)) + ) + else: + residue_is_atomized = False + if residue_is_atomized: + for atom_idx in range(int(start), stop): + bp_partners_ann[atom_idx] = list(value) + else: + bp_partners_ann[int(start)] = list(value) + + atom_array.set_annotation("bp_partners", bp_partners_ann) + return atom_array + + +def parse_dot_bracket(dot_bracket: str) -> tuple[list[tuple[int, int]], list[int]]: + """Parse a dot-bracket string into base pairs and unpaired positions. + + Supports standard ``()``, ``[]``, ``{}``, ``<>`` and pseudoknot + brackets ``A``–``E`` / ``a``–``e``. + + Args: + dot_bracket: Dot-bracket notation string. + + Returns: + Tuple of ``(pairs, unpaired)`` where *pairs* is a list of 0-based + ``(i, j)`` index tuples and *unpaired* is a list of 0-based indices + corresponding to ``.`` characters. + """ + + stack: dict[str, list[int]] = {} + pairs: list[tuple[int, int]] = [] + unpaired: list[int] = [] + + opener_for = { + ")": "(", + "]": "[", + "}": "{", + ">": "<", + "a": "A", + "b": "B", + "c": "C", + "d": "D", + "e": "E", + } + + for i, ch in enumerate(str(dot_bracket)): + if ch == ".": + unpaired.append(i) + elif ch in "([{abcde": + o = opener_for.get(ch) + if o is None or o not in stack or not stack[o]: + continue + j = stack[o].pop() + pairs.append((j, i)) + else: + continue + + return pairs, unpaired + + +def annotate_na_ss_from_specification( + atom_array: AtomArray, + specification: dict, + *, + overwrite: bool = True, +) -> AtomArray: + """Write ``bp_partners`` annotation from an inference-time specification. + + Inference analogue of :func:`annotate_na_ss`: interprets user-provided + dot-bracket strings and/or residue ranges rather than computing base + pairs from geometry. + + Supported *specification* keys (all optional): + + * ``ss_dbn``: global dot-bracket string (applied to the first *L* tokens). + * ``ss_dbn_dict``: ``{"-": dbn_str, ...}``. + * ``paired_region_list``: ``["A5-15,B1-11", ...]``. + * ``paired_position_list``: ``["A19,A61,A20", ...]``. + * ``loop_region_list``: ``["A5-10", ...]`` (forced unpaired). + + Args: + atom_array: Structure to annotate (modified in-place). + specification: Specification dict as described above. + overwrite: If False, merge with existing ``bp_partners``. + + Returns: + The same *atom_array* with the ``bp_partners`` annotation set. + """ + + spec = specification or {} + token_starts = get_token_starts(atom_array) + token_level_array = atom_array[token_starts] + token_ids: list[int] = [int(t) for t in list(token_level_array.token_id)] + n_tokens = len(token_starts) + + # Prepare/overwrite annotation array + if (not overwrite) and ("bp_partners" in atom_array.get_annotation_categories()): + bp_partners_ann = atom_array.bp_partners + if len(bp_partners_ann) != len(atom_array): + raise ValueError("Existing bp_partners annotation has wrong length") + else: + bp_partners_ann = np.empty(len(atom_array), dtype=object) + bp_partners_ann[:] = None + + # Build chain/res -> token index map for region/position specs. + # Accept both chain_iid-like keys (e.g. "A_1") and plain chain IDs (e.g. "A") + # so CLI/json specs like "A1,B3" work reliably in inference. + chain_iid_list: list[str] = [str(atm.chain_iid) for atm in token_level_array] + chain_id_list: list[str] = [str(atm.chain_id) for atm in token_level_array] + resi_list: list[int] = [int(atm.res_id) for atm in token_level_array] + chain_res_to_tok: dict[tuple[str, int], int] = {} + for i, (chain_iid, chain_id, res_id) in enumerate( + zip(chain_iid_list, chain_id_list, resi_list) + ): + key_iid = (chain_iid, int(res_id)) + key_chain = (chain_id, int(res_id)) + chain_res_to_tok.setdefault(key_iid, int(i)) + chain_res_to_tok.setdefault(key_chain, int(i)) + # Also support the short alias from chain_iid (e.g. "A_1" -> "A") + short_chain = chain_iid.split("_", 1)[0] + chain_res_to_tok.setdefault((short_chain, int(res_id)), int(i)) + + def _parse_region(region_str: str) -> tuple[str, int, int] | None: + region_str = str(region_str).strip() + if not region_str: + return None + chain_id = region_str[0] + rest = region_str[1:] + if "-" not in rest: + return None + start_s, end_s = rest.split("-", 1) + try: + start_res = int(start_s) + end_res = int(end_s) + except ValueError: + return None + if start_res > end_res: + start_res, end_res = end_res, start_res + return chain_id, start_res, end_res + + def _parse_single_pos(pos_str: str) -> tuple[str, int] | None: + pos_str = str(pos_str).strip() + if not pos_str: + return None + chain_id = pos_str[0] + rest = pos_str[1:] + try: + res_id = int(rest) + except ValueError: + return None + return chain_id, res_id + + def _region_to_token_indices(region_str: str) -> list[int]: + parsed = _parse_region(region_str) + if parsed is None: + return [] + chain_id, start_res, end_res = parsed + token_indices: list[int] = [] + for res_id in range(start_res, end_res + 1): + idx = chain_res_to_tok.get((chain_id, int(res_id))) + if idx is not None: + token_indices.append(int(idx)) + return token_indices + + def _pos_to_token_index(pos_str: str) -> int | None: + parsed = _parse_single_pos(pos_str) + if parsed is None: + return None + chain_id, res_id = parsed + return chain_res_to_tok.get((chain_id, int(res_id))) + + # Accumulate partners as token-index sets + partners: list[set[int]] = [set() for _ in range(n_tokens)] + loop_token_idxs: set[int] = set() + + def _add_pair(i: int, j: int) -> None: + if not (0 <= i < n_tokens and 0 <= j < n_tokens): + return + if i == j: + return + if i in loop_token_idxs or j in loop_token_idxs: + return + partners[i].add(j) + partners[j].add(i) + + # Case 1: global ss_dbn + ss_dbn = spec.get("ss_dbn") + if isinstance(ss_dbn, str) and ss_dbn.strip(): + pairs, unpaired = parse_dot_bracket(ss_dbn.strip()) + L = min(len(ss_dbn), n_tokens) + for i_local, j_local in pairs: + if 0 <= i_local < L and 0 <= j_local < L: + _add_pair(int(i_local), int(j_local)) + for i_local in unpaired: + if 0 <= int(i_local) < L: + loop_token_idxs.add(int(i_local)) + + # Case 1b: ss_dbn_dict + ss_dbn_dict = spec.get("ss_dbn_dict", {}) or {} + if isinstance(ss_dbn_dict, dict): + for region_str, dbn_str in ss_dbn_dict.items(): + if not isinstance(region_str, str) or not isinstance(dbn_str, str): + continue + dbn_str = dbn_str.strip() + if not dbn_str: + continue + toks = _region_to_token_indices(region_str) + if not toks or len(toks) != len(dbn_str): + continue + pairs, unpaired = parse_dot_bracket(dbn_str) + for i_local, j_local in pairs: + if 0 <= i_local < len(toks) and 0 <= j_local < len(toks): + _add_pair(int(toks[int(i_local)]), int(toks[int(j_local)])) + for i_local in unpaired: + if 0 <= i_local < len(toks): + loop_token_idxs.add(int(toks[int(i_local)])) + + # Case 2: paired_region_list + paired_region_list = spec.get("paired_region_list", []) + if isinstance(paired_region_list, str): + paired_region_list = [paired_region_list] + if isinstance(paired_region_list, list): + for region_entry in paired_region_list: + if not isinstance(region_entry, str) or not region_entry.strip(): + continue + region_parts = [p.strip() for p in region_entry.split(",") if p.strip()] + if len(region_parts) != 2: + continue + toks1 = _region_to_token_indices(region_parts[0]) + toks2 = _region_to_token_indices(region_parts[1]) + if not toks1 or not toks2: + continue + for ti in toks1: + for tj in toks2: + _add_pair(int(ti), int(tj)) + + # Case 3: paired_position_list + paired_position_list = spec.get("paired_position_list", []) + if isinstance(paired_position_list, str): + paired_position_list = [paired_position_list] + if isinstance(paired_position_list, list): + for group_str in paired_position_list: + if not isinstance(group_str, str) or not group_str.strip(): + continue + pos_parts = [p.strip() for p in group_str.split(",") if p.strip()] + tok_indices: list[int] = [] + for pos_str in pos_parts: + tok = _pos_to_token_index(pos_str) + if tok is not None: + tok_indices.append(int(tok)) + for i in range(len(tok_indices)): + for j in range(i + 1, len(tok_indices)): + _add_pair(tok_indices[i], tok_indices[j]) + + # Case 4: loop_region_list + loop_region_list = spec.get("loop_region_list", []) + if isinstance(loop_region_list, str): + loop_region_list = [loop_region_list] + if isinstance(loop_region_list, list): + for region_str in loop_region_list: + if not isinstance(region_str, str) or not region_str.strip(): + continue + for tok in _region_to_token_indices(region_str): + loop_token_idxs.add(int(tok)) + + # Enforce loop tokens as unpaired: remove any pairs involving them + for i in list(loop_token_idxs): + if not (0 <= i < n_tokens): + continue + for j in list(partners[i]): + partners[j].discard(i) + partners[i].clear() + + # Write lists of partner token_ids onto token-start atoms. + # Unspecified tokens remain unannotated (None) -> NA_SS_MASK. + for i in range(n_tokens): + atom_i = int(token_starts[i]) + if len(partners[i]) > 0: + bp_partners_ann[atom_i] = [] + for j in sorted(partners[i]): + partner_token_id = int(token_ids[int(j)]) + bp_partners_ann[atom_i].append(partner_token_id) + elif int(i) in loop_token_idxs: + bp_partners_ann[atom_i] = [] + + atom_array.set_annotation("bp_partners", bp_partners_ann) + return atom_array + + +def annotate_na_ss_from_data_specification( + data: dict, + *, + overwrite: bool = True, +) -> AtomArray: + """Annotate ``bp_partners`` from ``data["specification"]``. + + Convenience wrapper around :func:`annotate_na_ss_from_specification`. + + Args: + data: Pipeline data dict containing ``atom_array`` and optionally + ``specification``. + overwrite: If False, merge with existing ``bp_partners``. + + Returns: + The annotated AtomArray (also stored back in *data*). + """ + atom_array = data["atom_array"] + spec = data.get("specification", {}) or {} + return annotate_na_ss_from_specification(atom_array, spec, overwrite=overwrite) diff --git a/models/rfd3/src/rfd3/transforms/pipelines.py b/models/rfd3/src/rfd3/transforms/pipelines.py index 89bc7289..8d8fee10 100644 --- a/models/rfd3/src/rfd3/transforms/pipelines.py +++ b/models/rfd3/src/rfd3/transforms/pipelines.py @@ -72,6 +72,7 @@ ) from rfd3.transforms.design_transforms import ( AddAdditional1dFeaturesToFeats, + AddAdditional2dFeaturesToFeats, AddGroundTruthSequence, AddIsXFeats, AssignTypes, @@ -84,6 +85,7 @@ ) from rfd3.transforms.dna_crop import ProteinDNAContactContiguousCrop from rfd3.transforms.hbonds_hbplus import CalculateHbondsPlus +from rfd3.transforms.na_geom import CalculateNucleicAcidGeomFeats from rfd3.transforms.ppi_transforms import ( Add1DSSFeature, AddGlobalIsNonLoopyFeature, @@ -194,6 +196,7 @@ def get_crop_transform( max_binder_length: int, max_atoms_in_crop: int | None, allowed_types: List[str], + association_scheme: str, ): if ( crop_contiguous_probability > 0 @@ -213,7 +216,9 @@ def get_crop_transform( ), "Crop center cutoff distance must be greater than 0" pre_crop_transforms = [ - SubsampleToTypes(allowed_types=allowed_types), + SubsampleToTypes( + allowed_types=allowed_types, association_scheme=association_scheme + ), ] cropping_transform = RandomRoute( @@ -350,6 +355,7 @@ def build_atom14_base_pipeline_( center_option: str, atom_1d_features: dict | None, token_1d_features: dict | None, + token_2d_features: dict | None = None, # PPI features max_ppi_hotspots_frac_to_provide: float, ppi_hotspot_max_distance: float, @@ -357,6 +363,9 @@ def build_atom14_base_pipeline_( max_ss_frac_to_provide: float, min_ss_island_len: int, max_ss_island_len: int, + ## Nucleic acid features ##### + # add_na_pair_features: bool, + ## This should not be necessary, controlled through feature names in model, and meta conditioning probabilities, inference behavior handled in transform itself ##### **_, # dump additional kwargs (e.g. msa stuff) ): """ @@ -383,6 +392,7 @@ def build_atom14_base_pipeline_( train_conditions=train_conditions, meta_conditioning_probabilities=meta_conditioning_probabilities, sequence_encoding=af3_sequence_encoding, + association_scheme=association_scheme, ), ), ] @@ -405,6 +415,7 @@ def build_atom14_base_pipeline_( max_binder_length=max_binder_length, max_atoms_in_crop=max_atoms_in_crop, allowed_types=allowed_types, + association_scheme=association_scheme, ) if zero_occ_on_exposure_after_cropping: @@ -422,7 +433,9 @@ def build_atom14_base_pipeline_( # ... Add global token features (since number of tokens is fixed after cropping) transforms.append(AddGlobalTokenIdAnnotation()) # ... Create masks (NOTE: Modulates token count, and resets global token id if necessary) - transforms.append(TrainingRoute(SampleConditioningFlags())) + transforms.append( + TrainingRoute(SampleConditioningFlags(association_scheme=association_scheme)) + ) # Post-crop transforms transforms.append( @@ -434,6 +447,16 @@ def build_atom14_base_pipeline_( ), ) ) + # Add nucleic acid geometry features + # if add_na_pair_features: + transforms.append( + CalculateNucleicAcidGeomFeats( + is_inference, + meta_conditioning_probabilities, + NA_only=False, + planar_only=True, + ) + ) # Design Transforms transforms += [ @@ -442,7 +465,9 @@ def build_atom14_base_pipeline_( sharding_depth=1, ), # ... Fuse inference and training conditioning assignments - UnindexFlaggedTokens(central_atom=central_atom), + UnindexFlaggedTokens( + central_atom=central_atom, association_scheme=association_scheme + ), # ... Virtual atom padding (NOTE: Last transform which modulates atom count) PadTokensWithVirtualAtoms( n_atoms_per_token=n_atoms_per_token, @@ -518,6 +543,12 @@ def build_atom14_base_pipeline_( autofill_zeros_if_not_present_in_atomarray=True, token_1d_features=token_1d_features, atom_1d_features=atom_1d_features, + association_scheme=association_scheme, + ), + AddAdditional2dFeaturesToFeats( + autofill_zeros_if_not_present_in_atomarray=True, + token_2d_features=token_2d_features, + association_scheme=association_scheme, ), AddAF3TokenBondFeatures(), AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding), @@ -593,7 +624,6 @@ def build_atom14_base_pipeline( Wrapper around pipeline construction to handle empty training args Sets default behaviour for inference to keep backward compatibility """ - if is_inference: # Provide explicit defaults for training-only args kwargs.setdefault("crop_size", 512) @@ -609,6 +639,8 @@ def build_atom14_base_pipeline( kwargs.setdefault("min_ss_island_len", 0) kwargs.setdefault("max_ss_island_len", 999) kwargs.setdefault("max_binder_length", 999) + # This should not be necessary. + # kwargs.setdefault("add_na_pair_features", False) kwargs.setdefault("b_factor_min", None) kwargs.setdefault("zero_occ_on_exposure_after_cropping", False) @@ -621,7 +653,7 @@ def build_atom14_base_pipeline( kwargs.setdefault("residue_cache_dir", None) # TODO: Delete these once all checkpoints are updated with the latest defaults - kwargs.setdefault("generate_conformers_for_non_protein_only", True) + kwargs.setdefault("generate_conformers_for_non_protein_only", False) kwargs.setdefault("return_atom_array", True) kwargs.setdefault("provide_elements_for_unindexed_components", False) kwargs.setdefault("center_option", "all") diff --git a/models/rfd3/src/rfd3/transforms/training_conditions.py b/models/rfd3/src/rfd3/transforms/training_conditions.py index f93e346f..67ced33a 100644 --- a/models/rfd3/src/rfd3/transforms/training_conditions.py +++ b/models/rfd3/src/rfd3/transforms/training_conditions.py @@ -13,6 +13,7 @@ spread_token_wise, ) from biotite.structure import AtomArray, get_residue_starts +from rfd3.constants import backbone_atoms_RNA from rfd3.transforms.conditioning_utils import ( random_condition, sample_island_tokens, @@ -58,6 +59,8 @@ class IslandCondition(TrainingCondition): Select islands as motif and assign conditioning strategies. """ + association_scheme = "atom14" + def __init__( self, *, @@ -70,9 +73,11 @@ def __init__( p_fix_motif_coordinates, p_fix_motif_sequence, p_unindex_motif_tokens, + association_scheme="atom14", ): self.name = name self.frequency = frequency + self.association_scheme = association_scheme # Token selection self.island_sampling_kwargs = island_sampling_kwargs @@ -88,11 +93,21 @@ def __init__( self.p_fix_motif_sequence = p_fix_motif_sequence self.p_unindex_motif_tokens = p_unindex_motif_tokens + self.association_scheme = association_scheme + def is_valid_for_example(self, data) -> bool: is_protein = data["atom_array"].is_protein - if not np.any(is_protein): - return False - return True + is_dna = data["atom_array"].is_dna + is_rna = data["atom_array"].is_rna + ### updating this to allow other polymers + if self.association_scheme == "atom23": + if np.any(is_protein | is_dna | is_rna): + return True + else: + if np.any(is_protein): + return True + + return False def sample_motif_tokens(self, atom_array): """ @@ -101,13 +116,30 @@ def sample_motif_tokens(self, atom_array): token_level_array = atom_array[get_token_starts(atom_array)] # initialize motif tokens as all non-protein tokens - is_motif_token = np.asarray(~token_level_array.is_protein, dtype=bool).copy() - n_protein_tokens = np.sum(token_level_array.is_protein) - islands_mask = sample_island_tokens( - n_protein_tokens, - **self.island_sampling_kwargs, - ) - is_motif_token[token_level_array.is_protein] = islands_mask + if self.association_scheme == "atom23": + polymer_mask = ( + token_level_array.is_protein + | token_level_array.is_dna + | token_level_array.is_rna + ) + is_motif_token = np.asarray(~polymer_mask, dtype=bool).copy() + n_polymer_tokens = np.sum(polymer_mask) + islands_mask = sample_island_tokens( + n_polymer_tokens, + **self.island_sampling_kwargs, + ) + is_motif_token[polymer_mask] = islands_mask + else: + is_motif_token = np.asarray( + ~token_level_array.is_protein, dtype=bool + ).copy() + n_protein_tokens = np.sum(token_level_array.is_protein) + + islands_mask = sample_island_tokens( + n_protein_tokens, + **self.island_sampling_kwargs, + ) + is_motif_token[token_level_array.is_protein] = islands_mask # TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this # atom_with_coval_bond = token_level_array.covale # (n_atoms, ) @@ -127,7 +159,9 @@ def sample_motif_atoms(self, atom_array): is_motif_atom = np.asarray(atom_array.is_motif_token, dtype=bool).copy() if random_condition(self.p_diffuse_motif_sidechains): - backbone_atoms = ["N", "C", "CA"] + backbone_atoms = backbone_atoms_RNA.copy() + backbone_atoms.remove("C1'") + backbone_atoms = ["N", "C", "CA"] + backbone_atoms # covers DNA also if random_condition(self.p_include_oxygen_in_backbone_mask): backbone_atoms.append("O") is_motif_atom = is_motif_atom & np.isin( @@ -137,11 +171,11 @@ def sample_motif_atoms(self, atom_array): is_motif_atom = sample_motif_subgraphs( atom_array=atom_array, **self.subgraph_sampling_kwargs, + association_scheme=self.association_scheme, ) # We also only want resolved atoms to be motif is_motif_atom = (is_motif_atom) & (atom_array.occupancy > 0.0) - return is_motif_atom def sample(self, data): @@ -158,6 +192,7 @@ def sample(self, data): p_fix_motif_sequence=self.p_fix_motif_sequence, p_fix_motif_coordinates=self.p_fix_motif_coordinates, p_unindex_motif_tokens=self.p_unindex_motif_tokens, + association_scheme=self.association_scheme, ) atom_array.set_annotation( @@ -169,7 +204,6 @@ def sample(self, data): leak_global_index=data["conditions"]["unindex_leak_global_index"], ), ) - return atom_array @@ -177,6 +211,7 @@ class PPICondition(TrainingCondition): """Get condition indicating what is motif and what is to be diffused for protein-protein interaction training.""" name = "ppi" + association_scheme = "atom14" def is_valid_for_example(self, data): # Extract relevant data @@ -275,6 +310,7 @@ class SubtypeCondition(TrainingCondition): """ name = "subtype" + association_scheme = "atom14" def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False): self.frequency = frequency @@ -370,6 +406,7 @@ def sample_motif_subgraphs( hetatom_n_bond_expectation, residue_p_fix_all, hetatom_p_fix_all, + association_scheme="atom14", ): """ Returns a boolean mask over atoms, indicating which atoms are part of the sampled motif. @@ -402,7 +439,17 @@ def sample_motif_subgraphs( "n_bond_expectation": residue_n_bond_expectation, "p_fix_all": residue_p_fix_all, } - if not atom_array_subset.is_protein.all(): + + if association_scheme == "atom23": + clause = ( + atom_array_subset.is_protein.all() + | atom_array_subset.is_dna.all() + | atom_array_subset.is_rna.all() + ) + else: + clause = atom_array_subset.is_protein.all() + + if not clause: args.update( { "p_seed_furthest_from_o": 0.0, @@ -431,11 +478,14 @@ def sample_conditioning_strategy( p_fix_motif_sequence, p_fix_motif_coordinates, p_unindex_motif_tokens, + association_scheme, ): atom_array.set_annotation( "is_motif_atom_with_fixed_seq", sample_is_motif_atom_with_fixed_seq( - atom_array, p_fix_motif_sequence=p_fix_motif_sequence + atom_array, + p_fix_motif_sequence=p_fix_motif_sequence, + association_scheme=association_scheme, ), ) @@ -449,14 +499,17 @@ def sample_conditioning_strategy( atom_array.set_annotation( "is_motif_atom_unindexed", sample_unindexed_atoms( - atom_array, p_unindex_motif_tokens=p_unindex_motif_tokens + atom_array, + p_unindex_motif_tokens=p_unindex_motif_tokens, + association_scheme=association_scheme, ), ) - return atom_array -def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence): +def sample_is_motif_atom_with_fixed_seq( + atom_array, p_fix_motif_sequence, association_scheme +): """ Samples what kind of conditioning to apply to motif tokens. @@ -469,7 +522,12 @@ def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence): is_motif_atom_with_fixed_seq = np.zeros(atom_array.array_length(), dtype=bool) # By default reveal sequence for non-protein - is_motif_atom_with_fixed_seq = is_motif_atom_with_fixed_seq | ~atom_array.is_protein + + if not association_scheme == "atom23": + is_motif_atom_with_fixed_seq = ( + is_motif_atom_with_fixed_seq | ~atom_array.is_protein + ) + return is_motif_atom_with_fixed_seq @@ -487,7 +545,9 @@ def sample_fix_motif_coordinates(atom_array, p_fix_motif_coordinates): return is_motif_atom_with_fixed_coord -def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens): +def sample_unindexed_atoms( + atom_array, p_unindex_motif_tokens, association_scheme="atom14" +): """ Samples which atoms in motif tokens should be flagged for unindexing. @@ -498,11 +558,16 @@ def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens): is_motif_atom_unindexed = atom_array.is_motif_atom.copy() else: is_motif_atom_unindexed = np.zeros(atom_array.array_length(), dtype=bool) - # ensure non-residue atoms are not already flagged - is_motif_atom_unindexed = np.logical_and( - is_motif_atom_unindexed, atom_array.is_residue - ) + if association_scheme == "atom23": + is_motif_atom_unindexed = np.logical_and( + is_motif_atom_unindexed, + (atom_array.is_residue | atom_array.is_dna | atom_array.is_rna), + ) # is_residue refers to is_protein here + else: + is_motif_atom_unindexed = np.logical_and( + is_motif_atom_unindexed, atom_array.is_residue + ) return is_motif_atom_unindexed diff --git a/models/rfd3/src/rfd3/transforms/util_transforms.py b/models/rfd3/src/rfd3/transforms/util_transforms.py index 024b1bf1..37f08f19 100644 --- a/models/rfd3/src/rfd3/transforms/util_transforms.py +++ b/models/rfd3/src/rfd3/transforms/util_transforms.py @@ -37,7 +37,7 @@ def assert_single_representative(token, central_atom="CB"): mask = get_af3_token_representative_masks(token, central_atom=central_atom) assert ( np.sum(mask) == 1 - ), f"No representative atom (CB) found. mask: {mask}\nToken: {token}" + ), f"No representative atom ({central_atom}) found. mask: {mask}\nToken: {token}" def assert_single_token(token): @@ -252,13 +252,13 @@ def get_af3_token_representative_masks( atom_array: AtomArray, central_atom: str = "CA" ) -> np.ndarray: pyrimidine_representative_atom = is_pyrimidine(atom_array.res_name) & ( - atom_array.atom_name == "C2" + atom_array.atom_name == "C1'" ) purine_representative_atom = is_purine(atom_array.res_name) & ( - atom_array.atom_name == "C4" + atom_array.atom_name == "C1'" ) unknown_na_representative_atom = is_unknown_nucleotide(atom_array.res_name) & ( - atom_array.atom_name == "C4" + atom_array.atom_name == "C1'" ) glycine_representative_atom = is_glycine(atom_array.res_name) & ( diff --git a/models/rfd3/src/rfd3/transforms/virtual_atoms.py b/models/rfd3/src/rfd3/transforms/virtual_atoms.py index ae40b3ce..d9097fea 100644 --- a/models/rfd3/src/rfd3/transforms/virtual_atoms.py +++ b/models/rfd3/src/rfd3/transforms/virtual_atoms.py @@ -10,8 +10,10 @@ ) from atomworks.ml.utils.token import get_token_starts from rfd3.constants import ( - ATOM14_ATOM_NAME_TO_ELEMENT, ATOM14_ATOM_NAMES, + ATOM23_ATOM_NAME_TO_ELEMENT, + ATOM23_ATOM_NAMES_DNA, + ATOM23_ATOM_NAMES_RNA, VIRTUAL_ATOM_ELEMENT_NAME, association_schemes, association_schemes_stripped, @@ -28,7 +30,9 @@ from foundry.common import exists -def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="atom14"): +def map_to_association_scheme( + atom_names: list | str, res_name: str, scheme="atom14", ATOM_NAMES=None +): """ Maps a list of names to the atom14 naming scheme for that particular name (within a specific residue) NB this function is a bit more general since it is used to handle tipatoms too. @@ -37,16 +41,17 @@ def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="ato raise ValueError( f"Scheme {scheme} not found in association_schemes_stripped. Available schemes: {list(association_schemes_stripped.keys())}" ) - atom_names = ( - [str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names - ) + atom_names = [atom_names] if isinstance(atom_names, str) else atom_names idxs = np.array( [ association_schemes_stripped[scheme][res_name].index(name) for name in atom_names ] ) - return ATOM14_ATOM_NAMES[idxs] + if ATOM_NAMES is None: + return ATOM14_ATOM_NAMES[idxs] + else: + return ATOM_NAMES[idxs] def map_names_to_elements( @@ -58,7 +63,7 @@ def map_names_to_elements( then it returns the default value """ atom_names = [atom_names] if isinstance(atom_names, str) else atom_names - elements = [ATOM14_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names] + elements = [ATOM23_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names] return np.array(elements) @@ -68,17 +73,20 @@ def generate_atom_mappings_(scheme="atom14"): atom_mapping = {} symmetry_mapping = {} - for aaa, atom14_names in ccd_ordering_atomchar.items(): - mapping = list(range(14)) + for aaa, atom_names in ccd_ordering_atomchar.items(): + if aaa not in scheme: + continue + + mapping = list(range(len(atom_names))) scheme_names = scheme[aaa] - for ccd_index in range(len(atom14_names)): - atom14_name = atom14_names[ccd_index] - if atom14_name is not None: + for ccd_index in range(len(atom_names)): + atom_name = atom_names[ccd_index] + if atom_name is not None: assert ( - atom14_name in scheme_names - ), f"{atom14_name} not in CCD ordering for {aaa}" - scheme_index = scheme_names.index(atom14_name) + atom_name in scheme_names + ), f"{atom_name} not in CCD ordering for {aaa}" + scheme_index = scheme_names.index(atom_name) scheme_index_in_cur_mapping = mapping.index(scheme_index) mapping[ccd_index], mapping[scheme_index_in_cur_mapping] = ( mapping[scheme_index_in_cur_mapping], @@ -121,6 +129,7 @@ def permute_symmetric_atom_names_( ) -> list: # NB: Can leak GT sequence if the model receives the canconical ordering of atoms as input # With the structure-local atom attention it will not unless N_keys(n_attn_seq_neighbours) > n_atom_attn_queries. + if res_name in association_map: idx_to_swap = association_map[res_name] atom_names = atom_names[idx_to_swap] @@ -174,18 +183,35 @@ def forward(self, data: dict) -> dict: ), "Token ids and token level array have different lengths!" # Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms: - is_residue = ( - token_level_array.is_protein & ~token_level_array.atomize - ) | is_motif_token_unindexed - - # Unindexed tokens are never padded, and so are treated as residues with fixed sequence. - is_paddable = is_residue & ~( - is_motif_atom_with_fixed_seq | is_motif_token_unindexed - ) - is_non_paddable_residue = is_residue & ( - is_motif_atom_with_fixed_seq | is_motif_token_unindexed - ) + if self.association_scheme == "atom23": + is_residue = ( + token_level_array.is_protein & ~token_level_array.atomize + ) | is_motif_token_unindexed + + is_residue_NA = ( + (token_level_array.is_dna | token_level_array.is_rna) + & ~token_level_array.atomize + ) | is_motif_token_unindexed + + # Unindexed tokens are never padded, and so are treated as residues with fixed sequence. + is_paddable = (is_residue_NA | is_residue) & ~( + is_motif_atom_with_fixed_seq | is_motif_token_unindexed + ) + is_non_paddable_residue = (is_residue_NA | is_residue) & ( + is_motif_atom_with_fixed_seq | is_motif_token_unindexed + ) + else: + is_residue = ( + token_level_array.is_protein & ~token_level_array.atomize + ) | is_motif_token_unindexed + # Unindexed tokens are never padded, and so are treated as residues with fixed sequence. + is_paddable = is_residue & ~( + is_motif_atom_with_fixed_seq | is_motif_token_unindexed + ) + is_non_paddable_residue = is_residue & ( + is_motif_atom_with_fixed_seq | is_motif_token_unindexed + ) # Collect virtual atoms to insert (we will insert them all at once) virtual_atoms_to_insert = [] insert_positions = [] @@ -194,13 +220,26 @@ def forward(self, data: dict) -> dict: for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])): if is_paddable[token_id]: token = atom_array[start:end] + # First, pad with virtual atoms if needed - n_pad = self.n_atoms_per_token - len(token) + if self.association_scheme == "atom23" and atom_array[start].is_dna: + n_atoms_per_token = 22 + central_atom = "C1'" + elif self.association_scheme == "atom23" and atom_array[start].is_rna: + n_atoms_per_token = 23 + central_atom = "C1'" + else: + n_atoms_per_token = self.n_atoms_per_token + central_atom = self.atom_to_pad_from + + n_pad = n_atoms_per_token - len(token) + if n_pad > 0: mask = get_af3_token_representative_masks( - token, central_atom=self.atom_to_pad_from + token, central_atom=central_atom ) - assert_single_representative(token) + + assert_single_representative(token, central_atom=central_atom) # ... Create virtual atoms pad_atoms = token[mask].copy() @@ -263,17 +302,25 @@ def _fix_multidimensional_annotations_in_pad_array( for token_id, (start, end) in enumerate( zip(starts_padded[:-1], starts_padded[1:]) ): + if atom_array_padded[start].is_dna: + ATOM_NAMES = ATOM23_ATOM_NAMES_DNA + elif atom_array_padded[start].is_rna: + ATOM_NAMES = ATOM23_ATOM_NAMES_RNA + else: + ATOM_NAMES = ATOM14_ATOM_NAMES + if is_paddable[token_id]: # ... Permutation of atom names during training if not data["is_inference"] and exists(self.association_scheme): atom_names = permute_symmetric_atom_names_( - ATOM14_ATOM_NAMES, + ATOM_NAMES, atom_array_padded.res_name[start], association_map=self.association_map_, symmetry_map=self.symmetry_map_, ) else: - atom_names = ATOM14_ATOM_NAMES + atom_names = ATOM_NAMES + atom_array_padded.atom_name[start:end] = atom_names atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names @@ -285,7 +332,10 @@ def _fix_multidimensional_annotations_in_pad_array( ) atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names atom_names = map_to_association_scheme( - atom_names, res_name, scheme=self.association_scheme + atom_names, + res_name, + scheme=self.association_scheme, + ATOM_NAMES=ATOM_NAMES, ) atom_array_padded.atom_name[start:end] = atom_names else: diff --git a/models/rfd3/src/rfd3/utils/inference.py b/models/rfd3/src/rfd3/utils/inference.py index 2d04e149..1e00f566 100644 --- a/models/rfd3/src/rfd3/utils/inference.py +++ b/models/rfd3/src/rfd3/utils/inference.py @@ -452,7 +452,10 @@ def infer_ori_from_com(atom_array): def set_com( - atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None + atom_array, + ori_token: list | None = None, + infer_ori_strategy: str | None = None, + ori_jitter: float | None = None, ): if exists(ori_token): center = np.array([float(x) for x in ori_token], dtype=atom_array.coord.dtype) @@ -505,6 +508,17 @@ def set_com( atom_array.coord = np.zeros_like( atom_array.coord, dtype=atom_array.coord.dtype ) + if ori_jitter is not None: + # randomly jitter ori with given scale + direction = np.random.normal(size=3) + direction /= np.linalg.norm(direction) + + # Random length (mean ~ scale) + length = np.random.exponential(scale=scale) + jittered_offset = direction * length + + atom_array.coord -= jittered_offset + return atom_array diff --git a/src/foundry/utils/components.py b/src/foundry/utils/components.py index 75bc87f3..0af9c7ea 100644 --- a/src/foundry/utils/components.py +++ b/src/foundry/utils/components.py @@ -96,8 +96,21 @@ def get_design_pattern_with_constraints(contig, length=None): fixed_parts = [] pos_to_put_motif = [] + suff = [] # suffixes for diffused regions P(optional),R,D + for part in contig_parts: - if any(c.isalpha() for c in part): # Detect parts containing letters as fixed + ## updating to include DNA and RNA generation + if part[-1] in ["R", "D"]: ##Detect non-fixed RNA and DNA contig part + suff.append(part[-1]) + part = part[:-1] + if "-" in part: + start, end = map(int, part.split("-")) + else: + start = end = int(part) + variable_ranges.append([start, end]) + pos_to_put_motif.append(0) + + elif any(c.isalpha() for c in part): # Detect parts containing letters as fixed pn_unit_id, pn_unit_start, pn_unit_end = extract_pn_unit_info(part) fixed_parts.append([pn_unit_id, pn_unit_start, pn_unit_end]) pos_to_put_motif.append(1) @@ -110,6 +123,7 @@ def get_design_pattern_with_constraints(contig, length=None): start = end = int(part) variable_ranges.append([start, end]) pos_to_put_motif.append(0) + suff.append("P") # adjust the total length to solely for free residues num_motif_residues = sum([i[2] - i[1] + 1 for i in fixed_parts]) @@ -167,7 +181,7 @@ def get_design_pattern_with_constraints(contig, length=None): atoms_with_motif.append(f"{pn_unit_id}{index}") elif pos_to_put_motif[idx] == 0: free_atom = num_free_atoms.pop(0) - atoms_with_motif.append(free_atom) + atoms_with_motif.append(str(free_atom) + suff.pop(0)) elif pos_to_put_motif[idx] == 2: atoms_with_motif.append("/0")