diff --git a/.env b/.env
index 74786fe4..88ac0bb3 100644
--- a/.env
+++ b/.env
@@ -10,7 +10,7 @@
# expected that you use the same saving conventions as the RCSB PDB, which means:
# `1a2b` --> /path/to/pdb_mirror/a2/1a2b.cif.gz
# To set up a mirror, you can use tha atomworks commandline: `atomworks pdb sync /path/to/mirror`
-PDB_MIRROR_PATH=
+PDB_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_01_pdb
# The `CCD_MIRROR_PATH` is a path to a local mirror of the CCD database.
# It's expected that you use the same saving conventions as the RCSB CCD, which means:
@@ -19,7 +19,7 @@ PDB_MIRROR_PATH=
# If no mirror is provided, the internal biotite CCD will be used as a fallback. To provide a
# custom CCD for a ligand, you can place it in the in the CCD mirror path following the CCDs pattern.
# Example: /path/to/ccd_mirror/M/MYLIGAND1/MYLIGAND1.cif
-CCD_MIRROR_PATH=
+CCD_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_11_ccd
# --- Local MSA directories ---
LOCAL_MSA_DIRS=
@@ -29,14 +29,14 @@ LOCAL_MSA_DIRS=
# The HBPLUS_PATH is a path to the hbplus tool, which is used for hydrogen bond calculation
# during training and during metrics computation.
# Example: /path/to/hbplus
-HBPLUS_PATH=
+HBPLUS_PATH=/projects/ml/hbplus
# The `X3DNA_PATH` is a path to the x3dna tool, which is used for DNA structure analysis.
# Example: /path/to/x3dna-v2.4
-X3DNA_PATH=
+X3DNA_PATH=/projects/ml/prot_dna/x3dna-v2.4
# For secondary structure prediction (not currently used)
-DSSP_PATH=
+DSSP_PATH=/projects/ml/dssp/install/bin/mkdssp
# The `HHFILTER_PATH` is a path to the hhfilter tool from the HH-suite, which is used for
# filtering MSAs to reduce redundancy.
diff --git a/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml b/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml
index 063e6f5e..4714de10 100644
--- a/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml
+++ b/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml
@@ -42,4 +42,4 @@ confidence_head:
n_bins_exp_resolved: 2
use_Cb_distances: False
use_af3_style_binning_and_final_layer_norms: True
- symmetrize_Cb_logits: True
\ No newline at end of file
+ symmetrize_Cb_logits: True
diff --git a/models/rfd3/README.md b/models/rfd3/README.md
index ecb4f34c..bdbf0dd0 100644
--- a/models/rfd3/README.md
+++ b/models/rfd3/README.md
@@ -1,6 +1,6 @@
# De novo Design of Biomolecular Interactions with RFdiffusion3
-RFdiffusion3 (RFD3) is a diffusion method that can design protein structures
+RFdiffusion3 (RFD3) is a diffusion method that can design biopolymer structures
under complex constraints.
This repository contains both the training and inference code, and
@@ -62,6 +62,8 @@ For example, you can fix sequence and not structure (prediction-type task), fix
For full details on how to specify inputs, see the [input specification documentation](./docs/input.md). You can also see `foundry/models/rfd3/configs/inference_engine/rfdiffusion3.yaml` for even more options.
+Nucleic acid design, along with proteins, is also possible using RFD3 using the atom23 checkpoints. For full details see the [atom23 design documentation](./docs/examples/atom23_design.md)
+
## Further example JSONs for different applications
Additional examples are broken up by use case. If you have cloned the
repository, matching `.json` files are in `foundry/models/rfd3/docs/examples`
@@ -75,27 +77,33 @@ you will need to change the path in the `.json` file(s) before running.
-
-
- |
-
-
+
|
-
+
|
+
+
+
+ |
+
-
+
|
-
+
|
+
+
+
+ |
+
diff --git a/models/rfd3/configs/callbacks/design_callbacks.yaml b/models/rfd3/configs/callbacks/design_callbacks.yaml
index 309b492a..68c9a4ce 100644
--- a/models/rfd3/configs/callbacks/design_callbacks.yaml
+++ b/models/rfd3/configs/callbacks/design_callbacks.yaml
@@ -1,5 +1,6 @@
defaults:
- train_logging
+ - metrics_logging
- _self_
log_learning_rate_callback:
diff --git a/models/rfd3/configs/datasets/design_base_rfd3na.yaml b/models/rfd3/configs/datasets/design_base_rfd3na.yaml
new file mode 100644
index 00000000..661b2fb0
--- /dev/null
+++ b/models/rfd3/configs/datasets/design_base_rfd3na.yaml
@@ -0,0 +1,105 @@
+# base training dataset for training AF3 design models (atom14 variants):
+# protein subsampling only.
+
+defaults:
+ # Grab datasets
+ - train/pdb/rfd3_train_interface@train.pdb.sub_datasets.interface
+ - train/pdb/rfd3_train_pn_unit@train.pdb.sub_datasets.pn_unit
+ - train/rfd3_monomer_distillation@train
+ - train/rna_monomer_distillation@train
+
+ # Customized validation datasets
+ #- val/unconditional@val.unconditional
+ #- val/unconditional_deep@val.unconditional_deep
+ #- val/indexed@val.indexed
+ - val/pseudoknot@val.pseudoknot
+
+ # Customized train masks
+ - conditions/unconditional@global_transform_args.train_conditions.unconditional
+ - conditions/island@global_transform_args.train_conditions.island
+ - conditions/tipatom@global_transform_args.train_conditions.tipatom
+ - conditions/sequence_design@global_transform_args.train_conditions.sequence_design
+ - conditions/ppi@global_transform_args.train_conditions.ppi
+
+ - _self_
+
+# Create a dictionary used for transform arguments
+pipeline_target: rfd3.transforms.pipelines.build_atom14_base_pipeline
+
+# Base config overrides:
+diffusion_batch_size_train: 32
+diffusion_batch_size_inference: 8
+crop_size: 384
+n_recycles_train: 2
+n_recycles_validation: 1
+max_atoms_in_crop: 3840 # ~10x crop size.
+
+# Global transform arguments are necessary for arguments shared between training and inference
+global_transform_args:
+ n_atoms_per_token: 14
+ central_atom: CB
+ sigma_perturb: 2.0
+ sigma_perturb_com: 1.0
+ association_scheme: dense
+ center_option: diffuse # options are ["all", "motif", "diffuse"]
+
+ # Reference conformer policy
+ generate_conformers: True
+ generate_conformers_for_non_protein_only: True
+ provide_reference_conformer_when_unmasked: True
+ ground_truth_conformer_policy: IGNORE # Other options: REPLACE, ADD, FALLBACK. See atomworks.enums for details
+ provide_elements_for_unindexed_components: True
+ use_element_for_atom_names_of_atomized_tokens: True # TODO: correct name, implies unindexed do too
+
+ # PPI Cropping
+ keep_full_binder_in_spatial_crop: False
+ max_binder_length: 170
+
+ # PPI Hotspots
+ max_ppi_hotspots_frac_to_provide: 0.2
+ ppi_hotspot_max_distance: 4.5
+
+ # Secondary structure features
+ max_ss_frac_to_provide: 0.4
+ min_ss_island_len: 1
+ max_ss_island_len: 10
+
+ # Nucleic acid features
+ add_na_pair_features: false
+
+ train_conditions:
+ unconditional:
+ frequency: 5.0
+ sequence_design:
+ frequency: 2.0
+ island:
+ frequency: 1.0
+ tipatom:
+ frequency: 0.0
+ ppi:
+ frequency: 0.0
+
+ # Used to create simple boolean flags for downstream conditioning
+ meta_conditioning_probabilities:
+ p_is_nucleic_ss_example: 0.1
+ p_nucleic_ss_show_partial_feats: 0.7
+ calculate_NA_SS: 0.5
+ calculate_hbonds: 0.2
+ calculate_rasa: 0.6
+
+ keep_protein_motif_rasa: 0.1 # Small to prevent noisy input to model
+ hbond_subsample: 0.5
+
+ # fully indexed training
+ unindex_leak_global_index: 0.10
+ unindex_insert_random_break: 0.10
+ unindex_remove_random_break: 0.10
+
+ # Probability of adding 1d secondary structure conditioning
+ add_1d_ss_features: 0.1
+ featurize_plddt: 0.9 # Applied for monomer distillation only
+ add_global_is_non_loopy_feature: 0.99
+
+ # PPI
+ add_ppi_hotspots: 0.75
+ full_binder_crop: 0.75
diff --git a/models/rfd3/configs/datasets/train/pdb/base_transform_args.yaml b/models/rfd3/configs/datasets/train/pdb/base_transform_args.yaml
index 08c735ca..28e2384d 100644
--- a/models/rfd3/configs/datasets/train/pdb/base_transform_args.yaml
+++ b/models/rfd3/configs/datasets/train/pdb/base_transform_args.yaml
@@ -43,6 +43,9 @@ dataset:
min_ss_island_len: ${datasets.global_transform_args.min_ss_island_len}
max_ss_island_len: ${datasets.global_transform_args.max_ss_island_len}
+ # Nucleic acid features
+ add_na_pair_features: ${datasets.global_transform_args.add_na_pair_features}
+
# Cropping
crop_size: ${datasets.crop_size}
max_atoms_in_crop: ${datasets.max_atoms_in_crop}
@@ -56,4 +59,5 @@ dataset:
# Other dataset-specific parameters
atom_1d_features: ${model.net.token_initializer.atom_1d_features}
- token_1d_features: ${model.net.token_initializer.token_1d_features}
\ No newline at end of file
+ token_1d_features: ${model.net.token_initializer.token_1d_features}
+ token_2d_features: ${model.net.token_initializer.token_2d_features}
diff --git a/models/rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml b/models/rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml
index d5df571b..9ff698a8 100644
--- a/models/rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml
+++ b/models/rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml
@@ -12,6 +12,7 @@ dataset:
# filters common across all PDB datasets
- 'pdb_id not in ["7rte", "7m5w", "7n5u"]'
- 'pdb_id not in ["3di3", "5o45", "1z92", "2gy5", "4zxb"]'
+ - 'pdb_id not in ["1drz", "2m8k", "2miy", "3q3z", "4oqu", "4plx", "4znp", "7kd1", "7kga", "7qr4"]'
- "deposition_date < '2024-12-16'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
@@ -19,4 +20,4 @@ dataset:
# interface specific filters
- "~(pn_unit_1_non_polymer_res_names.notnull() and pn_unit_1_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "~(pn_unit_2_non_polymer_res_names.notnull() and pn_unit_2_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- - "is_inter_molecule"
\ No newline at end of file
+ - "is_inter_molecule"
diff --git a/models/rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml b/models/rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml
index dfd548fc..19d31e8b 100644
--- a/models/rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml
+++ b/models/rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml
@@ -15,6 +15,7 @@ dataset:
# filters common across all PDB datasets
- 'pdb_id not in ["7rte", "7m5w", "7n5u"]'
- 'pdb_id not in ["3di3", "5o45", "1z92", "2gy5", "4zxb"]'
+ - 'pdb_id not in ["1drz", "2m8k", "2miy", "3q3z", "4oqu", "4plx", "4znp", "7kd1", "7kga", "7qr4"]'
- "deposition_date < '2024-12-16'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
diff --git a/models/rfd3/configs/datasets/train/rna_monomer_distillation.yaml b/models/rfd3/configs/datasets/train/rna_monomer_distillation.yaml
new file mode 100644
index 00000000..05ac8472
--- /dev/null
+++ b/models/rfd3/configs/datasets/train/rna_monomer_distillation.yaml
@@ -0,0 +1,39 @@
+defaults:
+ - pdb/base_transform_args@rna_monomer_distillation
+ - _self_
+
+rna_monomer_distillation:
+ dataset:
+ _target_: atomworks.ml.datasets.StructuralDatasetWrapper
+ save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
+
+ # cif parser arguments
+ cif_parser_args:
+ cache_dir: null
+ load_from_cache: False
+ save_to_cache: False
+
+ # metadata parser
+ dataset_parser:
+ _target_: atomworks.ml.datasets.parsers.GenericDFParser
+ pn_unit_iid_colnames: null
+
+ # metadata dataset
+ dataset:
+ _target_: atomworks.ml.datasets.PandasDataset
+ name: rna_monomer_distillation
+ id_column: example_id
+ data: /projects/ml/afavor/rna_distillation/rna_distillation_filtered_df.parquet
+ columns_to_load:
+ - example_id
+ - path
+ - cluster_id
+ - seq_hash
+ - overall_plddt
+ - overall_pde
+ - overall_pae
+
+ transform:
+ crop_contiguous_probability: 0.67
+ crop_spatial_probability: 0.33
+
diff --git a/models/rfd3/configs/datasets/val/design_validation_base.yaml b/models/rfd3/configs/datasets/val/design_validation_base.yaml
index 5aabcc07..90931e09 100644
--- a/models/rfd3/configs/datasets/val/design_validation_base.yaml
+++ b/models/rfd3/configs/datasets/val/design_validation_base.yaml
@@ -37,4 +37,5 @@ dataset:
# Other dataset-specific parameters
atom_1d_features: ${model.net.token_initializer.atom_1d_features}
- token_1d_features: ${model.net.token_initializer.token_1d_features}
\ No newline at end of file
+ token_1d_features: ${model.net.token_initializer.token_1d_features}
+ token_2d_features: ${model.net.token_initializer.token_2d_features}
\ No newline at end of file
diff --git a/models/rfd3/configs/datasets/val/pseudoknot.yaml b/models/rfd3/configs/datasets/val/pseudoknot.yaml
new file mode 100644
index 00000000..7cf5ce14
--- /dev/null
+++ b/models/rfd3/configs/datasets/val/pseudoknot.yaml
@@ -0,0 +1,9 @@
+
+defaults:
+ - design_validation_base
+ - _self_
+
+dataset:
+ name: pseudoknot
+ eval_every_n: 1
+ data: ${paths.data.design_benchmark_data_dir}/pseudoknot.json
diff --git a/models/rfd3/configs/experiment/rfd3na_fine_tune.yaml b/models/rfd3/configs/experiment/rfd3na_fine_tune.yaml
new file mode 100644
index 00000000..3f2dfbe1
--- /dev/null
+++ b/models/rfd3/configs/experiment/rfd3na_fine_tune.yaml
@@ -0,0 +1,105 @@
+# @package _global_
+# Training configuration for RFD3
+
+defaults:
+ #- /debug/default
+ - override /model: rfd3_base
+ - override /logger: wandb
+ - override /datasets: design_base_rfd3na
+ - _self_
+
+name: rfd3na-fine-tune-ss0.5
+tags: [print-model]
+ckpt_path: null
+
+model:
+ net:
+ token_initializer:
+ token_1d_features:
+ ref_motif_token_type: 3
+ restype: 32
+ is_dna_token: 1
+ is_rna_token: 1
+ is_protein_token: 1
+ token_2d_features:
+ bp_partners: 3 # Unspecified, pair, loop
+ atom_1d_features:
+ ref_atom_name_chars: 256
+ ref_element: 128
+ ref_charge: 1
+ ref_mask: 1
+ ref_is_motif_atom_with_fixed_coord: 1
+ ref_is_motif_atom_unindexed: 1
+ has_zero_occupancy: 1
+ ref_pos: 3
+
+ # Guided features
+ ref_atomwise_rasa: 3
+ active_donor: 1
+ active_acceptor: 1
+ is_atom_level_hotspot: 1
+ diffusion_module:
+ n_recycle: 2
+ use_local_token_attention: True
+ diffusion_transformer:
+ n_local_tokens: 32
+ n_keys: 128
+
+ inference_sampler:
+ num_timesteps: 100
+
+
+datasets:
+ diffusion_batch_size_train: 16
+ crop_size: 256
+ max_atoms_in_crop: 2560 # ~10x crop size.
+ global_transform_args:
+ meta_conditioning_probabilities:
+ p_is_nucleic_ss_example: 0.5
+ p_nucleic_ss_show_partial_feats: 0.7
+ p_canonical_bp_filter: 0.2
+ #calculate_NA_SS: 0.3
+
+ association_scheme: atom23
+ #add_na_pair_features: true
+ train_conditions:
+ unconditional:
+ frequency: 2.0
+ island:
+ frequency: 2.0
+ sequence_design:
+ frequency: 0.5
+ tipatom:
+ frequency: 5.0
+ ppi:
+ frequency: 0.0
+ train:
+ # These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
+ pdb:
+ probability: 0.6
+ sub_datasets:
+ pn_unit:
+ weights:
+ alphas:
+ a_nuc: 3.0
+ interface:
+ weights:
+ alphas:
+ a_nuc: 3.0
+ rna_monomer_distillation:
+ probability: 0.3
+ monomer_distillation:
+ probability: 0.1
+
+ val:
+ pseudoknot:
+ dataset:
+ # eval_every_n: 10
+ eval_every_n: 5
+
+trainer:
+ #devices_per_node: 1
+ #limit_train_batches: 10
+ #limit_val_batches: 1
+ validate_every_n_epochs: 5
+ prevalidate: true
diff --git a/models/rfd3/configs/inference_engine/rfdiffusion3.yaml b/models/rfd3/configs/inference_engine/rfdiffusion3.yaml
index c5a6a72b..fe4a00fb 100644
--- a/models/rfd3/configs/inference_engine/rfdiffusion3.yaml
+++ b/models/rfd3/configs/inference_engine/rfdiffusion3.yaml
@@ -28,6 +28,7 @@ inference_sampler:
- active_donor
- active_acceptor
- ref_atomwise_rasa
+ - bp_partners
use_classifier_free_guidance: False
cfg_t_max: null # max t to apply cfg guidance
diff --git a/models/rfd3/configs/model/components/rfd3_net.yaml b/models/rfd3/configs/model/components/rfd3_net.yaml
index 40334833..83cddc4b 100644
--- a/models/rfd3/configs/model/components/rfd3_net.yaml
+++ b/models/rfd3/configs/model/components/rfd3_net.yaml
@@ -25,6 +25,9 @@ token_initializer: # formerly known as the trunk
ref_plddt: 1
is_non_loopy: 1
+ # Optional 2D token feature definitions (empty by default)
+ token_2d_features: {}
+
downcast: ${model.net.diffusion_module.downcast}
atom_1d_features:
ref_atom_name_chars: 256
diff --git a/models/rfd3/configs/trainer/metrics/design_metrics.yaml b/models/rfd3/configs/trainer/metrics/design_metrics.yaml
index 2a456051..c0bde00d 100644
--- a/models/rfd3/configs/trainer/metrics/design_metrics.yaml
+++ b/models/rfd3/configs/trainer/metrics/design_metrics.yaml
@@ -20,3 +20,16 @@ hbond_metrics:
_target_: rfd3.metrics.hbonds_hbplus_metrics.HbondMetrics
cutoff_HA_dist: 3
cutoff_DA_distance: 3.5
+
+nucleic_ss_similarity:
+ _target_: rfd3.metrics.nucleic_ss_metrics.NucleicSSSimilarityMetrics
+ restrict_to_nucleic: True
+ compute_for_diffused_region_only: False
+ annotate_predicted_fresh: True
+ annotation_NA_only: False
+ annotation_planar_only: True
+
+rna_aptamer_contacts:
+ _target_: rfd3.metrics.rna_aptamer_metrics.LigandContactMetrics
+ restrict_to_nucleic: True
+
diff --git a/models/rfd3/docs/.assets/multipolymer.png b/models/rfd3/docs/.assets/multipolymer.png
new file mode 100644
index 00000000..f081043a
Binary files /dev/null and b/models/rfd3/docs/.assets/multipolymer.png differ
diff --git a/models/rfd3/docs/.assets/overview.png b/models/rfd3/docs/.assets/overview.png
index ee53c36a..6495cf69 100644
Binary files a/models/rfd3/docs/.assets/overview.png and b/models/rfd3/docs/.assets/overview.png differ
diff --git a/models/rfd3/docs/examples/atom23_design.json b/models/rfd3/docs/examples/atom23_design.json
new file mode 100644
index 00000000..4632fef8
--- /dev/null
+++ b/models/rfd3/docs/examples/atom23_design.json
@@ -0,0 +1,98 @@
+{
+ "multipolymer": {
+ "contig": "40-50R,/0,10-20D,/0,80-110",
+ "length": "130-180",
+ "input": "../input_pdbs/AMP.pdb"
+ },
+ "W05": {
+ "ss_dbn": ".(((((((((((((((((((..[[[[[[.)))))(((....)))(((....)))))))))))))))))((((((..]]]]]].)))))).",
+ "select_fixed_atoms": false,
+ "contig": "90-90R",
+ "length": "90-90",
+ "input": "../input_pdbs/AMP.pdb"
+ },
+ "AMP_aptamer": {
+ "input": "../input_pdbs/AMP.pdb",
+ "ligand": "AMP",
+ "contig": "40-50R",
+ "length": "40-50",
+ "ori_jitter": 1,
+ "select_buried": {"AMP": "ALL"},
+ "select_hbond_acceptor": {
+ "AMP": "N7,O4',O1P,O2P,O3P,N3,N1"
+ },
+ "select_hbond_donor": {
+ "AMP": "N6,O3',O2'"
+ }
+ },
+ "FMN_aptamer": {
+ "input": "../input_pdbs/FMN_3x21.pdb",
+ "ligand": "FMN",
+ "contig": "40-50R",
+ "length": "40-50",
+ "ori_jitter": 1,
+ "select_buried": {"FMN": "ALL"},
+ "select_hbond_acceptor": {
+ "FMN": "O2,O4,N1,N5,O5',O2P,O3P"
+ },
+ "select_hbond_donor": {
+ "FMN": "N3,O2',O3',O4'"
+ }
+ },
+ "AMP_aptamer_noh": {
+ "input": "../input_pdbs/AMP.pdb",
+ "ligand": "AMP",
+ "contig": "40-50R",
+ "length": "40-50",
+ "ori_jitter": 1,
+ "select_buried": {"AMP": "ALL"}
+ },
+ "FMN_aptamer_noh": {
+ "input": "../input_pdbs/FMN_3x21.pdb",
+ "ligand": "FMN",
+ "contig": "40-50R",
+ "length": "40-50",
+ "ori_jitter": 1,
+ "select_buried": {"FMN": "ALL"}
+ },
+ "unindexed_rnasep": {
+ "input": "../input_pdbs/rnase_p_3q1q_active_site_small.pdb",
+ "contig": "50-80R,/0,100-120,/0,C1-4,C79-86",
+ "length": "162-212",
+ "ligand": "MG,PO4",
+ "unindex": "B49,B50,B51,B52,B321,/0,A56-58,/0",
+ "select_fixed_atoms": {
+ "B49": "ALL",
+ "B50": "ALL",
+ "B51": "ALL",
+ "B52": "ALL",
+ "B321": "ALL",
+ "A56-58": "ALL",
+ "C1-4": "ALL",
+ "C79-86": "ALL"
+ }
+ },
+ "dict_input_ss": {
+ "ss_dbn_dict": {
+ "A6-25":"(((..)))....(((..)))",
+ "B1-20":"((((..))))...((...))"
+ },
+ "contig":"30-30R,/0,30-30R",
+ "length":"60-60",
+ "input":"../input_pdbs/AMP.pdb"
+ },
+ "paired_region_input_ss": {
+ "paired_region_list": ["A20-25,B10-15"],
+ "loop_region_list":["A10-19","B20-30"],
+ "contig":"50-50R,/0,50-50R",
+ "length":"100-100",
+ "input":"../input_pdbs/AMP.pdb"
+ },
+ "paired_position_input_ss": {
+ "paired_position_list": ["A3,B3","A5,B5","A7,B7","A9,B9","A11,B11","A13,B13","A15,B15","A17,B17","A19,B19"],
+ "contig":"20-20R,/0,20-20R",
+ "length":"40-40",
+ "input":"../input_pdbs/AMP.pdb"
+
+ }
+}
diff --git a/models/rfd3/docs/examples/atom23_design.md b/models/rfd3/docs/examples/atom23_design.md
new file mode 100644
index 00000000..91f54ebc
--- /dev/null
+++ b/models/rfd3/docs/examples/atom23_design.md
@@ -0,0 +1,228 @@
+# RNA / DNA Design in RFdiffusion3
+
+This guide describes extensions to RFdiffusion3 for nucleic acid and hybrid RNA–protein design, including:
+
+- RNA/DNA-aware contigs (`R` / `D` suffix)
+- Ligand-conditioned aptamer design
+- Secondary structure (SS) conditioning
+- Base-pair constraints (region- and position-level)
+- Partial structure fixing and unindexing
+
+---
+
+## 1. Contig Syntax for RNA/DNA
+
+Contigs now support nucleic acid specification:
+
+- `R` → RNA segment
+- `D` → DNA segment
+- No suffix → protein (default)
+
+### Example
+
+```json
+{
+ "contig": "40-50R,/0,10-20D,/0,80-110"
+}
+```
+This corresponds to: 40–50 nt RNA, chain break, 10–20 nt DNA, chain break, 80–110 aa protein
+
+Multipolymer Design
+
+```json
+
+{
+ "multipolymer": {
+ "contig": "40-50R,/0,10-20D,/0,80-110",
+ "length": "130-180",
+ "input": "../input_pdbs/AMP.pdb"
+ }
+}
+```
+
+## 2. Secondary Structure Conditioning
+### 2.1 Dot-Bracket Notation (Global)
+```json
+{
+ "W05": {
+ "ss_dbn": ".(((((((((((((((((((..[[[[[[.)))))(((....)))(((....)))))))))))))))))((((((..]]]]]].)))))).",
+ "select_fixed_atoms": false,
+ "contig": "90-90R",
+ "length": "90-90",
+ "input": "../input_pdbs/AMP.pdb"
+ }
+}
+```
+`ss_dbn` specifies full RNA secondary structure
+
+Will be applied to the first L tokens, where L is the length of `ss_dbn`.
+
+### 2.2 Dictionary-Based SS Input
+
+Specify secondary structure for subsections:
+``` json
+{
+ "ss_dbn_dict": {
+ "A6-25": "(((..)))....(((..)))",
+ "B1-20": "((((..))))...((...))"
+ }
+}
+```
+Used in:
+``` json
+{
+ "dict_input_ss": {
+ "ss_dbn_dict": {
+ "A6-25": "(((..)))....(((..)))",
+ "B1-20": "((((..))))...((...))"
+ },
+ "contig": "30-30R,/0,30-30R",
+ "length": "60-60",
+ "input": "../input_pdbs/AMP.pdb"
+ }
+}
+```
+## 3. Base Pair region Conditioning
+### 3.1 Paired Regions
+
+Define paired and loop regions:
+```json
+{
+ "paired_region_list": ["A20-25,B10-15"],
+ "loop_region_list": ["A10-19","B20-30"]
+}
+```
+Enforces pairing and loop propensity between residue ranges during sampling
+
+Used in:
+```json
+{
+ "paired_region_input_ss": {
+ "paired_region_list": ["A20-25,B10-15"],
+ "loop_region_list": ["A10-19","B20-30"],
+ "contig": "50-50R,/0,50-50R",
+ "length": "100-100",
+ "input": "../input_pdbs/AMP.pdb"
+ }
+}
+```
+
+### 3.2 Explicit Base Pair Positions
+
+Fine-grained base pairing control:
+
+```json
+{
+ "paired_position_list": [
+ "A3,B3","A5,B5","A7,B7","A9,B9","A11,B11",
+ "A13,B13","A15,B15","A17,B17","A19,B19"
+ ]
+}
+```
+Used in:
+```json
+{
+ "paired_position_input_ss": {
+ "paired_position_list": [
+ "A3,B3","A5,B5","A7,B7","A9,B9","A11,B11",
+ "A13,B13","A15,B15","A17,B17","A19,B19"
+ ],
+ "contig": "20-20R,/0,20-20R",
+ "length": "40-40",
+ "input": "../input_pdbs/AMP.pdb"
+ }
+}
+```
+### Note: Most of the above jsons is not actually reading the `input` field. Kept as a dummy for the `inference_engine`.
+
+## 4. Ligand-Conditioned Aptamer Design
+
+Supports small molecule binding RNA design.
+
+AMP Aptamer Example
+```json
+{
+ "AMP_aptamer": {
+ "input": "../input_pdbs/AMP.pdb",
+ "ligand": "AMP",
+ "contig": "40-50R",
+ "length": "40-50",
+ "ori_jitter": 1,
+ "select_buried": {"AMP": "ALL"},
+ "select_hbond_acceptor": {
+ "AMP": "N7,O4',O1P,O2P,O3P,N3,N1"
+ },
+ "select_hbond_donor": {
+ "AMP": "N6,O3',O2'"
+ }
+ }
+}
+```
+Key Options
+
+`ligand`: ligand name in the input PDB
+
+`select_buried`: enforce burial of ligand atoms
+
+`select_hbond_acceptor` / `select_hbond_donor`: suggest Hbond interaction atoms
+
+`ori_jitter`: small random perturbation of ori token (from ligand COM)
+
+
+## 5. Hybrid RNA–Protein Design with Constraints
+### RNase P Active Site Example
+
+```json
+{
+ "unindexed_rnasep": {
+ "input": "../input_pdbs/rnase_p_3q1q_active_site_small.pdb",
+ "contig": "50-80R,/0,100-120,/0,C1-4,C79-86",
+ "length": "162-212",
+ "ligand": "MG,PO4",
+ "unindex": "B49,B50,B51,B52,B321,/0,A56-58,/0",
+ "select_fixed_atoms": {
+ "B49": "ALL",
+ "B50": "ALL",
+ "B51": "ALL",
+ "B52": "ALL",
+ "B321": "ALL",
+ "A56-58": "ALL",
+ "C1-4": "ALL",
+ "C79-86": "ALL"
+ }
+ }
+}
+```
+Key Features
+
+Mixed RNA + protein + fixed fragments
+
+`unindex`: removes residues from positional indexing
+
+`select_fixed_atoms`: freezes specified atoms
+
+Ligands (MG, PO4) included in design context
+
+Useful for catalytic residues or structural motifs
+
+## 7. Summary of Features Used
+
+R / D suffix → RNA / DNA specification in contigs
+
+`ss_dbn` → global secondary structure constraint
+
+`ss_dbn_dict` → local secondary structure constraints
+
+`paired_region_list` → helix-level pairing constraints
+
+`paired_position_list` → base-level pairing constraints
+
+ligand + selection options → aptamer design
+
+`unindex` → remove residues from indexing
+
+`select_fixed_atoms` → freeze structural elements
+
+
+---
+
diff --git a/models/rfd3/docs/input.md b/models/rfd3/docs/input.md
index 9253abc5..384101c6 100644
--- a/models/rfd3/docs/input.md
+++ b/models/rfd3/docs/input.md
@@ -120,7 +120,7 @@ Below is a table of all of the inputs that the `InputSpecification` accepts. Use
| -------------------------------------------------------------- | ----------------- | --------------------------------------------------------------------- |
| `input` | `str` | Path to and file name of **PDB/CIF**. Required if you provide contig+length. |
| `atom_array_input` | internal | Pre-loaded [`AtomArray`](https://www.biotite-python.org/latest/apidoc/biotite.structure.AtomArray.html) (not recommended). |
-| `contig` | `InputSelection` | (Can only pass a contig string.) Indexed motif specification, e.g., `"A1-80,10,/0,B5-12"`. |
+| `contig` | `InputSelection` | (Can only pass a contig string.) Indexed motif specification, e.g., `"A1-80,10,/0,B5-12"`. When running inference with an atom223 checkpoint, contigs can include `R` or `D` suffixes to denote designed polymer type is RNA or DNA e.g. "A1-80,10-10R,20-20D,30-30,/0,B5-12" (designs a 10 length RNA chain, 20 length DNA chain and a 30 length protein chain). |
| `unindex` | `InputSelection` | (Can only pass a contig string or dictionary.) Unindexed motif components, the specified residues can be anywhere in the final sequence. See [Unindexing Specifics](#unindexing-specifics) for more information. |
| `length` | `str` | Total design length constraint; `"min-max"` or int for specified length. |
| `ligand` | `str` | Ligand(s) by chemical component name (from [RSCB PDB](https://www.rcsb.org/)) or index. |
@@ -136,6 +136,7 @@ Below is a table of all of the inputs that the `InputSpecification` accepts. Use
| `symmetry` | `SymmetryConfig` | See {doc}`examples/symmetry`. |
| `ori_token` | `list[float]` | `[x,y,z]` origin override to control COM (center of mass) placement of designed structure. |
| `infer_ori_strategy` | `str` | `"com"` or `"hotspots"`. The center of mass of the diffused region will typically be within 5Å of the ORI token. Using `hotspots` will place the ORI token 10Å outward from the center of mass of the specified hotspots. Using `com` will place the token at the center of mass of the input structure.|
+| `ori_jitter` | `float` | default `None`. Per batch, move the ori token in a random direction by a distance sampled from a geomtric distribution with the mean as the specified float value in Angstrom.|
| `plddt_enhanced` | `bool` | Default `True`. Enables pLDDT (predicted Local Distance Difference Test) enhancement. |
| `is_non_loopy` | `bool \| None` | Default `None`. If `True`/`False`, produces output structures with fewer/more loops.|
| `partial_t` | `float` | Noise (Å) for partial diffusion, enables partial diffusion (sets the noise level.) Recommended values are 5.0-15.0 Å. See [Partial Diffusion](#partial-diffusion) for more information. |
diff --git a/models/rfd3/src/rfd3/constants.py b/models/rfd3/src/rfd3/constants.py
index 59c387c3..5407e681 100644
--- a/models/rfd3/src/rfd3/constants.py
+++ b/models/rfd3/src/rfd3/constants.py
@@ -72,6 +72,37 @@
'GLY': (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
'UNK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk
'MSK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask
+ 'DA': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
+ ' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 ',
+ None),
+
+ 'DC': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
+ ' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 ',
+ None, None, None),
+
+ 'DG': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
+ ' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '),
+
+ 'DT': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
+ ' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C7 ', ' C6 ',
+ None, None),
+
+ 'A' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
+ ' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 ',
+ None),
+
+ 'C' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
+ ' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 ',
+ None, None, None),
+
+ 'G' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
+ ' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '),
+
+ 'U' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
+ ' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C6 ',
+ None, None, None),
+ 'DX': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", None, None, None, None, None, None, None, None, None, None, None), #dna_mask
+ 'X': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", None, None, None, None, None, None, None, None, None, None, None), #rna mask
}
"""Canonical ordering of amino acid atom names in the CCD."""
@@ -195,10 +226,7 @@
strip_list = lambda x: [(x.strip() if x is not None else None) for x in x] # noqa
-association_schemes_stripped = {
- name: {k: strip_list(v) for k, v in scheme.items()}
- for name, scheme in association_schemes.items()
-}
+
SELECTION_PROTEIN = ["POLYPEPTIDE(D)", "POLYPEPTIDE(L)"]
SELECTION_NONPROTEIN = [
"POLYDEOXYRIBONUCLEOTIDE",
@@ -210,3 +238,461 @@
"MACROLIDE",
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
]
+
+backbone_atomscheme_DNA = [
+ " P ",
+ " OP1",
+ " OP2",
+ " O5'",
+ " C5'",
+ " C4'",
+ " O4'",
+ " C3'",
+ " O3'",
+ " C2'",
+ " C1'",
+] # , None]
+
+backbone_atomscheme_RNA = [
+ " P ",
+ " OP1",
+ " OP2",
+ " O5'",
+ " C5'",
+ " C4'",
+ " O4'",
+ " C3'",
+ " O3'",
+ " C2'",
+ " O2'",
+ " C1'",
+]
+
+DNA_atoms = {
+ "DA": [
+ " N9 ",
+ " C8 ",
+ " N7 ",
+ " C5 ",
+ " C6 ",
+ " N6 ",
+ " N1 ",
+ " C2 ",
+ " N3 ",
+ " C4 ",
+ ],
+ "DC": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " N4 ", " C5 ", " C6 "],
+ "DG": [
+ " N9 ",
+ " C8 ",
+ " N7 ",
+ " C5 ",
+ " C6 ",
+ " O6 ",
+ " N1 ",
+ " C2 ",
+ " N2 ",
+ " N3 ",
+ " C4 ",
+ ],
+ "DT": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " O4 ", " C5 ", " C7 ", " C6 "],
+}
+
+RNA_atoms = {
+ "A": [
+ " N9 ",
+ " C8 ",
+ " N7 ",
+ " C5 ",
+ " C6 ",
+ " N6 ",
+ " N1 ",
+ " C2 ",
+ " N3 ",
+ " C4 ",
+ ],
+ "C": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " N4 ", " C5 ", " C6 "],
+ "G": [
+ " N9 ",
+ " C8 ",
+ " N7 ",
+ " C5 ",
+ " C6 ",
+ " O6 ",
+ " N1 ",
+ " C2 ",
+ " N2 ",
+ " N3 ",
+ " C4 ",
+ ],
+ "U": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " O4 ", " C5 ", " C6 "],
+}
+
+association_schemes["atom23"] = {}
+for item in DNA_atoms:
+ association_schemes["atom23"][item] = tuple(
+ backbone_atomscheme_DNA
+ + DNA_atoms[item]
+ + [None] * (22 - len(DNA_atoms[item] + backbone_atomscheme_DNA))
+ )
+for item in RNA_atoms:
+ association_schemes["atom23"][item] = tuple(
+ backbone_atomscheme_RNA
+ + RNA_atoms[item]
+ + [None] * (23 - len(RNA_atoms[item] + backbone_atomscheme_RNA))
+ )
+
+for item in association_schemes["dense"]:
+ association_schemes["atom23"][item] = association_schemes["dense"][item]
+
+association_schemes["atom23"]["DX"] = (
+ " P ",
+ " OP1",
+ " OP2",
+ " O5'",
+ " C5'",
+ " C4'",
+ " O4'",
+ " C3'",
+ " O3'",
+ " C2'",
+ " C1'",
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+) # rna_mask
+association_schemes["atom23"]["X"] = (
+ " P ",
+ " OP1",
+ " OP2",
+ " O5'",
+ " C5'",
+ " C4'",
+ " O4'",
+ " C3'",
+ " O3'",
+ " C2'",
+ " O2'",
+ " C1'",
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+) # rna mask
+
+ATOM23_ATOM_NAMES_RNA = np.array(
+ [item.strip() for item in backbone_atomscheme_RNA]
+ + [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))]
+)
+"""Atom23 atom names (e.g. CA, V1)"""
+
+ATOM23_ATOM_ELEMENTS_RNA = np.array(
+ ["P", "O", "O", "O", "C", "C", "O", "C", "O", "C", "O", "C"]
+ + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(23 - len(backbone_atomscheme_RNA))]
+)
+"""Atom23 element names (e.g. C, VX)"""
+
+ATOM23_ATOM_NAME_TO_ELEMENT = {
+ name: elem for name, elem in zip(ATOM23_ATOM_NAMES_RNA, ATOM23_ATOM_ELEMENTS_RNA)
+}
+ATOM23_ATOM_NAMES_DNA = np.array(
+ [item.strip() for item in backbone_atomscheme_DNA]
+ + [f"V{i}" for i in range(22 - len(backbone_atomscheme_DNA))]
+)
+"""Atom23 atom names (e.g. CA, V1)"""
+
+ATOM23_ATOM_ELEMENTS_DNA = np.array(
+ ["P", "O", "O", "O", "C", "C", "O", "C", "O", "C", "C"]
+ + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(22 - len(backbone_atomscheme_DNA))]
+)
+"""Atom23 element names (e.g. C, VX)"""
+
+
+"""Mapping from atom14 atom names (e.g. CA, V1) to their corresponding element names (e.g. C, VX)"""
+## combining name to element mapping, should be fine
+for item in ATOM14_ATOM_NAME_TO_ELEMENT:
+ ATOM23_ATOM_NAME_TO_ELEMENT[item] = ATOM14_ATOM_NAME_TO_ELEMENT[item]
+
+association_schemes_stripped = {
+ name: {k: strip_list(v) for k, v in scheme.items()}
+ for name, scheme in association_schemes.items()
+}
+
+backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA)
+backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)
+
+# Mapping from residue type to its backbone and sidechain atoms (for convenience)
+ATOM_REGION_BY_RESI = {
+ "ALA": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")},
+ "ARG": {
+ "bb": ("N", "CA", "C", "O"),
+ "sc": ("CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"),
+ },
+ "ASN": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "OD1", "ND2")},
+ "ASP": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "OD1", "OD2")},
+ "CYS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "SG")},
+ "GLN": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "OE1", "NE2")},
+ "GLU": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "OE1", "OE2")},
+ "GLY": {"bb": ("N", "CA", "C", "O"), "sc": ()},
+ "HIS": {
+ "bb": ("N", "CA", "C", "O"),
+ "sc": ("CB", "CG", "ND1", "CD2", "CE1", "NE2"),
+ },
+ "ILE": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG1", "CG2", "CD1")},
+ "LEU": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD1", "CD2")},
+ "LYS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "CE", "NZ")},
+ "MET": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "SD", "CE")},
+ "PHE": {
+ "bb": ("N", "CA", "C", "O"),
+ "sc": ("CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"),
+ },
+ "PRO": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD")},
+ "SER": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "OG")},
+ "THR": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "OG1", "CG2")},
+ "TRP": {
+ "bb": ("N", "CA", "C", "O"),
+ "sc": ("CB", "CG", "CD1", "CD2", "CE2", "CE3", "NE1", "CZ2", "CZ3", "CH2"),
+ },
+ "TYR": {
+ "bb": ("N", "CA", "C", "O"),
+ "sc": ("CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"),
+ },
+ "VAL": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG1", "CG2")},
+ "UNK": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")},
+ "MAS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")},
+ "DA": {
+ "bb": (
+ "O4'",
+ "C1'",
+ "C2'",
+ "OP1",
+ "P",
+ "OP2",
+ "O5'",
+ "C5'",
+ "C4'",
+ "C3'",
+ "O3'",
+ ),
+ "sc": ("N9", "C4", "N3", "C2", "N1", "C6", "C5", "N7", "C8", "N6"),
+ },
+ "DC": {
+ "bb": (
+ "O4'",
+ "C1'",
+ "C2'",
+ "OP1",
+ "P",
+ "OP2",
+ "O5'",
+ "C5'",
+ "C4'",
+ "C3'",
+ "O3'",
+ ),
+ "sc": ("N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"),
+ },
+ "DG": {
+ "bb": (
+ "O4'",
+ "C1'",
+ "C2'",
+ "OP1",
+ "P",
+ "OP2",
+ "O5'",
+ "C5'",
+ "C4'",
+ "C3'",
+ "O3'",
+ ),
+ "sc": ("N9", "C4", "N3", "C2", "N1", "C6", "C5", "N7", "C8", "N2", "O6"),
+ },
+ "DT": {
+ "bb": (
+ "O4'",
+ "C1'",
+ "C2'",
+ "OP1",
+ "P",
+ "OP2",
+ "O5'",
+ "C5'",
+ "C4'",
+ "C3'",
+ "O3'",
+ ),
+ "sc": ("N1", "C2", "O2", "N3", "C4", "O4", "C5", "C7", "C6"),
+ },
+ "DX": {
+ "bb": (
+ "O4'",
+ "C1'",
+ "C2'",
+ "OP1",
+ "P",
+ "OP2",
+ "O5'",
+ "C5'",
+ "C4'",
+ "C3'",
+ "O3'",
+ ),
+ "sc": (),
+ },
+ "A": {
+ "bb": (
+ "O4'",
+ "C1'",
+ "C2'",
+ "OP1",
+ "P",
+ "OP2",
+ "O5'",
+ "C5'",
+ "C4'",
+ "C3'",
+ "O3'",
+ "O2'",
+ ),
+ "sc": ("N1", "C2", "N3", "C4", "C5", "C6", "N6", "N7", "C8", "N9"),
+ },
+ "C": {
+ "bb": (
+ "O4'",
+ "C1'",
+ "C2'",
+ "OP1",
+ "P",
+ "OP2",
+ "O5'",
+ "C5'",
+ "C4'",
+ "C3'",
+ "O3'",
+ "O2'",
+ ),
+ "sc": ("N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"),
+ },
+ "G": {
+ "bb": (
+ "O4'",
+ "C1'",
+ "C2'",
+ "OP1",
+ "P",
+ "OP2",
+ "O5'",
+ "C5'",
+ "C4'",
+ "C3'",
+ "O3'",
+ "O2'",
+ ),
+ "sc": ("N1", "C2", "N2", "N3", "C4", "C5", "C6", "O6", "N7", "C8", "N9"),
+ },
+ "U": {
+ "bb": (
+ "O4'",
+ "C1'",
+ "C2'",
+ "OP1",
+ "P",
+ "OP2",
+ "O5'",
+ "C5'",
+ "C4'",
+ "C3'",
+ "O3'",
+ "O2'",
+ ),
+ "sc": ("N1", "C2", "O2", "N3", "C4", "O4", "C5", "C6"),
+ },
+ "X": {
+ "bb": (
+ "O4'",
+ "C1'",
+ "C2'",
+ "OP1",
+ "P",
+ "OP2",
+ "O5'",
+ "C5'",
+ "C4'",
+ "C3'",
+ "O3'",
+ "O2'",
+ ),
+ "sc": (),
+ },
+ "HIS_D": {
+ "bb": ("N", "CA", "C", "O"),
+ "sc": ("CB", "CG", "NE2", "CD2", "CE1", "ND1"),
+ },
+}
+# Known planar sidechain atoms for each canonical residue type:
+PLANAR_ATOMS_BY_RESI = {
+ "ALA": [],
+ "ARG": ["NH1", "NH2", "CZ", "NE", "CD"],
+ "ASN": ["OD1", "ND2", "CG", "CB"],
+ "ASP": ["OD1", "OD2", "CG", "CB"],
+ "CYS": [],
+ "GLN": ["OE1", "NE2", "CD", "CG"],
+ "GLU": ["OE1", "OE2", "CD", "CG"],
+ "GLY": [],
+ "HIS": ["ND1", "CE1", "NE2", "CD2", "CG", "CB"],
+ "ILE": [],
+ "LEU": [],
+ "LYS": [],
+ "MET": [],
+ "PHE": ["CZ", "CE1", "CE2", "CD1", "CD2", "CG", "CB"],
+ "PRO": [],
+ "SER": [],
+ "THR": [],
+ "TRP": ["CH2", "CZ3", "CZ2", "CE3", "CE2", "CD2", "NE1", "CD1", "CG", "CB"],
+ "TYR": ["OH", "CZ", "CE1", "CE2", "CD1", "CD2", "CG", "CB"],
+ "VAL": [],
+ "UNK": [],
+ "MAS": [],
+ "DA": ["N6", "C6", "N1", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
+ "DC": ["N4", "C4", "N3", "O2", "C2", "C5", "C6", "N1"],
+ "DG": ["O6", "C6", "N1", "N2", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
+ "DT": ["O4", "O2", "N3", "C4", "C2", "C5", "C6", "N1", "C7"],
+ "DX": [],
+ "A": ["N6", "C6", "N1", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
+ "C": ["N4", "C4", "N3", "O2", "C2", "C5", "C6", "N1"],
+ "G": ["O6", "C6", "N1", "N2", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
+ "U": ["O4", "O2", "N3", "C4", "C2", "C5", "C6", "N1"],
+ "X": [],
+ "HIS_D": ["ND1", "CD2", "CE1", "NE2", "CG", "CB"],
+}
+
+# fix C/U symmetry
+temp = list(association_schemes["atom23"]["U"])
+temp[19], temp[20] = temp[20], temp[19]
+association_schemes["atom23"]["U"] = tuple(temp)
+
+association_schemes_stripped = {
+ name: {k: strip_list(v) for k, v in scheme.items()}
+ for name, scheme in association_schemes.items()
+}
+
+if __name__ == "__main__":
+ import pdb
+
+ pdb.set_trace()
diff --git a/models/rfd3/src/rfd3/inference/input_parsing.py b/models/rfd3/src/rfd3/inference/input_parsing.py
index d97b3be3..8293ca74 100644
--- a/models/rfd3/src/rfd3/inference/input_parsing.py
+++ b/models/rfd3/src/rfd3/inference/input_parsing.py
@@ -9,7 +9,7 @@
from typing import Any, Dict, List, Optional, Union
import numpy as np
-from atomworks.constants import STANDARD_AA
+from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA
from atomworks.io.parser import parse_atom_array
# from atomworks.ml.datasets.datasets import BaseDataset
@@ -30,9 +30,12 @@
OPTIONAL_CONDITIONING_VALUES,
REQUIRED_CONDITIONING_ANNOTATION_VALUES,
REQUIRED_INFERENCE_ANNOTATIONS,
+ backbone_atoms_DNA,
+ backbone_atoms_RNA,
)
from rfd3.inference.legacy_input_parsing import (
create_atom_array_from_design_specification_legacy,
+ reorder_atoms_per_residue,
)
from rfd3.inference.parsing import InputSelection
from rfd3.inference.symmetry.symmetry_utils import (
@@ -67,7 +70,6 @@
logger = RankedLogger(__name__, rank_zero_only=True)
-
#################################################################################
# Custom infer_ori functions
#################################################################################
@@ -117,7 +119,9 @@ class DesignInputSpecification(BaseModel):
validate_assignment=False,
str_strip_whitespace=True,
str_min_length=1,
- extra="forbid",
+ # extra="forbid", ####################################################
+ extra="allow",
+ ## for now allowing extra for rfd3na-ss purposes, can decide later ##
)
# fmt: off
# ========================================================================
@@ -180,6 +184,7 @@ class DesignInputSpecification(BaseModel):
symmetry: Optional[SymmetryConfig] = Field(None, description="Symmetry specification, see docs/symmetry.md")
# Centering & COM guidance
ori_token: Optional[list[float]] = Field(None, description="Origin coordinates")
+ ori_jitter: Optional[float] = Field(None, description="Jitter ori in a random direction and use ori_jitter to sample distance via exponential distribution")
infer_ori_strategy: Optional[str] = Field(None, description="Strategy for inferring origin; `com` or `hotspots`")
# Additional global conditioning
plddt_enhanced: Optional[bool] = Field(True, description="Enable pLDDT enhancement")
@@ -492,6 +497,15 @@ def apply_selections(start, end):
aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
is_bkbn, False, dtype=int
)
+ elif (
+ aa.res_name[start] in (STANDARD_DNA + STANDARD_RNA)
+ and self.redesign_motif_sidechains
+ ):
+ is_bkbn = np.isin(aa.atom_name[start:end], backbone_atoms_RNA)
+ aa.is_motif_atom_with_fixed_coord[start:end] = is_bkbn.astype(int)
+ aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
+ is_bkbn, False, dtype=int
+ )
# ... Apply selections on top
apply_selections(start, end)
@@ -505,6 +519,24 @@ def apply_selections(start, end):
def build(self, return_metadata=False):
"""Main build pipeline."""
atom_array_input_annotated = copy.deepcopy(self.atom_array_input)
+
+ ########## reorder NA atoms ###########
+ if exists(atom_array_input_annotated):
+ is_dna = np.isin(
+ atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"]
+ )
+ is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"])
+ dna_array = atom_array_input_annotated[is_dna]
+ rna_array = atom_array_input_annotated[is_rna]
+
+ atom_array_input_annotated[is_dna] = reorder_atoms_per_residue(
+ dna_array, backbone_atoms_DNA
+ )
+ atom_array_input_annotated[is_rna] = reorder_atoms_per_residue(
+ rna_array, backbone_atoms_RNA
+ )
+ #######################################
+
atom_array = self._build_init(atom_array_input_annotated)
# Apply post-processing
@@ -699,6 +731,13 @@ def _append_ligand(self, atom_array, atom_array_input_annotated):
ligand_array.set_annotation(
annot, np.full(ligand_array.array_length(), default)
)
+
+ chain_cand = "X"
+ while chain_cand in atom_array.chain_id.tolist():
+ chain_cand = chain_cand + chain_cand
+ ligand_chain = np.array([chain_cand] * len(ligand_array))
+ ligand_array.chain_id = ligand_chain
+
atom_array = atom_array + ligand_array
return atom_array
@@ -723,8 +762,13 @@ def _set_origin(self, atom_array):
"Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
)
else:
+ if not exists(self.ori_jitter):
+ self.ori_jitter = None
atom_array = set_com(
- atom_array, ori_token=None, infer_ori_strategy="com"
+ atom_array,
+ ori_token=None,
+ infer_ori_strategy="com",
+ ori_jitter=self.ori_jitter,
)
else:
# Standard: set ori token, zero out diffused atoms
@@ -894,35 +938,57 @@ def validator_context(validator_name: str, data: dict = None):
raise e
-def create_diffused_residues(n, additional_annotations=None):
+def create_diffused_residues(n, additional_annotations=None, polymer_type="P"):
+ from rfd3.constants import (
+ ATOM23_ATOM_NAME_TO_ELEMENT,
+ backbone_atoms_DNA,
+ backbone_atoms_RNA,
+ )
+
if n <= 0:
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
+ if polymer_type == "P":
+ res_name = "ALA"
+ bb_len = 5
+ bb_atom_names = ["N", "CA", "C", "O", "CB"]
+ elif polymer_type == "R":
+ res_name = "A"
+ bb_len = len(backbone_atoms_RNA)
+ bb_atom_names = backbone_atoms_RNA
+ elif polymer_type == "D":
+ res_name = "DA"
+ bb_len = len(backbone_atoms_DNA)
+ bb_atom_names = backbone_atoms_DNA
+ else:
+ raise ValueError(
+ f"invalid polymer type detected: {polymer_type}, check contig!"
+ )
+
+ bb_elements = [ATOM23_ATOM_NAME_TO_ELEMENT[item] for item in bb_atom_names]
+
atoms = []
[
atoms.extend(
[
struc.Atom(
np.array([0.0, 0.0, 0.0], dtype=np.float32),
- res_name="ALA",
+ res_name=res_name,
res_id=idx,
)
- for _ in range(5)
+ for _ in range(bb_len)
]
)
for idx in range(1, n + 1)
]
array = struc.array(atoms)
- array.set_annotation(
- "element", np.array(["N", "C", "C", "O", "C"] * n, dtype=" AtomArray:
+ """
+ Reorder atoms within each residue of an AtomArray.
+ Atoms in `desired_order` appear first (in that order), followed by all others
+ in original order. Faster version using get_residue_starts().
+
+ Parameters:
+ - atom_array: AtomArray to reorder.
+ - desired_order: List of atom names in the desired per-residue order.
+
+ Returns:
+ - AtomArray with reordered atoms per residue.
+ """
+ if len(atom_array) == 0:
+ return atom_array
+ res_starts = get_residue_starts(atom_array)
+ res_starts = np.append(res_starts, len(atom_array)) # add end index for slicing
+ reordered_chunks = []
+ order_dict = {name: i for i, name in enumerate(desired_order)}
+
+ for i in range(len(res_starts) - 1):
+ start, end = res_starts[i], res_starts[i + 1]
+ residue = atom_array[start:end]
+
+ # Boolean masks for matching and non-matching atom names
+ in_order_mask = np.isin(residue.atom_name, desired_order)
+ not_in_order_mask = ~in_order_mask
+
+ # Sort matching atoms by desired order
+ atoms_in_order = residue[in_order_mask]
+ sort_idx = np.argsort([order_dict[name] for name in atoms_in_order.atom_name])
+ ordered_atoms = atoms_in_order[sort_idx]
+
+ # Remaining atoms as-is
+ remaining_atoms = residue[not_in_order_mask]
+
+ # Concatenate reordered residue
+ reordered_chunks.append(concatenate([ordered_atoms, remaining_atoms]))
+ return concatenate(reordered_chunks)
diff --git a/models/rfd3/src/rfd3/metrics/design_metrics.py b/models/rfd3/src/rfd3/metrics/design_metrics.py
index 5ac24fbc..1ac0ad51 100644
--- a/models/rfd3/src/rfd3/metrics/design_metrics.py
+++ b/models/rfd3/src/rfd3/metrics/design_metrics.py
@@ -4,6 +4,7 @@
get_token_starts,
)
from beartype.typing import Any
+from rfd3.constants import backbone_atoms_RNA
from rfd3.metrics.metrics_utils import (
_flatten_dict,
get_hotspot_contacts,
@@ -14,6 +15,7 @@
from foundry.metrics.metric import Metric
STANDARD_CACA_DIST = 3.8
+STANDARD_P_P_DISTANCE = 6.4 ## average of B and A form 7 and 5.9
def get_clash_metrics(
@@ -28,11 +30,17 @@ def get_clash_metrics(
)
def get_chainbreaks():
- ca_atoms = atom_array[atom_array.atom_name == "CA"]
+ if "CA" in atom_array.atom_name:
+ ca_atoms = atom_array[atom_array.atom_name == "CA"]
+ cut_off = STANDARD_CACA_DIST
+ elif "P" in atom_array.atom_name:
+ ca_atoms = atom_array[atom_array.atom_name == "P"]
+ cut_off = STANDARD_P_P_DISTANCE
+
xyz = ca_atoms.coord
xyz = torch.from_numpy(xyz)
ca_dists = torch.norm(xyz[1:] - xyz[:-1], dim=-1)
- deviation = torch.abs(ca_dists - STANDARD_CACA_DIST)
+ deviation = torch.abs(ca_dists - cut_off)
# Allow leniency for expected chain breaks (e.g. PPI)
chain_breaks = ca_atoms.chain_iid[1:] != ca_atoms.chain_iid[:-1]
@@ -45,7 +53,9 @@ def get_chainbreaks():
}
def get_interresidue_clashes(backbone_only=False):
- protein_array = atom_array[atom_array.is_protein]
+ protein_array = atom_array[
+ atom_array.is_protein | atom_array.is_dna | atom_array.is_rna
+ ]
resid = protein_array.res_id - protein_array.res_id.min()
xyz = protein_array.coord
dists = np.linalg.norm(xyz[:, None] - xyz[None], axis=-1) # N_atoms x N_atoms
@@ -58,7 +68,9 @@ def get_interresidue_clashes(backbone_only=False):
if backbone_only:
# Block out non-backbone atoms
- backbone_mask = np.isin(protein_array.atom_name, ["N", "CA", "C"])
+ backbone_mask = np.isin(
+ protein_array.atom_name, ["N", "CA", "C"] + backbone_atoms_RNA
+ )
mask = backbone_mask[:, None] & backbone_mask[None, :]
dists[~mask] = 999
@@ -291,6 +303,7 @@ def __init__(self, compute_for_diffused_region_only: bool = False):
3.0 # maximum closest-neighbour distance before considered a floating atom
)
self.standard_ca_dist = 3.8
+ self.standard_PP_dist = 6.4
self.compute_for_diffused_region_only = compute_for_diffused_region_only
@property
@@ -310,6 +323,8 @@ def compute(self, X_L, tok_idx, f):
) # N_atoms x N_atoms
is_protein = f["is_protein"][tok_idx].cpu().numpy() # n_atoms
+ is_rna = f["is_rna"][tok_idx].cpu().numpy()
+ is_dna = f["is_dna"][tok_idx].cpu().numpy()
mask = np.zeros_like(dists, dtype=bool)
mask = mask | (np.eye(dists.shape[-1], dtype=bool))[None]
@@ -362,19 +377,51 @@ def compute(self, X_L, tok_idx, f):
if self.compute_for_diffused_region_only:
is_ca = is_ca[diffused_region]
is_protein = is_protein[diffused_region]
- idx_mask = is_ca & is_protein
+ is_dna = is_dna[diffused_region]
+ is_rna = is_rna[diffused_region]
+ protein_idx_mask = is_ca & (is_protein)
+ na_idx_mask = is_ca & (is_rna | is_dna)
+
if self.compute_for_diffused_region_only:
- xyz = X_L.cpu()[:, diffused_region][:, idx_mask]
+ xyz_protein = X_L.cpu()[:, diffused_region][:, protein_idx_mask]
+ xyz_na = X_L.cpu()[:, diffused_region][:, na_idx_mask]
else:
- xyz = X_L.cpu()[:, idx_mask]
+ xyz_protein = X_L.cpu()[:, protein_idx_mask]
+ xyz_na = X_L.cpu()[:, na_idx_mask]
- ca_dists = torch.norm(xyz[:, 1:] - xyz[:, :-1], dim=-1)
- deviation = torch.abs(ca_dists - self.standard_ca_dist) # B, (I-1)
- is_chainbreak = deviation > 0.75
+ ca_dists_protein = torch.norm(
+ xyz_protein[:, 1:] - xyz_protein[:, :-1], dim=-1
+ )
+ ca_dists_na = torch.norm(xyz_na[:, 1:] - xyz_na[:, :-1], dim=-1)
+
+ deviation_protein = torch.abs(
+ ca_dists_protein - self.standard_ca_dist
+ ) # B, (I-1)
+ deviation_na = torch.abs(ca_dists_na - self.standard_PP_dist) # B, (I-1)
+ is_chainbreak_protein = deviation_protein > 0.75
+ is_chainbreak_na = deviation_na > 1
+
+ try:
+ o["max_ca_deviation_protein"] = float(
+ deviation_protein.max(-1).values.mean()
+ )
+ o["fraction_chainbreaks_protein"] = float(
+ is_chainbreak_protein.float().mean(-1).mean()
+ )
+ o["n_chainbreaks_protein"] = float(
+ is_chainbreak_protein.float().sum(-1).mean()
+ )
+ except Exception:
+ print("No protein in this example, skipping protein chainbreak metrics")
- o["max_ca_deviation"] = float(deviation.max(-1).values.mean())
- o["fraction_chainbreaks"] = float(is_chainbreak.float().mean(-1).mean())
- o["n_chainbreaks"] = float(is_chainbreak.float().sum(-1).mean())
+ try:
+ o["max_ca_deviation_na"] = float(deviation_na.max(-1).values.mean())
+ o["fraction_chainbreaks_na"] = float(
+ is_chainbreak_na.float().mean(-1).mean()
+ )
+ o["n_chainbreaks_na"] = float(is_chainbreak_na.float().sum(-1).mean())
+ except Exception:
+ print("No NA in this example, skipping NA chainbreak metrics")
return o
diff --git a/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py
new file mode 100644
index 00000000..53c21194
--- /dev/null
+++ b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py
@@ -0,0 +1,380 @@
+import logging
+
+import numpy as np
+from atomworks.ml.utils.token import (
+ get_token_starts,
+)
+from biotite.structure import AtomArray
+from rfd3.trainer.trainer_utils import (
+ _cleanup_virtual_atoms_and_assign_atom_name_elements,
+ _readout_seq_from_struc,
+)
+from rfd3.transforms.na_geom_utils import annotate_na_ss
+
+from foundry.metrics.metric import Metric
+from foundry.utils.ddp import RankedLogger
+
+logging.basicConfig(level=logging.INFO)
+global_logger = RankedLogger(__name__, rank_zero_only=False)
+
+
+def _get_bp_partners_annotation(atom_array: AtomArray):
+ """Return bp-partners annotation."""
+ categories = atom_array.get_annotation_categories()
+ if "bp_partners" in categories:
+ return atom_array.bp_partners
+ raise ValueError("atom_array missing bp_partners annotation")
+
+
+def _safe_f1_from_sizes(intersection_n: int, pred_n: int, gt_n: int) -> float:
+ """Return F1 with sensible empty-set handling."""
+ if pred_n == 0 and gt_n == 0:
+ return 1.0
+
+ precision = float(intersection_n / pred_n) if pred_n > 0 else 0.0
+ recall = float(intersection_n / gt_n) if gt_n > 0 else 0.0
+
+ if precision + recall == 0.0:
+ return 0.0
+
+ return float(2.0 * precision * recall / (precision + recall))
+
+
+def _get_token_ids(atom_array: AtomArray) -> np.ndarray:
+ token_starts = get_token_starts(atom_array)
+ token_level_array = atom_array[token_starts]
+ return np.asarray(token_level_array.token_id, dtype=int)
+
+
+def _get_candidate_token_ids(
+ atom_array: AtomArray,
+ *,
+ restrict_to_nucleic: bool,
+ compute_for_diffused_region_only: bool,
+) -> set[int]:
+ """Return a set of token_ids to include for scoring."""
+ token_starts = get_token_starts(atom_array)
+ token_level_array = atom_array[token_starts]
+ token_ids = np.asarray(token_level_array.token_id, dtype=int)
+
+ token_mask = np.ones(len(token_ids), dtype=bool)
+
+ if restrict_to_nucleic:
+ is_rna = (
+ np.asarray(getattr(token_level_array, "is_rna"), dtype=bool)
+ if hasattr(token_level_array, "is_rna")
+ else np.zeros(len(token_ids), dtype=bool)
+ )
+ is_dna = (
+ np.asarray(getattr(token_level_array, "is_dna"), dtype=bool)
+ if hasattr(token_level_array, "is_dna")
+ else np.zeros(len(token_ids), dtype=bool)
+ )
+ token_mask &= (
+ (is_rna | is_dna) if (is_rna.any() or is_dna.any()) else token_mask
+ )
+
+ if compute_for_diffused_region_only:
+ if hasattr(token_level_array, "is_motif_atom"):
+ token_mask &= ~np.asarray(token_level_array.is_motif_atom, dtype=bool)
+ elif hasattr(token_level_array, "is_motif_token"):
+ token_mask &= ~np.asarray(token_level_array.is_motif_token, dtype=bool)
+
+ return set(int(t) for t in token_ids[token_mask].tolist())
+
+
+def _extract_bp_pairs(
+ atom_array: AtomArray,
+ *,
+ allowed_token_ids: set[int],
+) -> set[tuple[int, int]]:
+ """Extract unordered base-pair edges from bp-partner annotations.
+
+ Pairs are represented as (min_token_id, max_token_id).
+ """
+ token_starts = get_token_starts(atom_array)
+ token_level_array = atom_array[token_starts]
+ token_ids = np.asarray(token_level_array.token_id, dtype=int)
+ token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())}
+
+ bp_partner_ann = _get_bp_partners_annotation(atom_array)
+ pairs: set[tuple[int, int]] = set()
+
+ for pos, start_idx in enumerate(token_starts.tolist()):
+ i_tid = int(token_ids[pos])
+ if i_tid not in allowed_token_ids:
+ continue
+
+ partners = bp_partner_ann[int(start_idx)]
+ if partners is None:
+ continue
+ if not isinstance(partners, (list, tuple, np.ndarray)):
+ continue
+
+ for partner_token_id in partners:
+ try:
+ j_tid = int(partner_token_id)
+ except Exception:
+ continue
+
+ if j_tid == i_tid or j_tid not in allowed_token_ids:
+ continue
+
+ if j_tid not in token_id_to_pos:
+ continue
+
+ a, b = (i_tid, j_tid) if i_tid < j_tid else (j_tid, i_tid)
+ pairs.add((a, b))
+
+ return pairs
+
+
+def _extract_loop_and_paired_token_ids(
+ atom_array: AtomArray,
+ *,
+ allowed_token_ids: set[int],
+) -> tuple[set[int], set[int]]:
+ """Return (loop_token_ids, paired_token_ids) within the allowed token set."""
+ token_starts = get_token_starts(atom_array)
+ token_level_array = atom_array[token_starts]
+ token_ids = np.asarray(token_level_array.token_id, dtype=int)
+ token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())}
+
+ bp_partner_ann = _get_bp_partners_annotation(atom_array)
+
+ loop_token_ids: set[int] = set()
+ paired_token_ids: set[int] = set()
+
+ for pos, start_idx in enumerate(token_starts.tolist()):
+ i_tid = int(token_ids[pos])
+ if i_tid not in allowed_token_ids:
+ continue
+
+ partners = bp_partner_ann[int(start_idx)]
+ # New semantics:
+ # - None => unannotated/masked (NOT a loop)
+ # - [] => explicitly unpaired loop
+ if partners is None:
+ continue
+ if not isinstance(partners, (list, tuple, np.ndarray)):
+ continue
+ if len(partners) == 0:
+ loop_token_ids.add(i_tid)
+ continue
+
+ for partner_token_id in partners:
+ try:
+ j_tid = int(partner_token_id)
+ except Exception:
+ continue
+
+ if j_tid == i_tid or j_tid not in allowed_token_ids:
+ continue
+ if j_tid not in token_id_to_pos:
+ continue
+ paired_token_ids.add(i_tid)
+ paired_token_ids.add(j_tid)
+
+ return loop_token_ids, paired_token_ids
+
+
+def compute_from_two_arr(
+ gt_arr, pred_arr, restrict_to_nucleic=True, compute_for_diffused_region_only=False
+):
+ gt_token_ids = _get_token_ids(gt_arr)
+ pred_token_ids = _get_token_ids(pred_arr)
+ if len(gt_token_ids) != len(pred_token_ids):
+ None
+
+ # Restrict to token_ids that are valid in both arrays.
+ gt_allowed = _get_candidate_token_ids(
+ gt_arr,
+ restrict_to_nucleic=restrict_to_nucleic,
+ compute_for_diffused_region_only=compute_for_diffused_region_only,
+ )
+ pred_allowed = _get_candidate_token_ids(
+ pred_arr,
+ restrict_to_nucleic=restrict_to_nucleic,
+ compute_for_diffused_region_only=compute_for_diffused_region_only,
+ )
+ allowed = gt_allowed & pred_allowed
+
+ if len(allowed) == 0:
+ return None
+
+ gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed)
+ pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed)
+
+ gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids(
+ gt_arr, allowed_token_ids=allowed
+ )
+ pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids(
+ pred_arr, allowed_token_ids=allowed
+ )
+
+ pair_tp = len(gt_pairs & pred_pairs)
+ pair_pred_n = len(pred_pairs)
+ pair_gt_n = len(gt_pairs)
+
+ loop_tp = len(gt_loop & pred_loop)
+ loop_pred_n = len(pred_loop)
+ loop_gt_n = len(gt_loop)
+
+ pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n)
+ loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n)
+
+ pair_weight = len(gt_paired_tokens)
+ loop_weight = len(gt_loop)
+ total_weight = pair_weight + loop_weight
+ if total_weight == 0:
+ weighted_f1 = 1.0
+ else:
+ weighted_f1 = float(
+ (pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight
+ )
+
+ return pair_f1, loop_f1, weighted_f1
+
+
+def get_NA_SS_F1(pred_array):
+ ## save the original bop_partner annotation
+ gt_array = pred_array.copy()
+
+ ## replace by annotating again
+ pred_array = annotate_na_ss(
+ pred_array,
+ NA_only=True,
+ planar_only=True,
+ overwrite=True,
+ p_canonical_bp_filter=0.0,
+ )
+
+ try:
+ pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_array, pred_array)
+ except Exception:
+ # fails when returns None because expects three returns
+ return {}
+
+ return {
+ "pair_f1": pair_f1,
+ "loop_f1": loop_f1,
+ "weighted_f1": weighted_f1,
+ }
+
+
+class NucleicSSSimilarityMetrics(Metric):
+ """Secondary-structure similarity for nucleic acids.
+
+ Reports:
+ - `pair_f1`: F1 over basepair edges from token-level bp-partner annotation.
+ - `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partners == []`).
+ Unannotated tokens (`bp_partners is None`) are masked.
+ - `weighted_f1`: GT-weighted average of `pair_f1` and `loop_f1`, weighted by
+ the prevalence of paired vs loop tokens in the GT.
+ """
+
+ def __init__(
+ self,
+ *,
+ restrict_to_nucleic: bool = True,
+ compute_for_diffused_region_only: bool = False,
+ annotate_predicted_fresh: bool = False,
+ annotation_NA_only: bool = False,
+ annotation_planar_only: bool = True,
+ ):
+ super().__init__()
+ self.restrict_to_nucleic = restrict_to_nucleic
+ self.compute_for_diffused_region_only = compute_for_diffused_region_only
+ self.annotate_predicted_fresh = annotate_predicted_fresh
+ self.annotation_NA_only = annotation_NA_only
+ self.annotation_planar_only = annotation_planar_only
+
+ @property
+ def kwargs_to_compute_args(self):
+ return {
+ "ground_truth_atom_array_stack": ("ground_truth_atom_array_stack",),
+ "predicted_atom_array_stack": ("predicted_atom_array_stack",),
+ }
+
+ def compute(self, *, ground_truth_atom_array_stack, predicted_atom_array_stack):
+ if ground_truth_atom_array_stack is None or predicted_atom_array_stack is None:
+ return {}
+
+ pair_f1_list: list[float] = []
+ loop_f1_list: list[float] = []
+ weighted_f1_list: list[float] = []
+
+ n_valid = 0
+
+ for gt_arr, pred_arr in zip(
+ ground_truth_atom_array_stack, predicted_atom_array_stack
+ ):
+ gt_categories = gt_arr.get_annotation_categories()
+ if "bp_partners" not in gt_categories:
+ continue
+
+ # Important: predicted AtomArrays are built from a template AtomArray.
+ # If that template already carries bp_partners (often GT-derived), the
+ # prediction can inherit it, yielding artificially perfect scores.
+ # Optionally recompute bp_partners from the *predicted coordinates*.
+ if self.annotate_predicted_fresh:
+ try:
+ # Infer res name from geometry first
+ pred_arr = _readout_seq_from_struc(
+ pred_arr,
+ central_atom="C1'",
+ threshold=0.5,
+ association_scheme="atom23",
+ )
+
+ # strip virtuals and set final atom names/elements
+ pred_arr = _cleanup_virtual_atoms_and_assign_atom_name_elements(
+ pred_arr,
+ association_scheme="atom23",
+ )
+ except Exception:
+ # this can fail early in training
+ print("could not cleanup virtuals for nucleic ss metric compute")
+ pass
+ # clear annotation to avoid potential info leak
+ if "bp_partners" in pred_arr.get_annotation_categories():
+ pred_arr.del_annotation("bp_partners")
+
+ # add nucleic-ss annotations
+ annotate_na_ss(
+ pred_arr,
+ NA_only=self.annotation_NA_only,
+ planar_only=self.annotation_planar_only,
+ overwrite=True,
+ p_canonical_bp_filter=0.0,
+ )
+ pred_categories = pred_arr.get_annotation_categories()
+ if "bp_partners" not in pred_categories:
+ continue
+
+ # Basic sanity check: token counts should match for aligned comparisons
+ try:
+ pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(
+ gt_arr,
+ pred_arr,
+ restrict_to_nucleic=self.restrict_to_nucleic,
+ compute_for_diffused_region_only=self.compute_for_diffused_region_only,
+ )
+ except Exception:
+ # fails when returns None because expects three returns
+ continue
+
+ pair_f1_list.append(pair_f1)
+ loop_f1_list.append(loop_f1)
+ weighted_f1_list.append(weighted_f1)
+ n_valid += 1
+
+ if n_valid == 0:
+ return {}
+
+ return {
+ "pair_f1": float(np.mean(pair_f1_list)),
+ "loop_f1": float(np.mean(loop_f1_list)),
+ "weighted_f1": float(np.mean(weighted_f1_list)),
+ "n_valid_samples": int(n_valid),
+ }
diff --git a/models/rfd3/src/rfd3/metrics/rna_aptamer_metrics.py b/models/rfd3/src/rfd3/metrics/rna_aptamer_metrics.py
new file mode 100644
index 00000000..bf1c7a36
--- /dev/null
+++ b/models/rfd3/src/rfd3/metrics/rna_aptamer_metrics.py
@@ -0,0 +1,115 @@
+import logging
+
+import numpy as np
+
+from foundry.metrics.metric import Metric
+from foundry.utils.ddp import RankedLogger
+
+logging.basicConfig(level=logging.INFO)
+global_logger = RankedLogger(__name__, rank_zero_only=False)
+
+
+def calculate_ligand_contacts(
+ atom_array_stack,
+ cutoff_distance=4.0,
+):
+ """
+ Count number of atom contacts within cutoff of any ligand atom.
+
+ Parameters
+ ----------
+ atom_array_stack : AtomArrayStack
+ Shape: (n_models, n_atoms)
+ cutoff_distance : float
+ Distance cutoff in Å
+
+ Returns
+ -------
+ total_contacts : int
+ mean_contacts_per_model : float
+ """
+
+ cutoff_sq = cutoff_distance**2
+ contacts_per_model = []
+
+ n_models = len(atom_array_stack)
+
+ for i in range(n_models):
+ atoms = atom_array_stack[i]
+
+ coords = atoms.coord
+ hetero_mask = atoms.hetero.astype(bool)
+
+ # Skip if no ligand
+ if not np.any(hetero_mask):
+ contacts_per_model.append(0)
+ continue
+
+ ligand_coords = coords[hetero_mask]
+ non_ligand_coords = coords[~hetero_mask]
+
+ if len(non_ligand_coords) == 0:
+ contacts_per_model.append(0)
+ continue
+
+ # Pairwise squared distances
+ diff = non_ligand_coords[:, None, :] - ligand_coords[None, :, :]
+ dist_sq = np.sum(diff**2, axis=-1)
+
+ # Any ligand within cutoff
+ contact_mask = np.any(dist_sq < cutoff_sq, axis=1)
+
+ n_contacts = np.sum(contact_mask)
+ contacts_per_model.append(n_contacts)
+
+ contacts_per_model = np.array(contacts_per_model)
+
+ return (
+ int(np.sum(contacts_per_model)),
+ float(np.mean(contacts_per_model)),
+ float(np.mean(contacts_per_model)) / hetero_mask.sum(),
+ )
+
+
+class LigandContactMetrics(Metric):
+ def __init__(
+ self,
+ *,
+ cutoff_distance: float = 4.0,
+ restrict_to_nucleic: bool = True,
+ ):
+ super().__init__()
+ self.cutoff_distance = cutoff_distance
+ self.restrict_to_nucleic = restrict_to_nucleic
+
+ @property
+ def kwargs_to_compute_args(self):
+ return {
+ "predicted_atom_array_stack": ("predicted_atom_array_stack",),
+ }
+
+ def compute(self, *, predicted_atom_array_stack):
+ if self.restrict_to_nucleic:
+ if (
+ predicted_atom_array_stack[0].is_rna.sum()
+ + predicted_atom_array_stack[0].is_dna.sum()
+ == 0
+ ):
+ return {}
+ try:
+ total_contacts, mean_contacts, mean_contacts_per_atom = (
+ calculate_ligand_contacts(
+ atom_array_stack=predicted_atom_array_stack,
+ cutoff_distance=self.cutoff_distance,
+ )
+ )
+ except Exception as e:
+ global_logger.error(
+ f"Error calculating ligand contact metrics: {e} | Skipping"
+ )
+ return {}
+
+ return {
+ "mean_ligand_contacts_per_model": float(mean_contacts),
+ "mean_ligand_contacts_per_atom": float(mean_contacts_per_atom),
+ }
diff --git a/models/rfd3/src/rfd3/model/cfg_utils.py b/models/rfd3/src/rfd3/model/cfg_utils.py
index d99c2fa0..2b9f860d 100644
--- a/models/rfd3/src/rfd3/model/cfg_utils.py
+++ b/models/rfd3/src/rfd3/model/cfg_utils.py
@@ -57,9 +57,14 @@ def strip_f(
# set the feature to default value if it is in the cfg_features
if k in cfg_features:
- v_cropped = torch.zeros_like(v_cropped).to(
- v_cropped.device, dtype=v_cropped.dtype
- )
+ if k not in ["bp_partners"]:
+ v_cropped = torch.zeros_like(v_cropped).to(
+ v_cropped.device, dtype=v_cropped.dtype
+ )
+ else:
+ ## for bp_partners default is a mask feature
+ v_cropped[:, :, 0] = 1
+ v_cropped[:, :, 1:] = 0
# update the feature in the dictionary
f_stripped[k] = v_cropped
diff --git a/models/rfd3/src/rfd3/model/layers/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py
index aeac08c8..3dfa0d56 100644
--- a/models/rfd3/src/rfd3/model/layers/block_utils.py
+++ b/models/rfd3/src/rfd3/model/layers/block_utils.py
@@ -210,7 +210,9 @@ def create_attention_indices(
chain_ids is not None and len(torch.unique(chain_ids)) > 3
): # Multi-chain structure
# Reserve 25% of attention keys for inter-chain interactions
- k_inter_chain = max(32, k_actual // 4) # At least 32 inter-chain keys
+ k_inter_chain = min(
+ max(32, k_actual // 4), k_actual
+ ) # At least 32 inter-chain keys
k_intra_chain = k_actual - k_inter_chain
attn_indices = get_sparse_attention_indices_with_inter_chain(
diff --git a/models/rfd3/src/rfd3/model/layers/blocks.py b/models/rfd3/src/rfd3/model/layers/blocks.py
index eaf08093..3290b9e0 100644
--- a/models/rfd3/src/rfd3/model/layers/blocks.py
+++ b/models/rfd3/src/rfd3/model/layers/blocks.py
@@ -144,6 +144,43 @@ def forward(self, f, collapse_length):
)
+class TwoDFeatureEmbedder(nn.Module):
+ """
+ Embeds 2D features into a single vector.
+
+ Args:
+ features (dict): Dictionary of feature names and their number of channels.
+ output_channels (int): Output dimension of the projected embedding.
+ """
+
+ def __init__(self, features, output_channels):
+ super().__init__()
+ self.features = {k: v for k, v in features.items() if exists(v)}
+ total_embedding_input_features = sum(self.features.values())
+ self.embedders = nn.ModuleDict(
+ {
+ feature: EmbeddingLayer(
+ n_channels, total_embedding_input_features, output_channels
+ )
+ for feature, n_channels in self.features.items()
+ }
+ )
+
+ def collapse2D(self, x, L):
+ return x.reshape((L, L, x.numel() // (L * L)))
+
+ def forward(self, f, collapse_length):
+ return sum(
+ tuple(
+ self.embedders[feature](
+ self.collapse2D(f[feature].float(), collapse_length)
+ )
+ for feature, n_channels in self.features.items()
+ if exists(n_channels)
+ )
+ )
+
+
class SinusoidalDistEmbed(nn.Module):
"""
Applies sinusoidal embedding to pairwise distances and projects to c_atompair.
diff --git a/models/rfd3/src/rfd3/model/layers/encoders.py b/models/rfd3/src/rfd3/model/layers/encoders.py
index b0ed86fa..fc19a95c 100644
--- a/models/rfd3/src/rfd3/model/layers/encoders.py
+++ b/models/rfd3/src/rfd3/model/layers/encoders.py
@@ -14,6 +14,7 @@
PositionPairDistEmbedder,
RelativePositionEncodingWithIndexRemoval,
SinusoidalDistEmbed,
+ TwoDFeatureEmbedder,
)
from rfd3.model.layers.chunked_pairwise import (
ChunkedPairwiseEmbedder,
@@ -51,6 +52,7 @@ def __init__(
token_1d_features,
atom_1d_features,
atom_transformer,
+ token_2d_features=None,
use_chunked_pll=False, # New parameter for memory optimization
):
super().__init__()
@@ -62,6 +64,10 @@ def __init__(
self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
+ if token_2d_features is not None:
+ self.token_2d_embedder = TwoDFeatureEmbedder(token_2d_features, c_z)
+ else:
+ self.token_2d_embedder = None
self.downcast_atom = Downcast(c_atom=c_s, c_token=c_s, c_s=None, **downcast)
self.transition_post_token = Transition(c=c_s, n=2)
@@ -202,6 +208,9 @@ def init_tokens():
Z_init_II = Z_init_II + self.ref_pos_embedder_tok(
f["ref_pos"][f["is_ca"]], valid_mask
)
+ # Add extra token pair features
+ if self.token_2d_embedder is not None:
+ Z_init_II = Z_init_II + self.token_2d_embedder(f, I)
# Run a small transformer to provide position encodings to single.
for block in self.transformer_stack:
diff --git a/models/rfd3/src/rfd3/trainer/rfd3.py b/models/rfd3/src/rfd3/trainer/rfd3.py
index a7f72e6f..93a8f18c 100644
--- a/models/rfd3/src/rfd3/trainer/rfd3.py
+++ b/models/rfd3/src/rfd3/trainer/rfd3.py
@@ -8,6 +8,7 @@
from omegaconf import DictConfig
from rfd3.metrics.design_metrics import get_all_backbone_metrics
from rfd3.metrics.hbonds_hbplus_metrics import get_hbond_metrics
+from rfd3.metrics.nucleic_ss_metrics import get_NA_SS_F1
from rfd3.trainer.recycling import get_recycle_schedule
from rfd3.trainer.trainer_utils import (
_build_atom_array_stack,
@@ -428,9 +429,14 @@ def _build_predicted_atom_array_stack(
# ... Delete virtual atoms and assign atom names and elements
if self.cleanup_virtual_atoms:
- atom_array = _cleanup_virtual_atoms_and_assign_atom_name_elements(
- atom_array, association_scheme=self.association_scheme
- )
+ try:
+ atom_array = _cleanup_virtual_atoms_and_assign_atom_name_elements(
+ atom_array, association_scheme=self.association_scheme
+ )
+ except Exception as e:
+ global_logger.warning(
+ f"Failed to cleanup virtual atoms from diffusion output: {e}"
+ )
# ... When cleaning up virtual atoms, we can also calculate native_array_metricsl
metadata_dict[i]["metrics"] |= get_all_backbone_metrics(
@@ -444,6 +450,12 @@ def _build_predicted_atom_array_stack(
):
metadata_dict[i]["metrics"] |= get_hbond_metrics(atom_array)
+ if "bp_partners" in atom_array.get_annotation_categories():
+ if not np.all(atom_array.bp_partners == None): # noqa: E711
+ try:
+ metadata_dict[i]["metrics"] |= get_NA_SS_F1(atom_array)
+ except Exception:
+ pass
if "partial_t" in f:
# Try calcualte a CA RMSD to input:
aa_in = example["atom_array"]
diff --git a/models/rfd3/src/rfd3/trainer/trainer_utils.py b/models/rfd3/src/rfd3/trainer/trainer_utils.py
index 59f43a1c..a2cd07c9 100644
--- a/models/rfd3/src/rfd3/trainer/trainer_utils.py
+++ b/models/rfd3/src/rfd3/trainer/trainer_utils.py
@@ -2,6 +2,7 @@
import numpy as np
import torch
+from atomworks.constants import STANDARD_DNA, STANDARD_RNA
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from atomworks.ml.utils.token import (
get_token_starts,
@@ -11,9 +12,12 @@
from jaxtyping import Float, Int
from rfd3.constants import (
ATOM14_ATOM_NAMES,
+ ATOM23_ATOM_NAMES_DNA,
+ ATOM23_ATOM_NAMES_RNA,
VIRTUAL_ATOM_ELEMENT_NAME,
association_schemes,
association_schemes_stripped,
+ backbone_atoms_RNA,
)
from rfd3.utils.io import (
build_stack_from_atom_array_and_batched_coords,
@@ -148,7 +152,6 @@ def _cleanup_virtual_atoms_and_assign_atom_name_elements(
is_seq_known = all(
np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool)
) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool))
-
# ... If sequence is known for the original atom array, just skip
if is_seq_known:
ret_mask += [True] * len(res_array)
@@ -218,7 +221,10 @@ def _readout_seq_from_struc(
# There might be a better way to do this.
CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
- if np.linalg.norm(CA_coord - CB_coord) < threshold:
+
+ if cur_res_atom_array.is_dna[0] or cur_res_atom_array.is_rna[0]:
+ cur_central_atom = "C1'"
+ elif np.linalg.norm(CA_coord - CB_coord) < threshold:
cur_central_atom = "CA"
else:
cur_central_atom = central_atom
@@ -252,15 +258,28 @@ def _readout_seq_from_struc(
continue
# ... Find the index of virtual atom names in the standard atom14 names
+ ATOM_NAMES = ATOM14_ATOM_NAMES
+ if restype in STANDARD_DNA:
+ ATOM_NAMES = ATOM23_ATOM_NAMES_DNA
+ if not cur_res_atom_array.is_dna[0]:
+ continue
+ elif restype in STANDARD_RNA:
+ ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
+ if not cur_res_atom_array.is_rna[0]:
+ continue
+ else:
+ # ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
+ if not cur_res_atom_array.is_protein[0]:
+ continue
+
atom_name_idx_in_atom14_scheme = np.array(
[
- np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
+ np.where(ATOM_NAMES == atom_name)[0][0]
for atom_name in cur_pred_res_atom_names
]
) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
- atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
+ atom14_scheme_mask = np.zeros_like(ATOM_NAMES, dtype=bool)
atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
-
# ... Find the matched restype by checking if all the non-None posititons and None positions match
# This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later.
if all(x is not None for x in atom_names[atom14_scheme_mask]) and all(
@@ -403,7 +422,13 @@ def process_unindexed_outputs(
row_ind, col_ind = linear_sum_assignment(dists)
res_id_, chain_id_, _ = indices_to_components_(atom_array_diffused, col_ind)
- assert (res_id_ == res_id) & (chain_id_ == chain_id)
+ try:
+ assert (res_id_ == res_id) & (chain_id_ == chain_id)
+ except Exception:
+ global_logger.warning(
+ "Unindexed mapping did not work properly, res_id, chain_id"
+ )
+
inserted_mask = np.logical_or(inserted_mask, token_match)
# ... Compute metrics based on the new distances
@@ -427,13 +452,30 @@ def process_unindexed_outputs(
else:
join_atom = None
+ if join_atom is None:
+ pass
+ else:
+ dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
+
+ elif not np.any(np.isin(token.atom_name, backbone_atoms_RNA)):
+ if np.sum(token.atomize) == 1:
+ join_atom = np.where(token.atomize)[0][0]
+ elif "C1'" in token.atom_name:
+ join_atom = np.where(token.atom_name == "C1'")[0][0]
+ else:
+ join_atom = None
+
if join_atom is None:
global_logger.warning(
- f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}."
+ "Skipping joint point rmsd, neither protein or NA backbone"
)
else:
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
+
+ try:
metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
+ except Exception:
+ pass
metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"
diff --git a/models/rfd3/src/rfd3/transforms/conditioning_base.py b/models/rfd3/src/rfd3/transforms/conditioning_base.py
index 01a303c4..682005bb 100644
--- a/models/rfd3/src/rfd3/transforms/conditioning_base.py
+++ b/models/rfd3/src/rfd3/transforms/conditioning_base.py
@@ -235,6 +235,7 @@ def __init__(
train_conditions: dict,
meta_conditioning_probabilities: dict,
sequence_encoding,
+ association_scheme,
):
if exists(train_conditions):
train_conditions = hydra.utils.instantiate(
@@ -242,7 +243,12 @@ def __init__(
)
self.meta_conditioning_probabilities = meta_conditioning_probabilities
self.train_conditions = train_conditions
+
+ for item in self.train_conditions:
+ self.train_conditions[item].association_scheme = association_scheme
+
self.sequence_encoding = sequence_encoding
+ self.association_scheme = association_scheme
def check_input(self, data: dict):
assert not data["is_inference"], "This transform is only used during training!"
@@ -259,10 +265,13 @@ def check_input(self, data: dict):
assert "conditions" in data, "Conditioning dict not initialized"
def forward(self, data):
+ # for item in self.train_conditions:
+ # print(self.train_conditions[item].is_valid_for_example(data))
+
valid_conditions = [
cond
for cond in self.train_conditions.values()
- if cond.frequency > 0 and cond.is_valid_for_example(data)
+ if cond.is_valid_for_example(data) and cond.frequency > 0
]
if len(valid_conditions) == 0:
@@ -278,6 +287,8 @@ def forward(self, data):
i_cond = np.random.choice(np.arange(len(p_cond)), p=p_cond)
cond = valid_conditions[i_cond]
+ cond.association_scheme = self.association_scheme
+
data["sampled_condition"] = cond
data["sampled_condition_name"] = cond.name
data["sampled_condition_cls"] = cond.__class__
@@ -296,6 +307,9 @@ class SampleConditioningFlags(Transform):
"SampleConditioningType",
] # We use is_protein in the PPI training condition
+ def __init__(self, association_scheme):
+ self.association_scheme = association_scheme
+
def check_input(self, data):
assert not data[
"is_inference"
@@ -317,13 +331,14 @@ class UnindexFlaggedTokens(Transform):
Serves as the merge point between training / infernece conditioning pipelines
"""
- def __init__(self, central_atom):
+ def __init__(self, central_atom, association_scheme):
"""
Args:
central_atom: The atom to use as the central atom for unindexed motifs.
"""
super().__init__()
self.central_atom = central_atom
+ self.association_scheme = association_scheme
def check_input(self, data: dict):
check_contains_keys(data, ["atom_array"])
@@ -368,8 +383,18 @@ def expand_unindexed_motifs(
token.res_id = token.res_id + max_resid
token.is_C_terminus[:] = False
token.is_N_terminus[:] = False
- assert token.is_protein.all(), f"Cannot unindex non-protein token: {token}"
- token = add_representative_atom(token, central_atom=self.central_atom)
+
+ if not self.association_scheme == "atom23":
+ assert token.is_protein.all(), f"Cannot unindex non-protein token: {token} unless using atom23 association scheme"
+ token = add_representative_atom(token, central_atom=self.central_atom)
+ else:
+ if token.is_protein.all():
+ token = add_representative_atom(
+ token, central_atom=self.central_atom
+ )
+ else:
+ token = add_representative_atom(token, central_atom="C1'")
+
unindexed_tokens.append(token)
# ... Remove original tokens e.g. during inference
@@ -404,7 +429,6 @@ def expand_unindexed_motifs(
f"Failed to create uniquely recognised tokens after concatenation.\n"
f"Concatenated tokens: {get_token_count(atom_array_full)}, unindexed: {n_unindexed_tokens}"
)
-
return atom_array_full
def create_unindexed_masks(
diff --git a/models/rfd3/src/rfd3/transforms/design_transforms.py b/models/rfd3/src/rfd3/transforms/design_transforms.py
index 38c79979..e7d2d4cb 100644
--- a/models/rfd3/src/rfd3/transforms/design_transforms.py
+++ b/models/rfd3/src/rfd3/transforms/design_transforms.py
@@ -38,6 +38,7 @@
UnindexFlaggedTokens,
get_motif_features,
)
+from rfd3.transforms.na_geom import na_ss_feats_from_annotation
from rfd3.transforms.rasa import discretize_rasa
from rfd3.transforms.util_transforms import (
AssignTypes,
@@ -70,8 +71,10 @@ class SubsampleToTypes(Transform):
def __init__(
self,
allowed_types: list | str = ["is_protein"],
+ association_scheme: str = "atom14",
):
self.allowed_types = allowed_types
+ self.association_scheme = association_scheme
if not self.allowed_types == "ALL":
for k in allowed_types:
if not k.startswith("is_"):
@@ -103,7 +106,7 @@ def forward(self, data):
)
)
- if atom_array.is_protein.sum() == 0:
+ if self.association_scheme != "atom23" and atom_array.is_protein.sum() == 0:
raise ValueError(
"No protein atoms found in the atom array. Example ID: {}".format(
data.get("example_id", "unknown")
@@ -697,10 +700,12 @@ def __init__(
token_1d_features,
atom_1d_features,
autofill_zeros_if_not_present_in_atomarray=False,
+ association_scheme="atom14",
):
self.autofill = autofill_zeros_if_not_present_in_atomarray
self.token_1d_features = token_1d_features
self.atom_1d_features = atom_1d_features
+ self.association_scheme = association_scheme
def check_input(self, data) -> None:
check_contains_keys(data, ["atom_array"])
@@ -753,6 +758,13 @@ def forward(self, data: Dict[str, Any]) -> Dict[str, Any]:
if "feats" not in data.keys():
data["feats"] = {}
+ if self.association_scheme == "atom23":
+ data["atom_array"].set_annotation(
+ "is_protein_token", data["atom_array"].is_protein
+ )
+ data["atom_array"].set_annotation("is_dna_token", data["atom_array"].is_dna)
+ data["atom_array"].set_annotation("is_rna_token", data["atom_array"].is_rna)
+
for feature_name, n_dims in self.token_1d_features.items():
data = self.generate_feature(feature_name, n_dims, data, "token")
@@ -762,6 +774,91 @@ def forward(self, data: Dict[str, Any]) -> Dict[str, Any]:
return data
+class AddAdditional2dFeaturesToFeats(Transform):
+ """
+ Adds any net.token_initializer.token_2d_features and net.diffusion_module.diffusion_atom_encoder.atom_2d_features present in the atomarray but not in data['feats'] to data['feats']
+ Args:
+ - autofill_zeros_if_not_present_in_atomarray: self explanatory
+ - token_2d_features: List of single-item dictionaries, corresponding to feature_name: n_feature_dims. Should be hydra interpolated from
+ net.token_initializer.token_2d_features
+ """
+
+ incompatible_previous_transforms = ["AddAdditional2dFeaturesToFeats"]
+
+ def __init__(
+ self,
+ token_2d_features,
+ autofill_zeros_if_not_present_in_atomarray=False,
+ association_scheme="atom14",
+ ):
+ self.autofill = autofill_zeros_if_not_present_in_atomarray
+ self.token_2d_features = token_2d_features
+ self.association_scheme = association_scheme
+
+ # Need to pre-define custom constructor functions
+ # to map from atomarray annotations to tensors.
+ self.constructor_functions = {
+ "bp_partners": na_ss_feats_from_annotation,
+ }
+
+ def check_input(self, data) -> None:
+ check_contains_keys(data, ["atom_array"])
+ check_is_instance(data, "atom_array", AtomArray)
+
+ def generate_token_feature(self, feature_name, n_dims, data):
+ # Don't do this if we already have the feature
+ if feature_name in data["feats"].keys():
+ return data
+
+ # For these, we need to use a constructor function mapping,
+ # since pair features may require custom logic/conventions.
+
+ ## for old ckpt handling ##
+ if feature_name in self.constructor_functions.keys():
+ feature_array = self.constructor_functions[feature_name](data["atom_array"])
+ else:
+ raise ValueError(
+ f"No constructor function found for 2d feature `{feature_name}`"
+ )
+
+ # We can fix shape issues here:
+ if len(feature_array.shape) == 2 and n_dims == 1:
+ feature_array = feature_array.unsqueeze(1)
+
+ # ensure that feature_array is a 3d array with third dim == n_dims:
+ if len(feature_array.shape) != 3:
+ raise ValueError(
+ f"token 2d_feature `{feature_name}` must be a 3d array, got {len(feature_array.shape)}d."
+ )
+ if feature_array.shape[2] != n_dims:
+ raise ValueError(
+ f"token 2d_feature `{feature_name}` dimensions in atomarray ({feature_array.shape[-1]}) does not match dimension declared in config, ({n_dims})"
+ )
+ # Ensure correct shape in first two dims (I,I,...)
+ if feature_array.shape[0] != feature_array.shape[1]:
+ raise ValueError(
+ f"token 2d_feature `{feature_name}` first two dimensions must be equal (square matrix), got {feature_array.shape[0]} and {feature_array.shape[1]}"
+ )
+
+ data["feats"][feature_name] = feature_array
+
+ return data
+
+ def forward(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Checks if the 2d_features are present in data['feats']. If not present, adds them from the atomarray.
+ If annotation is not present in atomarray, either autofills the feature with 0s or throws an error
+ """
+ if "feats" not in data.keys():
+ data["feats"] = {}
+ # Only apply for features that the model is expecting:
+ if self.token_2d_features is None:
+ return data
+ for feature_name, n_dims in self.token_2d_features.items():
+ data = self.generate_token_feature(feature_name, n_dims, data)
+ return data
+
+
class FeaturizepLDDT(Transform):
"""
Provides:
diff --git a/models/rfd3/src/rfd3/transforms/na_geom.py b/models/rfd3/src/rfd3/transforms/na_geom.py
new file mode 100644
index 00000000..ee264f2b
--- /dev/null
+++ b/models/rfd3/src/rfd3/transforms/na_geom.py
@@ -0,0 +1,323 @@
+from typing import Any
+
+import numpy as np
+from atomworks.ml.transforms._checks import (
+ check_atom_array_annotation,
+ check_contains_keys,
+ check_is_instance,
+)
+from atomworks.ml.transforms.base import Transform
+from atomworks.ml.utils.token import get_token_starts, spread_token_wise
+from biotite.structure import AtomArray
+from rfd3.transforms.conditioning_utils import sample_island_tokens
+from rfd3.transforms.na_geom_utils import (
+ DEFAULT_NA_SS_FEATURE_INFO,
+ annotate_na_ss,
+ annotate_na_ss_from_data_specification,
+)
+
+
+def na_ss_feats_from_annotation(
+ atom_array: AtomArray,
+ token_starts=None,
+ n_tokens=None,
+ return_as_onehot=True,
+) -> np.ndarray:
+ """
+ Takes in atom array and constucts a base pair feature matrix from annotations,
+ according to to custom feature constuction + masking system.
+ This featurization utilizes info from BasePairEnum to assign int values
+ to paired, unpaired, and masked positions in the matrix.
+
+ Args:
+ * atom_array: AtomArray with bp_partners annotation at atom level
+ * token_starts (optional): indices of token starts in the atom array
+ * n_tokens (optional): number of tokens (length of token_starts)
+ * return_as_onehot (optional): if False, return integer-encoded
+ matrix instead of one-hot encoded matrix
+
+ returns:
+ * na_ss_matrix:
+ If ``return_as_onehot`` is True (default):
+ np.ndarray of shape (n_tokens, n_tokens, n_classes)
+ with one-hot encoded values according to BasePairEnum
+
+ If ``return_as_onehot`` is False :
+ np.ndarray of shape (n_tokens, n_tokens)
+ with int values according to BasePairEnum
+
+
+ """
+ # Get this info from atom_array, or avoid if given
+ if (token_starts is None) or (n_tokens is None):
+ token_starts = get_token_starts(atom_array)
+ n_tokens = len(token_starts)
+
+ # Collect token inds for paired or loop positions:
+ pair_inds = []
+ loop_inds = []
+ token_bp_partners = atom_array.get_annotation("bp_partners")[
+ token_starts
+ ] # get bp_partners at token level
+ assert (
+ len(token_bp_partners) == n_tokens
+ ), "Length of token_bp_partners should match n_tokens"
+ for i, j_list in enumerate(token_bp_partners):
+ if j_list is not None:
+ if len(j_list) > 0:
+ for j in j_list:
+ pair_inds.append((i, j))
+ else:
+ loop_inds.append(i)
+
+ # The standard system for constructing meaningful base pair features:
+ # 0). Initialize with values of UNSPECIFIED (0): int matrix of shape (n_tokens, n_tokens)
+ na_ss_matrix = np.full(
+ (n_tokens, n_tokens), DEFAULT_NA_SS_FEATURE_INFO["NA_SS_MASK"], dtype=np.int64
+ )
+
+ # 1). Fill in with values of PAIR (1) at positions that have bp_partners annotated as a non-empty list
+ for pair_i, pair_j in pair_inds:
+ na_ss_matrix[pair_i, pair_j] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"]
+ na_ss_matrix[pair_j, pair_i] = DEFAULT_NA_SS_FEATURE_INFO[
+ "NA_SS_PAIR"
+ ] # ensure symmetry
+
+ # 2). Fill in with values of LOOP (2) at positions that have bp_partners annotated as an empty list (explicitly unpaired)
+ # (we make full stripes across that position's row/col to indicate that NONE of those other positions are paired )
+ for loop_i in loop_inds:
+ na_ss_matrix[loop_i, :] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"]
+ na_ss_matrix[:, loop_i] = DEFAULT_NA_SS_FEATURE_INFO[
+ "NA_SS_LOOP"
+ ] # ensure symmetry
+
+ # Optional: convert NA-SS matrix to one-hot encoding according for model input:
+ if return_as_onehot:
+ na_ss_matrix = np.eye(len(DEFAULT_NA_SS_FEATURE_INFO), dtype=np.int64)[
+ na_ss_matrix
+ ]
+
+ return na_ss_matrix
+
+
+class CalculateNucleicAcidGeomFeats(Transform):
+ """
+ Transform for constructing nucleic-acid conditioning features.
+
+ This transform currently produces only nucleic-acid secondary-structure (NA-SS)
+ features as a 2D token-token matrix with 3 bins:
+ * 0: mask / unspecified
+ * 1: paired
+ * 2: loop / explicitly unpaired
+
+ Training:
+ - Computes geometry/H-bond-based base pairs and writes them onto the AtomArray
+ via the ``bp_partners`` annotation (annotation-first), then reconstructs the
+ matrix (and optionally masks parts of it) before one-hot encoding.
+
+ Inference:
+ - Interprets user-provided secondary-structure specifications, writes the same
+ ``bp_partners`` annotation, then follows the same matrix + one-hot path.
+
+ Note: helical-parameter features are not implemented/used in this refactored path.
+ """
+
+ def __init__(
+ self,
+ is_inference,
+ # Conditional sampling parameters all stored in this dict:
+ meta_conditioning_probabilities: dict[str, float] = None,
+ # Mask control paramerers:
+ nucleic_ss_min_shown: float = 0.2,
+ nucleic_ss_max_shown: float = 1.0,
+ n_islands_min: int = 1,
+ n_islands_max: int = 6,
+ # USE_RF2AA_NAMES: bool = False,
+ NA_only: bool = False,
+ planar_only: bool = True,
+ ):
+ # Critical, must always have to know how to handle
+ self.is_inference = is_inference
+
+ self.meta_conditioning_probabilities = meta_conditioning_probabilities or {}
+
+ # Control whether we show some nucleic SS or default to full 2D mask
+ self.p_is_nucleic_ss_example = self.meta_conditioning_probabilities.get(
+ "p_is_nucleic_ss_example", 0.0
+ )
+
+ # Control whether we define full SS or just part of it (only applies if is NA SS example)
+ self.p_show_partial_feats = self.meta_conditioning_probabilities.get(
+ "p_nucleic_ss_show_partial_feats", 0.0
+ )
+
+ # Some frac of time default to only showing canonical base pairs
+ self.p_canonical_bp_filter = self.meta_conditioning_probabilities.get(
+ "p_canonical_bp_filter", 0.5
+ )
+
+ # mask patterning control to make things resemble design scenarios
+ self.nucleic_ss_min_shown = nucleic_ss_min_shown
+ self.nucleic_ss_max_shown = nucleic_ss_max_shown
+ self.n_islands_min = n_islands_min
+ self.n_islands_max = n_islands_max
+
+ # Filters for what can be considered a planar contact interaction
+ self.NA_only = (
+ NA_only # only annotate base-like interactions for nucleic acid residues
+ )
+ self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations,
+
+ def check_input(self, data: dict[str, Any]) -> None:
+ check_contains_keys(data, ["atom_array"])
+ check_is_instance(data, "atom_array", AtomArray)
+ check_atom_array_annotation(data, ["res_name"])
+ # maybe do later: check_atom_array_has_hydrogen(data)
+
+ def _sample_training_flags(self) -> tuple[bool, bool]:
+ """Sample booleans controlling whether/how features are shown in training."""
+ is_nucleic_ss_example = bool(np.random.rand() < self.p_is_nucleic_ss_example)
+ give_partial_feats = bool(np.random.rand() < self.p_show_partial_feats)
+ return is_nucleic_ss_example, give_partial_feats
+
+ def forward(self, data: dict) -> dict:
+ atom_array = data["atom_array"]
+
+ # Calculate n_tokens (assuming one token per residue for simplicity)
+ token_starts = get_token_starts(atom_array)
+ n_tokens = len(token_starts)
+ # token_level_array = atom_array[token_starts]
+
+ # Handle the training case with ground truth and masking
+ if not self.is_inference:
+ # First, annotate as usual
+ is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
+
+ if is_nucleic_ss_example:
+ atom_array = annotate_na_ss(
+ atom_array,
+ NA_only=self.NA_only,
+ planar_only=self.planar_only,
+ p_canonical_bp_filter=self.p_canonical_bp_filter,
+ )
+
+ # Generate symmetric partner annotations at the token level for masking purposes.
+ # choice for object-consistency: if already masked/undefined: be a list mapping to self-index.
+ partner_sym_map = {
+ i: atom_array.bp_partners[ts_i]
+ if atom_array.bp_partners[ts_i] is not None
+ else [i]
+ for i, ts_i in enumerate(token_starts)
+ }
+
+ # # Sample mask on token level:
+ token_mask_to_show = self._sample_where_to_show_ss(
+ n_tokens,
+ is_nucleic_ss_example=is_nucleic_ss_example,
+ give_partial_feats=give_partial_feats,
+ partner_sym_map=partner_sym_map,
+ ) # Mask vec for tokens where ss shown
+
+ # Spread mask to atom level
+ is_ss_shown = spread_token_wise(atom_array, token_mask_to_show)
+
+ # Extract the base pair annotations
+ bp_partners_atom = atom_array.get_annotation("bp_partners")
+
+ # Remove unshown positions from bp_partners annotation
+ bp_partners_atom[~is_ss_shown] = None
+
+ # Reset the annotation with newly hidden positions
+ atom_array.set_annotation("bp_partners", bp_partners_atom)
+ else:
+ atom_array.set_annotation(
+ "bp_partners", np.array([None] * len(atom_array))
+ )
+
+ # Inference case: create from commandline args
+ else:
+ """
+ Different cases handled:
+ - 1). Single dot-bracket string
+ - 2). multiple dot bracket strings with chain/ind ranges specified
+ - 3). Lists of paired indices
+ """
+ atom_array = annotate_na_ss_from_data_specification(
+ data,
+ overwrite=True,
+ )
+
+ # Check feats existence and update:
+ if "feats" not in data:
+ data["feats"] = {}
+
+ data.setdefault("log_dict", {})
+ log_dict = data["log_dict"]
+ data["log_dict"] = log_dict
+ data["atom_array"] = atom_array
+
+ return data
+
+ def _sample_where_to_show_ss(
+ self,
+ n_tokens: int,
+ is_nucleic_ss_example: bool = True,
+ give_partial_feats: bool = True,
+ partner_sym_map: dict[int, list[int]] = None,
+ ) -> np.ndarray:
+ """Sample token-level islands indicating which SS rows/cols to reveal.
+ This custom function allows for enforcing symmetry in the shown features according
+ to the partner_sym_map, which encodes which tokens are partners in the SS
+ matrix and thus should be masked/unmasked together to maintain consistency.
+
+ """
+ # If NOT is_nucleic_ss_example, set is_shown to all False
+ if not is_nucleic_ss_example:
+ token_mask_to_show = np.zeros((n_tokens,), dtype=bool)
+
+ # If NOT give_partial_feats, set is_shown to all True
+ if not give_partial_feats:
+ token_mask_to_show = np.ones((n_tokens,), dtype=bool)
+ else:
+ # Get numerical parameters for that govern the mask pattern
+ frac_shown = (
+ self.nucleic_ss_min_shown
+ + (self.nucleic_ss_max_shown - self.nucleic_ss_min_shown)
+ * np.random.rand()
+ )
+ frac_shown = float(np.clip(frac_shown, 0.0, 1.0))
+ max_length = int(np.ceil(frac_shown * n_tokens))
+ if max_length <= 0:
+ token_mask_to_show = np.zeros((n_tokens,), dtype=bool)
+ island_len_min = max(
+ 1, int(frac_shown * n_tokens // max(int(self.n_islands_max), 1))
+ )
+ island_len_max = max(
+ 1, int(frac_shown * n_tokens // max(int(self.n_islands_min), 1))
+ )
+ island_len_min = min(island_len_min, n_tokens)
+ island_len_max = min(island_len_max, n_tokens)
+ island_len_max = max(island_len_max, island_len_min)
+
+ # Sample the actual mask using the utility function:
+ token_mask_to_show = sample_island_tokens(
+ n_tokens,
+ island_len_min=island_len_min,
+ island_len_max=island_len_max,
+ n_islands_min=self.n_islands_min,
+ n_islands_max=self.n_islands_max,
+ max_length=max_length,
+ )
+
+ # Handle symmetry by iterating through the partner_sym_map items and setting
+ # `partner_mask_to_show` at partner positions to match `token_mask_to_show`
+ # initialize as all shown so effect comes from hiding + logical AND condition
+ partner_mask_to_show = np.ones_like(token_mask_to_show)
+ for token_i, partner_ind_list in partner_sym_map.items():
+ for partner_ind in partner_ind_list:
+ partner_mask_to_show[partner_ind] = token_mask_to_show[token_i]
+
+ # Combine the original mask with the partner mask to ensure symmetry
+ token_mask_to_show = token_mask_to_show & partner_mask_to_show
+
+ return token_mask_to_show
diff --git a/models/rfd3/src/rfd3/transforms/na_geom_utils.py b/models/rfd3/src/rfd3/transforms/na_geom_utils.py
new file mode 100644
index 00000000..7a4dd911
--- /dev/null
+++ b/models/rfd3/src/rfd3/transforms/na_geom_utils.py
@@ -0,0 +1,1866 @@
+import math
+import os
+import subprocess
+import tempfile
+from datetime import datetime
+from typing import Dict, Optional
+
+import biotite.structure as struc
+import numpy as np
+from atomworks.constants import (
+ STANDARD_AA,
+ STANDARD_DNA,
+ STANDARD_RNA,
+)
+from atomworks.io.utils.sequence import (
+ is_purine,
+ is_pyrimidine,
+)
+from atomworks.ml.encoding_definitions import AF3SequenceEncoding
+from atomworks.ml.utils.token import (
+ get_token_starts,
+ is_glycine,
+ is_protein_unknown,
+ is_standard_aa_not_glycine,
+ is_unknown_nucleotide,
+)
+from biotite.structure import AtomArray
+from rfd3.constants import (
+ ATOM_REGION_BY_RESI,
+ PLANAR_ATOMS_BY_RESI,
+)
+from rfd3.transforms.hbonds_hbplus import save_atomarray_to_pdb
+
+# Derived: True when the residue has any planar sidechain atoms
+HAS_PLANAR_SC = {res: bool(atoms) for res, atoms in PLANAR_ATOMS_BY_RESI.items()}
+
+DEFAULT_NA_SS_FEATURE_INFO: dict[str, int] = {
+ "NA_SS_MASK": 0,
+ "NA_SS_PAIR": 1,
+ "NA_SS_LOOP": 2,
+}
+
+AA_PLANAR_ATOMS = sorted(
+ set(
+ atom
+ for res in STANDARD_AA
+ if res in PLANAR_ATOMS_BY_RESI
+ for atom in PLANAR_ATOMS_BY_RESI[res]
+ )
+)
+
+NA_PLANAR_ATOMS = sorted(
+ set(
+ atom
+ for res in (*STANDARD_RNA, *STANDARD_DNA)
+ if res in PLANAR_ATOMS_BY_RESI
+ for atom in PLANAR_ATOMS_BY_RESI[res]
+ )
+)
+
+
+class NucMolInfo:
+ """Constants and parameters for nucleic-acid geometry and interaction scoring.
+
+ All parameters are set to empirically validated defaults. No constructor
+ arguments are currently accepted.
+ """
+
+ def __init__(self) -> None:
+ # Hbond interaction-class indices of the `hbond_count`` array:
+ # `hbond_count`` array is (L, L, 3), where the last dimension
+ # encodes interaction type between tokens i & j
+ self.BB_BB = 0 # backbone-backbone hbond interactions
+ self.BB_SC = 1 # backbone-sidechain hbond interactions
+ self.SC_SC = 2 # sidechain-sidechain hbond interactions
+
+ # We sum over the last dimension of the hbond_count array, scaling
+ # count by the following weights to get the interaction score:
+ self.bp_weight_BB_BB = 0.0
+ self.bp_weight_BB_SC = 0.5
+ self.bp_weight_SC_SC = 1.0
+ self.bp_summation_weights = [
+ self.bp_weight_BB_BB,
+ self.bp_weight_BB_SC,
+ self.bp_weight_SC_SC,
+ ]
+
+ # Parameters fo sigmoid function that gives us a continuous step function for
+ # meeting basepair interaction criteria based on hbond counts alone (1st filter).
+ # Calibrated such that:
+ # >= 2 base-base H-bonds -> ~1.0
+ # 1 base-base H-bond + 1 base-backbone H-bond -> ~0.5
+ self.min_hbonds_for_bp = 2.0
+ self.bp_hbond_coeff = 9.8 # determined heuristically
+ self.bp_val_cutoff = (
+ 0.5 # minimum basepairing score for binarizing basepairs when needed
+ )
+
+ self.base_geometry_limits = {}
+ self.base_geometry_limits["D_ij"] = 16.0
+ self.base_geometry_limits["H_ij"] = 1.5
+ self.base_geometry_limits["P_ij"] = math.pi / 5
+ self.base_geometry_limits["B_ij"] = math.pi / 5
+
+ self.rep_atom_dict = {"protein": "CA", "rna": "C1'", "dna": "C1'"}
+
+ # go through self.vec_atom_dict and remove spaces from atom names (values in inner dicts), and remove spaces from keys + replace 'R' with '' in outer dict keys
+ self.vec_atom_dict = {
+ "DA": {
+ "W_start": "N1",
+ "W_stop": "N6",
+ "H_start": "N7",
+ "H_stop": "N6",
+ "S_start": "C1'",
+ "S_stop": "N3",
+ "B_start": "C1'",
+ "B_stop": "N9",
+ },
+ "DG": {
+ "W_start": "N1",
+ "W_stop": "O6",
+ "H_start": "N7",
+ "H_stop": "O6",
+ "S_start": "C1'",
+ "S_stop": "N3",
+ "B_start": "C1'",
+ "B_stop": "N9",
+ },
+ "DC": {
+ "W_start": "N3",
+ "W_stop": "N4",
+ "H_start": "C5",
+ "H_stop": "N4",
+ "S_start": "C1'",
+ "S_stop": "O2",
+ "B_start": "C1'",
+ "B_stop": "N1",
+ },
+ "DT": {
+ "W_start": "N3",
+ "W_stop": "O4",
+ "H_start": "C5",
+ "H_stop": "O4",
+ "S_start": "C1'",
+ "S_stop": "O2",
+ "B_start": "C1'",
+ "B_stop": "N1",
+ },
+ "A": {
+ "W_start": "N1",
+ "W_stop": "N6",
+ "H_start": "N7",
+ "H_stop": "N6",
+ "S_start": "C1'",
+ "S_stop": "N3",
+ "B_start": "C1'",
+ "B_stop": "N9",
+ },
+ "G": {
+ "W_start": "N1",
+ "W_stop": "O6",
+ "H_start": "N7",
+ "H_stop": "O6",
+ "S_start": "C1'",
+ "S_stop": "N3",
+ "B_start": "C1'",
+ "B_stop": "N9",
+ },
+ "C": {
+ "W_start": "N3",
+ "W_stop": "N4",
+ "H_start": "C5",
+ "H_stop": "N4",
+ "S_start": "C1'",
+ "S_stop": "O2",
+ "B_start": "C1'",
+ "B_stop": "N1",
+ },
+ "U": {
+ "W_start": "N3",
+ "W_stop": "O4",
+ "H_start": "C5",
+ "H_stop": "O4",
+ "S_start": "C1'",
+ "S_stop": "O2",
+ "B_start": "C1'",
+ "B_stop": "N1",
+ },
+ }
+
+
+def calculate_hb_counts(
+ atom_array: AtomArray,
+ token_level_data: dict,
+ mol_info: NucMolInfo,
+ cutoff_HA_dist: float = 2.5,
+ cutoff_DA_dist: float = 3.9,
+):
+ """Count hydrogen bonds between residue pairs using HBPLUS.
+
+ Args:
+ atom_array: Structure to analyse.
+ token_level_data: Token-level metadata dict (must contain
+ ``token_id_list`` and ``resi2index``).
+ mol_info: Molecular-info object for backbone/sidechain atom lookup.
+ cutoff_HA_dist: H–A distance cutoff (Å) passed to HBPLUS.
+ cutoff_DA_dist: D–A distance cutoff (Å) passed to HBPLUS.
+
+ Returns:
+ np.ndarray of shape ``(I, I, 3)`` (int32) where the last axis
+ encodes: 0 = BB–BB, 1 = BB–SC, 2 = SC–SC H-bond counts.
+ """
+ hbplus_exe = os.environ.get("HBPLUS_PATH")
+
+ if hbplus_exe is None or hbplus_exe == "":
+ raise ValueError(
+ "HBPLUS_PATH environment variable not set. "
+ "Please set it to the path of the hbplus executable in order to calculate hydrogen bonds."
+ )
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dtstr = datetime.now().strftime("%Y%m%d%H%M%S")
+ pdb_filename = f"{dtstr}_{np.random.randint(10000)}.pdb"
+ pdb_path = os.path.join(tmpdir, pdb_filename)
+ atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path)
+
+ subprocess.call(
+ [
+ hbplus_exe,
+ "-h",
+ str(cutoff_HA_dist),
+ "-d",
+ str(cutoff_DA_dist),
+ pdb_path,
+ pdb_path,
+ ],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ cwd=tmpdir,
+ )
+
+ num_resis_total = len(token_level_data["token_id_list"])
+
+ hbond_count = np.zeros((num_resis_total, num_resis_total, 3), dtype=np.int32)
+
+ hb2_path = pdb_path.replace("pdb", "hb2")
+ if not os.path.exists(hb2_path):
+ print("WARNING: HB2 file could not be found; skipping NA SS metric")
+ return hbond_count
+ with open(hb2_path, "r") as hb2_f:
+ for i, line in enumerate(hb2_f):
+ if i < 8:
+ continue
+ if len(line) < 28:
+ continue
+
+ d_chain_iid = chain_map[line[0]]
+ d_resi = int(line[1:5].strip())
+ d_resn = line[6:9].strip()
+ d_atom_name = line[9:13].strip()
+
+ # Initialize donor/acceptor sidechain/backbone flags:
+ # then replace with True if valid for summation
+ d_is_sc = False
+ d_is_bb = False
+ a_is_sc = False
+ a_is_bb = False
+
+ d_mask = (
+ (atom_array.atom_name == d_atom_name)
+ & (atom_array.res_name == d_resn)
+ & (atom_array.res_id == d_resi)
+ & (atom_array.chain_iid == d_chain_iid)
+ )
+ # d_atm = atom_array[d_mask]
+ # d_idx = d_atm.token_id
+ d_idx = token_level_data["resi2index"].get(
+ f"{d_chain_iid}__{d_resi}", None
+ )
+ if d_idx is None:
+ continue
+
+ # Handle standard polymer residues for donor atom:
+ if d_resn in ATOM_REGION_BY_RESI.keys():
+ d_is_sc = d_atom_name in ATOM_REGION_BY_RESI[d_resn]["sc"]
+ d_is_bb = d_atom_name in ATOM_REGION_BY_RESI[d_resn]["bb"]
+ else:
+ # If non-polymer, define any ligand HBonding atom as backbone:
+ if d_mask.sum() > 0:
+ d_is_bb = atom_array[d_mask][0].is_ligand
+
+ a_chain_iid = chain_map[line[14]]
+ a_resi = int(line[15:19].strip())
+ a_resn = line[20:23].strip()
+ a_atom_name = line[23:27].strip()
+
+ a_mask = (
+ (atom_array.atom_name == a_atom_name)
+ & (atom_array.res_name == a_resn)
+ & (atom_array.res_id == a_resi)
+ & (atom_array.chain_iid == a_chain_iid)
+ )
+ a_idx = token_level_data["resi2index"].get(
+ f"{a_chain_iid}__{a_resi}", None
+ )
+ if a_idx is None:
+ continue
+
+ # Handle standard polymer residues for acceptor atom:
+ if a_resn in ATOM_REGION_BY_RESI.keys():
+ a_is_sc = a_atom_name in ATOM_REGION_BY_RESI[a_resn]["sc"]
+ a_is_bb = a_atom_name in ATOM_REGION_BY_RESI[a_resn]["bb"]
+ else:
+ # If non-polymer, define any ligand HBonding atom as backbone:
+ if a_mask.sum() > 0:
+ a_is_bb = atom_array[a_mask][0].is_ligand
+
+ # 0 -> both backbone (BB-BB)
+ hbond_count[a_idx, d_idx, 0] += a_is_bb * d_is_bb
+ hbond_count[d_idx, a_idx, 0] += d_is_bb * a_is_bb
+
+ # 1 -> one backbone, one sidechain (BB-SC)
+ hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | (
+ a_is_sc * d_is_bb
+ )
+ hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | (
+ d_is_sc * a_is_bb
+ )
+
+ # 2 -> both sidechain (SC-SC)
+ hbond_count[a_idx, d_idx, 2] += a_is_sc * d_is_sc
+ hbond_count[d_idx, a_idx, 2] += d_is_sc * a_is_sc
+ """
+ try:
+ os.remove(pdb_path)
+ os.remove(hb2_path)
+ except:
+ print("temp pdb/hb already removed or not created to begin with")
+ """
+ return hbond_count
+
+
+def find_planar_positions(
+ atom_array: AtomArray,
+ mol_info: NucMolInfo,
+ tol: float = 1e-2,
+) -> Dict:
+ """Identify residues with planar sidechains via known atom lists or PCA plane-fitting.
+
+ For canonical residues the planar atoms are looked up from ``mol_info``;
+ for non-canonical residues a plane is fitted to the four tip-most sidechain
+ atoms, and all atoms within *tol* of that plane are returned.
+
+ Args:
+ atom_array: Structure to analyse.
+ mol_info: Molecular-info object supplying per-residue planar atom lists.
+ tol: Distance tolerance (Å) from the fitted plane for an atom to be
+ considered planar.
+
+ Returns:
+ Dictionary ``{(chain_iid, res_id): [atom_name, ...]}`` mapping each
+ unique residue position to its list of planar sidechain atom names.
+ """
+ unique_positions_list = []
+ for atm in atom_array:
+ pos_id = (atm.chain_iid, atm.res_id, atm.res_name)
+ if pos_id not in unique_positions_list:
+ unique_positions_list.append(pos_id)
+
+ # Get candidate planar atoms:
+ planar_atom_list_dict = {}
+
+ # for chain_iid, res_id in unique_positions_list:
+ for chain_iid, res_id, res_name in unique_positions_list:
+ mask = (
+ (atom_array.chain_iid == chain_iid)
+ & (atom_array.res_id == res_id)
+ & (atom_array.res_name == res_name)
+ )
+ res_atoms = atom_array[mask]
+
+ # If possible, speed up by using known planar atoms for this residue type:
+ if res_name in PLANAR_ATOMS_BY_RESI.keys():
+ # Shared atoms between residue and known planar atoms for that residue type:
+ planar_atom_list = list(
+ set([atm.atom_name for atm in res_atoms])
+ & set(PLANAR_ATOMS_BY_RESI[res_name])
+ )
+ planar_atom_list_dict[(chain_iid, res_id)] = planar_atom_list
+
+ # If unknown or noncanonical residue, compute planar atoms geometrically:
+ else:
+ candidate_planar_atm_names = []
+ candidate_planar_atm_coords = []
+
+ for atm in res_atoms:
+ # Can pre-filter protein planar atoms:
+ if atm.is_protein and (atm.atom_name in AA_PLANAR_ATOMS):
+ candidate_planar_atm_names.append(atm.atom_name)
+ candidate_planar_atm_coords.append(atm.coord)
+ # Can pre-filter nucleic acid planar atoms:
+ elif (atm.is_rna or atm.is_dna) and (atm.atom_name in NA_PLANAR_ATOMS):
+ candidate_planar_atm_names.append(atm.atom_name)
+ candidate_planar_atm_coords.append(atm.coord)
+ # Otherwise, consider all atoms for plane fitting:
+ else:
+ candidate_planar_atm_names.append(atm.atom_name)
+ candidate_planar_atm_coords.append(atm.coord)
+
+ # reverse order to prioritize atoms further away from bb:
+ candidate_planar_atm_names = list(reversed(candidate_planar_atm_names))
+ candidate_planar_atm_coords = list(reversed(candidate_planar_atm_coords))
+
+ # Use first four candidate atoms only to define the plane:
+ if len(candidate_planar_atm_coords) >= 4:
+ coords = np.asarray(candidate_planar_atm_coords, dtype=np.float32)
+
+ # compute 4-atom based plane:
+ quad_coords = coords[:4, :]
+
+ # fit plane via PCA (use smallest‑variance eigenvector as normal)
+ quad_center = quad_coords.mean(axis=0, keepdims=True)
+ all_quad_centered = coords - quad_center
+ quad_centered = quad_coords - quad_center
+ # covariance matrix
+ quad_cov = (quad_centered.T @ quad_centered) / max(
+ quad_coords.shape[0] - 1, 1
+ )
+ # eigen decomposition
+ _, quad_eigvecs = np.linalg.eigh(quad_cov)
+ quad_normal = quad_eigvecs[:, 0] # eigenvector with smallest eigenvalue
+ quad_normal = quad_normal / (np.linalg.norm(quad_normal) + 1e-8)
+ # compute distances from plane for all candidate atoms
+ quad_dists = np.abs(all_quad_centered @ quad_normal)
+ # keep only atoms within tolerance
+ quad_valid_mask = quad_dists <= tol
+
+ # Filter for if we have a valid plane in the first place:
+ valid_plane_filter = np.nanmax(quad_dists[:4]) < tol
+ # Filter for if we have enough atoms in the plane:
+ plane_atom_filter = int(np.sum(quad_valid_mask)) >= 4
+ if valid_plane_filter and plane_atom_filter:
+ # Set the planar atom list for this position to those that are within tol of the plane:
+ # using quad_valid_mask and candidate_planar_atm_names:
+ planar_atom_list = [
+ n
+ for n, keep in zip(
+ candidate_planar_atm_names, quad_valid_mask.tolist()
+ )
+ if keep
+ ]
+
+ # not enough atoms close to a common plane
+ else:
+ planar_atom_list = []
+
+ else:
+ # need at least 4 atoms to define a robust plane
+ planar_atom_list = []
+
+ planar_atom_list_dict[(chain_iid, res_id)] = planar_atom_list
+
+ return planar_atom_list_dict
+
+
+def make_coord_list(
+ atom_array: AtomArray,
+ residue_list: list[str],
+ chain_list: list[str],
+ atom_list: list[str],
+) -> list[list[str]]:
+ """Extract per-residue representative coordinates from an AtomArray.
+
+ All three input lists must have the same length. Missing atoms are
+ filled with ``[NaN, NaN, NaN]``.
+
+ Args:
+ atom_array: Biotite AtomArray to query.
+ residue_list: Residue IDs (one per token).
+ chain_list: Chain identifiers (one per token).
+ atom_list: Atom names to extract (use ``"atomized"`` to take the
+ first atom of the residue).
+
+ Returns:
+ List of ``[x, y, z]`` coordinate lists, same length as input.
+ """
+ coord_list = []
+ for res_id, chain_id, atom_name in zip(residue_list, chain_list, atom_list):
+ # Check if the residue exists in the atom array
+ if atom_name == "atomized":
+ # Check for atomized residue, in which case we take the first atom of the residue
+ # full mask should be length-1 if atomized
+ mask = (atom_array.chain_id == chain_id) & (atom_array.res_id == res_id)
+ else:
+ # General case for non-atomized residues
+ # should have a unique solution, but we take the first entry either way.
+ mask = (
+ (atom_array.chain_id == chain_id)
+ & (atom_array.res_id == res_id)
+ & (atom_array.atom_name == atom_name)
+ )
+
+ # Get the coordinates for the masked atoms
+ coords = atom_array.coord[mask][0:1]
+
+ if len(coords) < 1:
+ coord_list.append([float("nan"), float("nan"), float("nan")])
+ else:
+ coord_list.append(coords[0].tolist())
+
+ return coord_list
+
+
+def get_token_level_metadata(
+ atom_array: AtomArray,
+ mol_info: "NucMolInfo",
+ *,
+ NA_only: bool = False,
+ planar_only: bool = True,
+ seq_cutoff=2,
+ gap_length=200,
+) -> dict:
+ """Build lightweight token-level metadata (no coordinate geometry).
+
+ Sufficient for SS reconstruction, loop labeling from ``bp_partners``,
+ and inference-time SS specification parsing. For geometry keys
+ (``xyz_planar``, ``frame_xyz``, ``M_i``), follow up with
+ :func:`add_token_level_geometry_data`.
+
+ Args:
+ atom_array: Structure to analyse.
+ mol_info: Molecular-info constants.
+ NA_only: If True, restrict filter_mask to nucleic-acid tokens.
+ planar_only: If True, restrict filter_mask to tokens with planar
+ sidechains.
+ seq_cutoff: Sequence-distance threshold for the ``seq_neighbors``
+ boolean mask.
+ gap_length: Artificial gap inserted between chains for relative
+ sequence position computation.
+
+ Returns:
+ Dict with keys: ``token_starts``, ``token_index``, ``is_na``,
+ ``is_planar``, ``chain_list``, ``chain_iid_list``, ``resi_list``,
+ ``resn_list``, ``token_id_list``, ``resi2index``, ``len_s``,
+ ``seq_neighbors``, ``na_inds``, ``na_tensor_inds``,
+ ``filter_mask``, ``rep_atom_list``, ``S_start_atom_list``,
+ ``S_stop_atom_list``, ``include_geometry`` (False).
+ """
+
+ # Use residue starts (not token starts) so atomized atoms within one residue
+ # map to a single NA-SS position.
+ token_starts = struc.get_residue_starts(atom_array)
+ token_level_array = atom_array[token_starts]
+
+ token_index = np.arange(len(token_starts))
+
+ # molecule type flags
+ # Instantiate encoding locally to avoid retaining large arrays at module scope.
+ sequence_encoding = AF3SequenceEncoding()
+
+ ###################
+ # is_protein = np.isin(
+ # token_level_array.res_name,
+ # sequence_encoding.all_res_names[sequence_encoding.is_aa_like],
+ # )
+ ##################
+ is_rna = np.isin(
+ token_level_array.res_name,
+ sequence_encoding.all_res_names[sequence_encoding.is_rna_like],
+ )
+ is_dna = np.isin(
+ token_level_array.res_name,
+ sequence_encoding.all_res_names[sequence_encoding.is_dna_like],
+ )
+
+ is_na_arr = (is_dna | is_rna).astype(bool)
+
+ chain_list: list[str] = []
+ chain_iid_list: list[str] = []
+ resi_list: list[int] = []
+ ind_list: list[int] = []
+ res_name_list: list[str] = []
+ token_id_list: list[str] = []
+
+ rep_atom_list: list[str | None] = []
+ S_start_atom_list: list[str | None] = []
+ S_stop_atom_list: list[str | None] = []
+ sc_planarity_list: list[bool] = []
+
+ for i, atm in enumerate(token_level_array):
+ chain_list.append(atm.chain_id)
+ chain_iid_list.append(atm.chain_iid)
+ resi_list.append(int(atm.res_id))
+ ind_list.append(int(i))
+ res_name_list.append(atm.res_name)
+ token_id_list.append(str(atm.token_id))
+
+ if atm.is_polymer and (atm.res_name in HAS_PLANAR_SC.keys()):
+ sc_planarity_list.append(bool(HAS_PLANAR_SC[atm.res_name]))
+ else:
+ sc_planarity_list.append(False)
+
+ # representative & sugar-edge atoms
+ if is_glycine(atm.res_name) | is_protein_unknown(atm.res_name):
+ rep_atom_i = "CA"
+ S_start_atom_i = None
+ S_stop_atom_i = None
+ elif is_standard_aa_not_glycine(atm.res_name):
+ rep_atom_i = "CA"
+ S_start_atom_i = "CA"
+ S_stop_atom_i = "CB"
+ elif is_pyrimidine(atm.res_name):
+ rep_atom_i = "C1'"
+ S_start_atom_i = "C1'"
+ S_stop_atom_i = "O2"
+ elif is_purine(atm.res_name):
+ rep_atom_i = "C1'"
+ S_start_atom_i = "C1'"
+ S_stop_atom_i = "N3"
+ elif is_unknown_nucleotide(atm.res_name):
+ rep_atom_i = "C1'"
+ S_start_atom_i = None
+ S_stop_atom_i = None
+ elif getattr(atm, "atomize", False):
+ rep_atom_i = atm.atom_name
+ S_start_atom_i = None
+ S_stop_atom_i = None
+ else:
+ rep_atom_i = None
+ S_start_atom_i = None
+ S_stop_atom_i = None
+
+ rep_atom_list.append(rep_atom_i)
+ S_start_atom_list.append(S_start_atom_i)
+ S_stop_atom_list.append(S_stop_atom_i)
+
+ # residue index <-> token index map
+ resi2index = {
+ f"{c}__{r}": i for c, r, i in zip(chain_iid_list, resi_list, ind_list)
+ }
+
+ # relative sequence positions w/ chain gaps
+ rel_pos_list: list[int] = []
+ current_chain = ""
+ chn_bias = -gap_length
+ for r, c in zip(resi_list, chain_iid_list):
+ if c != current_chain:
+ chn_bias += gap_length
+ current_chain = c
+ rel_pos_list.append(int(r + chn_bias))
+
+ rel_pos = np.asarray(rel_pos_list, dtype=np.int64)
+ seq_neighbors = np.abs(rel_pos[:, None] - rel_pos[None, :]) <= int(seq_cutoff)
+
+ na_inds = np.nonzero(is_na_arr)[0].tolist()
+ na_tensor_inds = {na_i: i for i, na_i in enumerate(na_inds)}
+
+ # Cheap planarity heuristic from residue name lookup
+ is_planar_arr = np.asarray(sc_planarity_list, dtype=bool)
+
+ # filter mask using NA_only / planar_only flags
+ if NA_only and planar_only:
+ filter_mask = is_na_arr & is_planar_arr
+ elif NA_only and (not planar_only):
+ filter_mask = is_na_arr.copy()
+ elif (not NA_only) and planar_only:
+ filter_mask = is_planar_arr.copy()
+ else:
+ filter_mask = np.ones_like(is_na_arr, dtype=bool)
+
+ return {
+ "token_starts": token_starts,
+ "token_index": token_index,
+ "is_na": is_na_arr,
+ "is_planar": is_planar_arr,
+ "chain_list": chain_list,
+ "chain_iid_list": chain_iid_list,
+ "resi_list": resi_list,
+ "resn_list": res_name_list,
+ "token_id_list": token_id_list,
+ "resi2index": resi2index,
+ "len_s": int(len(token_level_array)),
+ "seq_neighbors": seq_neighbors,
+ "na_inds": na_inds,
+ "na_tensor_inds": na_tensor_inds,
+ "filter_mask": filter_mask,
+ "rep_atom_list": rep_atom_list,
+ "S_start_atom_list": S_start_atom_list,
+ "S_stop_atom_list": S_stop_atom_list,
+ "include_geometry": False,
+ }
+
+
+def add_token_level_geometry_data(
+ atom_array: AtomArray,
+ mol_info: "NucMolInfo",
+ token_level_data: dict,
+ *,
+ NA_only: bool = False,
+ planar_only: bool = True,
+) -> dict:
+ """Augment token-level metadata with coordinate-derived geometry fields.
+
+ Populates ``xyz_planar``, ``xyz_S_start``, ``xyz_S_stop``,
+ ``frame_xyz``, ``M_i`` and updates ``is_planar`` / ``filter_mask``
+ using coordinate-derived planarity. Sets ``include_geometry=True``.
+
+ No-ops if geometry was already computed.
+
+ Args:
+ atom_array: Structure to extract coordinates from.
+ mol_info: Molecular-info constants.
+ token_level_data: Dict produced by :func:`get_token_level_metadata`
+ (modified in-place and returned).
+ NA_only: Restrict filter_mask to nucleic-acid tokens.
+ planar_only: Restrict filter_mask to tokens with planar sidechains.
+
+ Returns:
+ The same ``token_level_data`` dict, augmented with geometry keys.
+ """
+
+ if bool(token_level_data.get("include_geometry", False)):
+ return token_level_data
+
+ # Backward-compatibility: older token_level_data dicts (or user-provided ones)
+ # may not contain the metadata keys this function needs.
+ required_keys = (
+ "chain_iid_list",
+ "chain_list",
+ "resi_list",
+ "rep_atom_list",
+ "S_start_atom_list",
+ "S_stop_atom_list",
+ "is_na",
+ )
+ if any(k not in token_level_data for k in required_keys):
+ token_level_data = get_token_level_metadata(
+ atom_array,
+ mol_info,
+ NA_only=NA_only,
+ planar_only=planar_only,
+ )
+
+ chain_iid_list: list[str] = token_level_data["chain_iid_list"]
+ chain_list: list[str] = token_level_data["chain_list"]
+ resi_list: list[int] = token_level_data["resi_list"]
+ rep_atom_list: list[str | None] = token_level_data["rep_atom_list"]
+ S_start_atom_list: list[str | None] = token_level_data["S_start_atom_list"]
+ S_stop_atom_list: list[str | None] = token_level_data["S_stop_atom_list"]
+
+ planar_atom_list_dict = find_planar_positions(
+ atom_array, mol_info
+ ) # {(chain_iid, res_id): [atom_name, ...]}
+ has_planar_sc: list[bool] = []
+
+ xyz_planar: list[
+ list[list[float]]
+ ] = [] # list[I] of [K_i, 3] (K_i varies per residue)
+ xyz_S_start: list[list[float]] = [] # list[I] of [3]
+ xyz_S_stop: list[list[float]] = [] # list[I] of [3]
+
+ for c, r, S_start_atm, S_stop_atm in zip(
+ chain_iid_list,
+ resi_list,
+ S_start_atom_list,
+ S_stop_atom_list,
+ ):
+ planar_atoms_i = planar_atom_list_dict[(c, r)]
+ has_planar_sc.append(bool(len(planar_atoms_i) >= 4))
+
+ atom_array_i = atom_array[
+ (atom_array.chain_iid == c) & (atom_array.res_id == r)
+ ]
+
+ planar_coords_i: list[list[float]] = []
+ for pl_atm_name_j in planar_atoms_i:
+ pl_atom_array_ij = atom_array_i[atom_array_i.atom_name == pl_atm_name_j]
+ if len(pl_atom_array_ij) == 0:
+ planar_coords_i.append([float("nan"), float("nan"), float("nan")])
+ else:
+ planar_coords_i.append(pl_atom_array_ij[0].coord)
+
+ xyz_planar.append(
+ planar_coords_i if len(planar_coords_i) > 3 else [[float("nan")] * 3]
+ )
+
+ if S_start_atm is None:
+ xyz_S_start.append([float("nan"), float("nan"), float("nan")])
+ else:
+ S_start_atom_array_i = atom_array_i[atom_array_i.atom_name == S_start_atm]
+ xyz_S_start.append(
+ [float("nan"), float("nan"), float("nan")]
+ if len(S_start_atom_array_i) == 0
+ else S_start_atom_array_i[0].coord
+ )
+
+ if S_stop_atm is None:
+ xyz_S_stop.append([float("nan"), float("nan"), float("nan")])
+ else:
+ S_stop_atom_array_i = atom_array_i[atom_array_i.atom_name == S_stop_atm]
+ xyz_S_stop.append(
+ [float("nan"), float("nan"), float("nan")]
+ if len(S_stop_atom_array_i) == 0
+ else S_stop_atom_array_i[0].coord
+ )
+
+ del atom_array_i
+
+ # frame coordinates and backbone direction
+ frame_xyz = np.asarray( # [I, 3] representative-atom coordinates
+ make_coord_list(atom_array, resi_list, chain_list, rep_atom_list),
+ dtype=np.float32,
+ )
+
+ padded_centers = np.concatenate(
+ [frame_xyz[:1], frame_xyz, frame_xyz[-1:]], axis=0
+ ) # [I+2, 3]
+ M_i = (
+ ( # [I, 3] smoothed backbone-direction vectors
+ (padded_centers[1:-1] - padded_centers[:-2])
+ + (padded_centers[2:] - padded_centers[1:-1])
+ )
+ / 2.0
+ )
+
+ is_planar_arr = np.asarray(has_planar_sc, dtype=bool) # [I]
+ token_level_data["is_planar"] = is_planar_arr
+
+ is_na_arr = np.asarray(token_level_data["is_na"], dtype=bool) # [I]
+ if NA_only and planar_only:
+ filter_mask = is_na_arr & is_planar_arr
+ elif NA_only and (not planar_only):
+ filter_mask = is_na_arr.copy()
+ elif (not NA_only) and planar_only:
+ filter_mask = is_planar_arr.copy()
+ else:
+ filter_mask = np.ones_like(is_na_arr, dtype=bool)
+ token_level_data["filter_mask"] = filter_mask # [I] bool
+
+ token_level_data.update(
+ {
+ "xyz_planar": xyz_planar,
+ "xyz_S_start": xyz_S_start,
+ "xyz_S_stop": xyz_S_stop,
+ "frame_xyz": frame_xyz,
+ "M_i": M_i,
+ "include_geometry": True,
+ }
+ )
+
+ del planar_atom_list_dict, padded_centers
+ return token_level_data
+
+
+# ---------------------------------------------------------------------------
+# Sub-calculations used by compute_nucleic_ss
+# ---------------------------------------------------------------------------
+
+
+def _compute_local_frames(
+ xyz_planar: list[np.ndarray],
+ planar_centers: np.ndarray,
+ M_i: np.ndarray,
+ *,
+ xyz_S_start: list | None = None,
+ xyz_S_stop: list | None = None,
+ compute_full_frame: bool = False,
+ eps: float = 1e-8,
+) -> dict[str, np.ndarray]:
+ """Build per-residue local coordinate frames from planar sidechain atoms.
+
+ The base-normal direction Z_i is always computed via PCA on the planar
+ atom cloud, corrected for backbone direction. When *compute_full_frame*
+ is True the sugar-edge vector is used to derive X_i and Y_i as well.
+
+ Args:
+ xyz_planar: Per-residue planar-atom coordinates, list[I] of [K_i, 3].
+ planar_centers: Sidechain planar-atom centroids, [I, 3].
+ M_i: Backbone-direction vectors, [I, 3].
+ xyz_S_start: Sugar-edge start coordinates, list[I] of [3].
+ Required when *compute_full_frame* is True.
+ xyz_S_stop: Sugar-edge stop coordinates, list[I] of [3].
+ Required when *compute_full_frame* is True.
+ compute_full_frame: If True, also compute X_i and Y_i.
+ eps: Small constant for numerical stability.
+
+ Returns:
+ Dict with ``"Z_i"`` (always), and ``"X_i"``, ``"Y_i"`` when
+ *compute_full_frame* is True. Each array has shape ``[I, 3]``.
+ """
+ n_tokens = len(xyz_planar)
+
+ # Mean-centre the planar atoms per residue
+ centered_points = [ # list[I] of [K_i, 3]
+ np.asarray(xyz_i, dtype=np.float32) - cen_i
+ for xyz_i, cen_i in zip(xyz_planar, planar_centers)
+ ]
+
+ # PCA → eigenvectors per residue
+ eigenvectors = np.full((n_tokens, 3, 3), np.nan, dtype=np.float32) # [I, 3, 3]
+
+ for i, xyz_i in enumerate(centered_points):
+ xyz_i = xyz_i[~np.isnan(xyz_i).any(axis=1)]
+ if xyz_i.shape[0] >= 3:
+ cov_matrix = np.einsum("ij,ik->jk", xyz_i, xyz_i) / max( # [3, 3]
+ xyz_i.shape[0] - 1, 1
+ )
+ _, eigvecs = np.linalg.eigh(cov_matrix) # [3, 3]
+ eigenvectors[i] = eigvecs
+
+ # Base-normal: smallest-eigenvalue direction, corrected for backbone dir
+ N_i = eigenvectors[:, :, 0] # [I, 3]
+ N_i = N_i / (np.linalg.norm(N_i, axis=1, keepdims=True) + eps)
+
+ Z_i = N_i * np.sum(M_i * N_i, axis=-1, keepdims=True) # [I, 3]
+ Z_i = Z_i / (np.linalg.norm(Z_i, axis=-1, keepdims=True) + eps)
+
+ result: dict[str, np.ndarray] = {"Z_i": Z_i}
+
+ if compute_full_frame:
+ if xyz_S_start is None or xyz_S_stop is None:
+ raise ValueError("xyz_S_start and xyz_S_stop are required for full frame")
+
+ X_s_i = ( # [I, 3] sugar-edge direction
+ np.asarray(xyz_S_stop, dtype=np.float32)
+ - np.asarray(xyz_S_start, dtype=np.float32)
+ )
+ X_s_i = X_s_i / (np.linalg.norm(X_s_i, axis=-1, keepdims=True) + eps)
+
+ X_i = np.cross(Z_i, X_s_i) # [I, 3]
+ X_i = X_i / (np.linalg.norm(X_i, axis=-1, keepdims=True) + eps)
+ result["X_i"] = X_i
+
+ Y_i = np.cross(X_i, Z_i) # [I, 3]
+ Y_i = Y_i / (np.linalg.norm(Y_i, axis=-1, keepdims=True) + eps)
+ result["Y_i"] = Y_i
+
+ return result
+
+
+def _compute_pairwise_geometry(
+ Z_i: np.ndarray,
+ frame_D_ij_vec: np.ndarray,
+ sc_D_ij_vec: np.ndarray,
+ *,
+ X_i: np.ndarray | None = None,
+ clamp: bool = True,
+ compute_opening: bool = False,
+ eps: float = 1e-8,
+) -> dict[str, np.ndarray]:
+ """Compute pairwise base-step geometry between all residue pairs.
+
+ Derives the pairwise coordinate frame (X_ij, Y_ij, Z_ij) and the
+ base-pair geometry parameters: rise (H_ij), buckle (B_ij), propeller
+ (P_ij), and optionally opening angle (O_ij).
+
+ Args:
+ Z_i: Per-residue base-normal vectors, [I, 3].
+ frame_D_ij_vec: Pairwise backbone displacement vectors, [I, I, 3].
+ sc_D_ij_vec: Pairwise sidechain-centroid displacement vectors, [I, I, 3].
+ X_i: Per-residue local X-axis, [I, 3]. Required when
+ *compute_opening* is True.
+ clamp: Clamp cosines to [-1, 1] before ``arccos``.
+ compute_opening: If True, compute opening angle O_ij.
+ eps: Small constant for numerical stability.
+
+ Returns:
+ Dict with keys ``"H_ij"`` [I, I], ``"B_ij"`` [I, I],
+ ``"P_ij"`` [I, I], ``"base_ori_ij"`` [I, I],
+ ``"X_ij"`` [I, I, 3], ``"Y_ij"`` [I, I, 3],
+ ``"Z_ij"`` [I, I, 3], and optionally ``"O_ij"`` [I, I].
+ """
+ # Orientation-selected pairwise Z-axis
+ Z_sum = Z_i[:, None, :] + Z_i[None, :, :] # [I, I, 3]
+ Z_diff = Z_i[:, None, :] - Z_i[None, :, :] # [I, I, 3]
+ Z_ij_oris = 0.5 * np.stack((Z_sum, Z_diff), axis=0) # [2, I, I, 3]
+
+ base_ori_ij = ( # [I, I] 0=parallel, 1=antiparallel
+ np.linalg.norm(Z_ij_oris[1], axis=-1) > np.linalg.norm(Z_ij_oris[0], axis=-1)
+ ).astype(np.int64)
+
+ Z_ij = np.where(
+ base_ori_ij[..., None] == 0, Z_ij_oris[0], Z_ij_oris[1]
+ ) # [I, I, 3]
+ Z_ij = Z_ij / (np.linalg.norm(Z_ij, axis=-1, keepdims=True) + eps)
+
+ # Pairwise Y (inter-residue direction) and X axes
+ Y_ij = frame_D_ij_vec / (
+ np.linalg.norm(frame_D_ij_vec, axis=-1, keepdims=True) + eps
+ ) # [I, I, 3]
+ X_ij = np.cross(Z_ij, Y_ij) # [I, I, 3]
+ X_ij = X_ij / (np.linalg.norm(X_ij, axis=-1, keepdims=True) + eps)
+
+ # Rise (H_ij)
+ H_ij = np.sum(sc_D_ij_vec * Z_ij, axis=-1) # [I, I]
+ D_ij = np.linalg.norm(sc_D_ij_vec, axis=-1) # [I, I]
+
+ # Buckle (B_ij)
+ proj_Z_i_YZ = ( # [I, I, 3]
+ np.sum(Z_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij
+ + np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij
+ )
+ proj_Z_i_YZ_norm = proj_Z_i_YZ / (
+ np.linalg.norm(proj_Z_i_YZ, axis=-1, keepdims=True) + eps
+ )
+ cos_buckle = np.sum(
+ proj_Z_i_YZ_norm * (-proj_Z_i_YZ_norm.swapaxes(0, 1)), axis=-1
+ ) # [I, I]
+
+ # Propeller (P_ij)
+ proj_Z_i_ZX = ( # [I, I, 3]
+ np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij
+ + np.sum(Z_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij
+ )
+ proj_Z_i_ZX_norm = proj_Z_i_ZX / (
+ np.linalg.norm(proj_Z_i_ZX, axis=-1, keepdims=True) + eps
+ )
+ cos_propeller = np.sum(
+ proj_Z_i_ZX_norm * (-proj_Z_i_ZX_norm.swapaxes(0, 1)), axis=-1
+ ) # [I, I]
+
+ if clamp:
+ cos_buckle = np.clip(cos_buckle, -1.0, 1.0)
+ cos_propeller = np.clip(cos_propeller, -1.0, 1.0)
+
+ B_ij = np.arccos(cos_buckle) # [I, I]
+ P_ij = np.arccos(cos_propeller) # [I, I]
+
+ result: dict[str, np.ndarray] = {
+ "H_ij": H_ij,
+ "B_ij": B_ij,
+ "P_ij": P_ij,
+ "D_ij": D_ij,
+ "base_ori_ij": base_ori_ij,
+ "X_ij": X_ij,
+ "Y_ij": Y_ij,
+ "Z_ij": Z_ij,
+ }
+
+ # Opening angle (O_ij) — purely diagnostic
+ if compute_opening:
+ if X_i is None:
+ raise ValueError("X_i is required to compute opening angle")
+
+ proj_X_i_XY = ( # [I, I, 3]
+ np.sum(X_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij
+ + np.sum(X_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij
+ )
+ proj_X_i_XY_norm = proj_X_i_XY / (
+ np.linalg.norm(proj_X_i_XY, axis=-1, keepdims=True) + eps
+ )
+ cos_opening = np.sum(
+ proj_X_i_XY_norm * proj_X_i_XY_norm.swapaxes(0, 1), axis=-1
+ ) # [I, I]
+ if clamp:
+ cos_opening = np.clip(cos_opening, -1.0, 1.0)
+ result["O_ij"] = np.arccos(cos_opening) # [I, I]
+
+ return result
+
+
+def _compute_basepair_mask(
+ hbond_count: np.ndarray,
+ seq_neighbors: np.ndarray,
+ H_ij: np.ndarray,
+ B_ij: np.ndarray,
+ P_ij: np.ndarray,
+ D_ij: np.ndarray,
+ mol_info,
+ *,
+ bool_only: bool = False,
+ eps: float = 1e-8,
+) -> dict[str, np.ndarray] | np.ndarray:
+ """Identify base pairs by combining H-bond scores with geometry filters.
+
+ Computes a sigmoid-based base-pair probability from weighted H-bond
+ counts and gates it with rise / buckle / propeller geometry limits.
+
+ Args:
+ hbond_count: H-bond counts, [I, I, 3] (BB-BB / BB-SC / SC-SC).
+ seq_neighbors: Sequence-neighbor boolean mask, [I, I].
+ H_ij: Rise displacement, [I, I].
+ B_ij: Buckle angle (radians), [I, I].
+ P_ij: Propeller angle (radians), [I, I].
+ mol_info: Molecular-info object with ``bp_summation_weights``,
+ ``bp_hbond_coeff``, ``min_hbonds_for_bp``, ``bp_val_cutoff``,
+ and ``base_geometry_limits``.
+ bool_only: If True, return only the boolean mask array.
+ eps: Small constant for numerical stability.
+
+ Returns:
+ If *bool_only*: ``np.ndarray`` of shape ``(I, I)`` (bool).
+ Otherwise: dict with ``"basepairs_bool_ij"`` [I, I] (bool),
+ ``"basepairs_ij"`` [I, I] (float), and
+ ``"hbond_summation"`` [I, I] (float).
+ """
+ hbond_summation = np.tensordot( # [I, I]
+ hbond_count.astype(np.float32),
+ np.asarray(mol_info.bp_summation_weights, dtype=np.float32),
+ axes=([2], [0]),
+ )
+
+ logits = mol_info.bp_hbond_coeff * ( # [I, I]
+ hbond_summation - (mol_info.min_hbonds_for_bp - 1)
+ )
+ bp_preds = (1.0 / (1.0 + np.exp(-logits))) + eps # [I, I]
+
+ # Geometry filters
+ H_ij_filter = ( # [I, I]
+ (H_ij >= -mol_info.base_geometry_limits["H_ij"])
+ & (H_ij <= mol_info.base_geometry_limits["H_ij"])
+ )
+ B_ij_filter = ( # [I, I]
+ (B_ij <= mol_info.base_geometry_limits["B_ij"])
+ | (B_ij >= math.pi - mol_info.base_geometry_limits["B_ij"])
+ )
+ P_ij_filter = ( # [I, I]
+ (P_ij <= mol_info.base_geometry_limits["P_ij"])
+ | (P_ij >= math.pi - mol_info.base_geometry_limits["P_ij"])
+ )
+
+ D_ij_filter = D_ij <= mol_info.base_geometry_limits["D_ij"]
+
+ bp_geom_filter = H_ij_filter & B_ij_filter & P_ij_filter & D_ij_filter # [I, I]
+
+ if bool_only:
+ basepairs_bool_ij = ( # [I, I]
+ (~seq_neighbors)
+ & bp_geom_filter
+ & (bp_preds >= float(mol_info.bp_val_cutoff))
+ )
+ return basepairs_bool_ij
+
+ basepairs_ij = ( # [I, I]
+ (~seq_neighbors).astype(np.float32)
+ * bp_geom_filter.astype(np.float32)
+ * bp_preds.astype(np.float32)
+ )
+ basepairs_bool_ij = basepairs_ij >= mol_info.bp_val_cutoff # [I, I]
+
+ return {
+ "basepairs_bool_ij": basepairs_bool_ij,
+ "basepairs_ij": basepairs_ij,
+ "hbond_summation": hbond_summation,
+ }
+
+
+def compute_nucleic_ss(
+ mol_info,
+ token_level_data,
+ hbond_count,
+ clamp_pairwise_params=True,
+ eps=1e-8,
+ *,
+ return_local_params: bool = False,
+ return_pairwise_geometry: bool = False,
+ return_opening_angle: bool = False,
+ return_basepairs_only: bool = False,
+):
+ """Compute nucleic-acid pairwise base-pair geometry and filters.
+
+ Operates in two modes:
+
+ * **Fast annotation** (default / ``return_basepairs_only=True``): returns
+ only ``basepairs_bool_ij`` and frees intermediate arrays.
+ * **Diagnostic**: additionally returns local/pairwise geometry when
+ ``return_pairwise_geometry``, ``return_local_params``, or
+ ``return_opening_angle`` are set.
+
+ Args:
+ mol_info: Molecular-info constants (geometry limits, H-bond weights).
+ token_level_data: Token-level dict with geometry (from
+ :func:`add_token_level_geometry_data`).
+ hbond_count: H-bond count array, shape ``(I_full, I_full, 3)``.
+ clamp_pairwise_params: Clamp cosines to [-1, 1] before ``arccos``.
+ eps: Small constant for numerical stability.
+ return_local_params: Return per-residue X/Y/Z local frames.
+ return_pairwise_geometry: Return pairwise X_ij/Y_ij/Z_ij arrays.
+ return_opening_angle: Return pairwise opening angle O_ij.
+ return_basepairs_only: Return only the boolean base-pair mask
+ (fastest path).
+
+ Returns:
+ If ``return_basepairs_only``: ``np.ndarray`` of shape ``(I, I)``
+ (bool) — the base-pair boolean mask.
+
+ Otherwise: dict ``{"pair_params": {...}, "local_params": {...}}``
+ containing the requested geometry arrays (all shape ``(I, I)`` or
+ ``(I, 3)``).
+ """
+
+ mask_1d = np.asarray(token_level_data["filter_mask"], dtype=bool) # [I_full]
+
+ # --- Unpack and filter token-level data ----------------------
+ M_i = np.asarray(token_level_data["M_i"], dtype=np.float32)[mask_1d] # [I, 3]
+ frame_xyz = np.asarray(token_level_data["frame_xyz"], dtype=np.float32)[
+ mask_1d
+ ] # [I, 3]
+ xyz_S_start = [
+ v for v, k in zip(token_level_data["xyz_S_start"], mask_1d) if k
+ ] # list[I] of [3]
+ xyz_S_stop = [
+ v for v, k in zip(token_level_data["xyz_S_stop"], mask_1d) if k
+ ] # list[I] of [3]
+ xyz_planar = [
+ v for v, k in zip(token_level_data["xyz_planar"], mask_1d) if k
+ ] # list[I] of [K_i, 3]
+
+ hbond_count = np.asarray(hbond_count)[mask_1d, :][:, mask_1d] # [I, I, 3]
+ seq_neighbors = np.asarray(token_level_data["seq_neighbors"], dtype=bool)[
+ mask_1d, :
+ ][:, mask_1d] # [I, I]
+
+ # Nothing passed NA/planar filtering for this structure.
+ # Return empty outputs instead of failing downstream on np.stack([]).
+ if len(xyz_planar) == 0:
+ if return_basepairs_only:
+ return np.zeros((0, 0), dtype=bool)
+
+ pair_params: dict[str, np.ndarray] = {
+ "H_ij": np.zeros((0, 0), dtype=np.float32),
+ "B_ij": np.zeros((0, 0), dtype=np.float32),
+ "P_ij": np.zeros((0, 0), dtype=np.float32),
+ "D_ij": np.zeros((0, 0), dtype=np.float32),
+ "base_ori_ij": np.zeros((0, 0), dtype=np.float32),
+ "basepairs_bool_ij": np.zeros((0, 0), dtype=bool),
+ "basepairs_ij": np.zeros((0, 0), dtype=np.float32),
+ "hbond_summation": np.zeros((0, 0), dtype=np.float32),
+ }
+
+ if return_opening_angle:
+ pair_params["O_ij"] = np.zeros((0, 0), dtype=np.float32)
+
+ if return_pairwise_geometry:
+ pair_params["X_ij"] = np.zeros((0, 0), dtype=np.float32)
+ pair_params["Y_ij"] = np.zeros((0, 0), dtype=np.float32)
+ pair_params["Z_ij"] = np.zeros((0, 0), dtype=np.float32)
+
+ nucleic_ss_data: dict = {"pair_params": pair_params}
+ if return_local_params:
+ nucleic_ss_data["local_params"] = {
+ "X_i": np.zeros((0, 3), dtype=np.float32),
+ "Y_i": np.zeros((0, 3), dtype=np.float32),
+ "Z_i": np.zeros((0, 3), dtype=np.float32),
+ }
+
+ return nucleic_ss_data
+
+ # --- Precompute centroids and displacement vectors -----------
+ planar_centers = np.stack( # [I, 3]
+ [
+ np.nanmean(np.asarray(xyz_i, dtype=np.float32), axis=0)
+ for xyz_i in xyz_planar
+ ],
+ axis=0,
+ ).astype(np.float32)
+
+ frame_D_ij_vec = frame_xyz[None, :, :] - frame_xyz[:, None, :] # [I, I, 3]
+ sc_D_ij_vec = planar_centers[None, :, :] - planar_centers[:, None, :] # [I, I, 3]
+
+ # --- CALC I: per-residue local coordinate frames -------------
+ need_full_frame = return_local_params or return_opening_angle
+ local_frames = _compute_local_frames(
+ xyz_planar,
+ planar_centers,
+ M_i,
+ xyz_S_start=xyz_S_start if need_full_frame else None,
+ xyz_S_stop=xyz_S_stop if need_full_frame else None,
+ compute_full_frame=need_full_frame,
+ eps=eps,
+ )
+ Z_i = local_frames["Z_i"] # [I, 3]
+ X_i = local_frames.get("X_i") # [I, 3] or None
+
+ # --- CALC II: pairwise base-step geometry --------------------
+ pw_geom = _compute_pairwise_geometry(
+ Z_i,
+ frame_D_ij_vec,
+ sc_D_ij_vec,
+ X_i=X_i,
+ clamp=clamp_pairwise_params,
+ compute_opening=return_opening_angle,
+ eps=eps,
+ )
+
+ # --- CALC III: base-pair identification ----------------------
+ bp_result = _compute_basepair_mask(
+ hbond_count,
+ seq_neighbors,
+ pw_geom["H_ij"],
+ pw_geom["B_ij"],
+ pw_geom["P_ij"],
+ pw_geom["D_ij"],
+ mol_info,
+ bool_only=return_basepairs_only,
+ eps=eps,
+ )
+
+ if return_basepairs_only:
+ return bp_result # np.ndarray [I, I] bool
+
+ # --- Assemble output dict ------------------------------------
+ assert isinstance(bp_result, dict)
+
+ pair_params: dict[str, np.ndarray] = {
+ "H_ij": pw_geom["H_ij"],
+ "B_ij": pw_geom["B_ij"],
+ "P_ij": pw_geom["P_ij"],
+ "base_ori_ij": pw_geom["base_ori_ij"],
+ "basepairs_bool_ij": bp_result["basepairs_bool_ij"],
+ "basepairs_ij": bp_result["basepairs_ij"],
+ "hbond_summation": bp_result["hbond_summation"],
+ }
+
+ if return_opening_angle and "O_ij" in pw_geom:
+ pair_params["O_ij"] = pw_geom["O_ij"]
+
+ if return_pairwise_geometry:
+ pair_params["X_ij"] = pw_geom["X_ij"]
+ pair_params["Y_ij"] = pw_geom["Y_ij"]
+ pair_params["Z_ij"] = pw_geom["Z_ij"]
+
+ nucleic_ss_data: dict = {"pair_params": pair_params}
+ if return_local_params and "Y_i" in local_frames:
+ nucleic_ss_data["local_params"] = {
+ "X_i": local_frames["X_i"],
+ "Y_i": local_frames["Y_i"],
+ "Z_i": local_frames["Z_i"],
+ }
+
+ return nucleic_ss_data
+
+
+def annotate_na_ss(
+ atom_array: AtomArray,
+ *,
+ NA_only: bool = False,
+ planar_only: bool = True,
+ p_canonical_bp_filter: float = 0.0,
+ mol_info: Optional[NucMolInfo] = None,
+ overwrite: bool = True,
+ token_level_data: Optional[dict] = None,
+ cutoff_HA_dist: float = 3.5,
+ cutoff_DA_dist: float = 3.5,
+) -> AtomArray:
+ """Compute base pairs and write a ``bp_partners`` annotation onto *atom_array*.
+
+ Uses H-bond counts and pairwise geometry filters to identify base pairs,
+ then stores the result as a per-atom annotation with the following
+ semantics:
+
+ * ``[]`` — explicitly unpaired (loop)
+ * ``[token_id, ...]`` — paired partner token IDs
+ * ``None`` — unannotated / masked (non-NA or filtered-out tokens)
+
+ Args:
+ atom_array: Structure to annotate (modified in-place).
+ NA_only: Restrict geometry filter to nucleic-acid tokens.
+ planar_only: Restrict geometry filter to tokens with planar
+ sidechains.
+ p_canonical_bp_filter: Probability of discarding non-canonical
+ base pairs (keeps only A–U, A–T, G–C).
+ mol_info: Molecular-info constants; created if ``None``.
+ overwrite: If False, merge with existing ``bp_partners``.
+ token_level_data: Pre-computed metadata dict; augmented with
+ geometry as needed.
+ cutoff_HA_dist: H–A distance cutoff (Å) for HBPLUS.
+ cutoff_DA_dist: D–A distance cutoff (Å) for HBPLUS.
+
+ Returns:
+ The same *atom_array* with the ``bp_partners`` annotation set.
+ """
+
+ if mol_info is None:
+ mol_info = NucMolInfo()
+
+ # Residue representatives (0..L-1) and their corresponding atom indices.
+ # Keep this aligned with get_token_level_metadata(), which uses residue starts.
+ if token_level_data is not None and "token_starts" in token_level_data:
+ token_starts = np.asarray(token_level_data["token_starts"], dtype=int)
+ else:
+ token_starts = struc.get_residue_starts(atom_array)
+ residue_start_end = np.concatenate([token_starts, [atom_array.array_length()]])
+ token_level_array = atom_array[token_starts]
+ # token_id is assigned token-wise and matches get_token_starts() segmentation.
+ token_ids: list[int] = [int(t) for t in list(token_level_array.token_id)]
+ token_res_names: list[str] = [str(rn) for rn in list(token_level_array.res_name)]
+
+ # Compute basepairs on the token graph (respecting NA_only/planar_only filtering)
+ if token_level_data is None:
+ token_level_data = get_token_level_metadata(
+ atom_array,
+ mol_info,
+ NA_only=NA_only,
+ planar_only=planar_only,
+ )
+ token_level_data = add_token_level_geometry_data(
+ atom_array,
+ mol_info,
+ token_level_data,
+ NA_only=NA_only,
+ planar_only=planar_only,
+ )
+ # Note: this mask gives positions that are *chemically valid* for forming
+ # base pairs, which is different from custom mask-generation for features
+ mask_1d = np.asarray(token_level_data["filter_mask"], dtype=bool)
+
+ subset_idxs = np.nonzero(mask_1d)[0]
+
+ is_na_full = np.asarray(token_level_data["is_na"], dtype=bool)
+
+ hbond_count = calculate_hb_counts(
+ atom_array,
+ token_level_data,
+ mol_info,
+ cutoff_HA_dist=cutoff_HA_dist,
+ cutoff_DA_dist=cutoff_DA_dist,
+ )
+ bp_bool = np.asarray(
+ compute_nucleic_ss(
+ mol_info,
+ token_level_data,
+ hbond_count,
+ clamp_pairwise_params=True,
+ eps=1e-8,
+ return_local_params=False,
+ return_pairwise_geometry=False,
+ return_opening_angle=False,
+ return_basepairs_only=True,
+ ),
+ dtype=bool,
+ )
+
+ # Apply optional filters
+ if NA_only:
+ bp_bool &= is_na_full[:, None]
+ bp_bool &= is_na_full[None, :]
+ if planar_only:
+ n_tokens = bp_bool.shape[0]
+ has_planar_sc = np.asarray(
+ token_level_data.get("has_planar_sc", np.ones(n_tokens, dtype=bool)),
+ dtype=bool,
+ )
+ bp_bool &= has_planar_sc[:, None]
+ bp_bool &= has_planar_sc[None, :]
+
+ # Optional: filter to canonical Watson-Crick basepairs only.
+ # Sampled probabilistically to allow mixed supervision during training.
+ do_canonical_filter = bool(
+ p_canonical_bp_filter and (np.random.rand() < float(p_canonical_bp_filter))
+ )
+ if do_canonical_filter:
+
+ def _base_letter(res_name: str) -> str | None:
+ rn = str(res_name).strip().upper()
+ if rn in STANDARD_RNA:
+ return rn
+ if rn in STANDARD_DNA:
+ return rn[1] # DA/DC/DG/DT -> A/C/G/T
+ return None
+
+ allowed_pairs = {
+ ("A", "U"),
+ ("U", "A"),
+ ("A", "T"),
+ ("T", "A"),
+ ("G", "C"),
+ ("C", "G"),
+ }
+ base_letters_full: list[str | None] = [
+ _base_letter(rn) for rn in token_res_names
+ ]
+
+ bp_bool = np.asarray(bp_bool, dtype=bool)
+ bp_rows_tmp, bp_cols_tmp = np.nonzero(bp_bool)
+ for r, c in zip(bp_rows_tmp.tolist(), bp_cols_tmp.tolist()):
+ full_i = int(subset_idxs[int(r)])
+ full_j = int(subset_idxs[int(c)])
+ bi = base_letters_full[full_i]
+ bj = base_letters_full[full_j]
+ if bi is None or bj is None or (bi, bj) not in allowed_pairs:
+ bp_bool[int(r), int(c)] = False
+ bp_bool[int(c), int(r)] = False
+
+ bp_bool = np.asarray(bp_bool, dtype=bool)
+ bp_rows, bp_cols = np.nonzero(bp_bool)
+
+ # Build residue-level annotation first, then spread to all atoms in each residue.
+ if (not overwrite) and ("bp_partners" in atom_array.get_annotation_categories()):
+ existing_ann = atom_array.bp_partners
+ if len(existing_ann) != len(atom_array):
+ raise ValueError("Existing bp_partners annotation has wrong length")
+ residue_bp_partners = np.empty(len(token_starts), dtype=object)
+ residue_bp_partners[:] = None
+ for i, start in enumerate(token_starts.tolist()):
+ residue_bp_partners[i] = existing_ann[int(start)]
+ else:
+ residue_bp_partners = np.empty(len(token_starts), dtype=object)
+ residue_bp_partners[:] = None
+
+ # Explicit-loop semantics:
+ # - Only nucleic-acid token-start atoms *within subset_idxs* get a list container.
+ # - [] means explicitly unpaired loop.
+ # - None means unannotated/masked.
+ for full_i in subset_idxs.tolist():
+ if not bool(is_na_full[int(full_i)]):
+ continue
+ if residue_bp_partners[int(full_i)] is None:
+ residue_bp_partners[int(full_i)] = []
+
+ # Populate partners using token_id ints
+ # We only process each unordered pair once to avoid duplicates.
+ for r, c in zip(bp_rows.tolist(), bp_cols.tolist()):
+ if r == c:
+ continue
+
+ full_i = int(subset_idxs[int(r)])
+ full_j = int(subset_idxs[int(c)])
+ if full_i == full_j:
+ continue
+
+ # Only annotate NA-NA basepairs as nucleic secondary structure.
+ if (not bool(is_na_full[int(full_i)])) or (not bool(is_na_full[int(full_j)])):
+ continue
+
+ # Enforce uniqueness: only handle (i,j) where i < j
+ if full_j < full_i:
+ continue
+
+ partner_i = int(token_ids[full_j])
+ partner_j = int(token_ids[full_i])
+
+ if residue_bp_partners[full_i] is None:
+ residue_bp_partners[full_i] = []
+ if residue_bp_partners[full_j] is None:
+ residue_bp_partners[full_j] = []
+
+ # Add if not present
+ if partner_i not in residue_bp_partners[full_i]:
+ residue_bp_partners[full_i].append(partner_i)
+ if partner_j not in residue_bp_partners[full_j]:
+ residue_bp_partners[full_j].append(partner_j)
+
+ # Project residue-level annotations back to atom-level storage:
+ # - atomized residues: spread to all atoms in that residue
+ # - non-atomized residues: keep only on token-start representative atom
+ bp_partners_ann = np.empty(len(atom_array), dtype=object)
+ bp_partners_ann[:] = None
+ for i, start in enumerate(token_starts.tolist()):
+ stop = int(residue_start_end[i + 1])
+ value = residue_bp_partners[i]
+ if value is None:
+ continue
+ # A residue is treated as atomized if any atom in the residue carries atomize=True.
+ if "atomize" in atom_array.get_annotation_categories():
+ residue_is_atomized = bool(
+ np.any(np.asarray(atom_array.atomize[int(start) : stop], dtype=bool))
+ )
+ else:
+ residue_is_atomized = False
+ if residue_is_atomized:
+ for atom_idx in range(int(start), stop):
+ bp_partners_ann[atom_idx] = list(value)
+ else:
+ bp_partners_ann[int(start)] = list(value)
+
+ atom_array.set_annotation("bp_partners", bp_partners_ann)
+ return atom_array
+
+
+def parse_dot_bracket(dot_bracket: str) -> tuple[list[tuple[int, int]], list[int]]:
+ """Parse a dot-bracket string into base pairs and unpaired positions.
+
+ Supports standard ``()``, ``[]``, ``{}``, ``<>`` and pseudoknot
+ brackets ``A``–``E`` / ``a``–``e``.
+
+ Args:
+ dot_bracket: Dot-bracket notation string.
+
+ Returns:
+ Tuple of ``(pairs, unpaired)`` where *pairs* is a list of 0-based
+ ``(i, j)`` index tuples and *unpaired* is a list of 0-based indices
+ corresponding to ``.`` characters.
+ """
+
+ stack: dict[str, list[int]] = {}
+ pairs: list[tuple[int, int]] = []
+ unpaired: list[int] = []
+
+ opener_for = {
+ ")": "(",
+ "]": "[",
+ "}": "{",
+ ">": "<",
+ "a": "A",
+ "b": "B",
+ "c": "C",
+ "d": "D",
+ "e": "E",
+ }
+
+ for i, ch in enumerate(str(dot_bracket)):
+ if ch == ".":
+ unpaired.append(i)
+ elif ch in "([{abcde":
+ o = opener_for.get(ch)
+ if o is None or o not in stack or not stack[o]:
+ continue
+ j = stack[o].pop()
+ pairs.append((j, i))
+ else:
+ continue
+
+ return pairs, unpaired
+
+
+def annotate_na_ss_from_specification(
+ atom_array: AtomArray,
+ specification: dict,
+ *,
+ overwrite: bool = True,
+) -> AtomArray:
+ """Write ``bp_partners`` annotation from an inference-time specification.
+
+ Inference analogue of :func:`annotate_na_ss`: interprets user-provided
+ dot-bracket strings and/or residue ranges rather than computing base
+ pairs from geometry.
+
+ Supported *specification* keys (all optional):
+
+ * ``ss_dbn``: global dot-bracket string (applied to the first *L* tokens).
+ * ``ss_dbn_dict``: ``{"-": dbn_str, ...}``.
+ * ``paired_region_list``: ``["A5-15,B1-11", ...]``.
+ * ``paired_position_list``: ``["A19,A61,A20", ...]``.
+ * ``loop_region_list``: ``["A5-10", ...]`` (forced unpaired).
+
+ Args:
+ atom_array: Structure to annotate (modified in-place).
+ specification: Specification dict as described above.
+ overwrite: If False, merge with existing ``bp_partners``.
+
+ Returns:
+ The same *atom_array* with the ``bp_partners`` annotation set.
+ """
+
+ spec = specification or {}
+ token_starts = get_token_starts(atom_array)
+ token_level_array = atom_array[token_starts]
+ token_ids: list[int] = [int(t) for t in list(token_level_array.token_id)]
+ n_tokens = len(token_starts)
+
+ # Prepare/overwrite annotation array
+ if (not overwrite) and ("bp_partners" in atom_array.get_annotation_categories()):
+ bp_partners_ann = atom_array.bp_partners
+ if len(bp_partners_ann) != len(atom_array):
+ raise ValueError("Existing bp_partners annotation has wrong length")
+ else:
+ bp_partners_ann = np.empty(len(atom_array), dtype=object)
+ bp_partners_ann[:] = None
+
+ # Build chain/res -> token index map for region/position specs.
+ # Accept both chain_iid-like keys (e.g. "A_1") and plain chain IDs (e.g. "A")
+ # so CLI/json specs like "A1,B3" work reliably in inference.
+ chain_iid_list: list[str] = [str(atm.chain_iid) for atm in token_level_array]
+ chain_id_list: list[str] = [str(atm.chain_id) for atm in token_level_array]
+ resi_list: list[int] = [int(atm.res_id) for atm in token_level_array]
+ chain_res_to_tok: dict[tuple[str, int], int] = {}
+ for i, (chain_iid, chain_id, res_id) in enumerate(
+ zip(chain_iid_list, chain_id_list, resi_list)
+ ):
+ key_iid = (chain_iid, int(res_id))
+ key_chain = (chain_id, int(res_id))
+ chain_res_to_tok.setdefault(key_iid, int(i))
+ chain_res_to_tok.setdefault(key_chain, int(i))
+ # Also support the short alias from chain_iid (e.g. "A_1" -> "A")
+ short_chain = chain_iid.split("_", 1)[0]
+ chain_res_to_tok.setdefault((short_chain, int(res_id)), int(i))
+
+ def _parse_region(region_str: str) -> tuple[str, int, int] | None:
+ region_str = str(region_str).strip()
+ if not region_str:
+ return None
+ chain_id = region_str[0]
+ rest = region_str[1:]
+ if "-" not in rest:
+ return None
+ start_s, end_s = rest.split("-", 1)
+ try:
+ start_res = int(start_s)
+ end_res = int(end_s)
+ except ValueError:
+ return None
+ if start_res > end_res:
+ start_res, end_res = end_res, start_res
+ return chain_id, start_res, end_res
+
+ def _parse_single_pos(pos_str: str) -> tuple[str, int] | None:
+ pos_str = str(pos_str).strip()
+ if not pos_str:
+ return None
+ chain_id = pos_str[0]
+ rest = pos_str[1:]
+ try:
+ res_id = int(rest)
+ except ValueError:
+ return None
+ return chain_id, res_id
+
+ def _region_to_token_indices(region_str: str) -> list[int]:
+ parsed = _parse_region(region_str)
+ if parsed is None:
+ return []
+ chain_id, start_res, end_res = parsed
+ token_indices: list[int] = []
+ for res_id in range(start_res, end_res + 1):
+ idx = chain_res_to_tok.get((chain_id, int(res_id)))
+ if idx is not None:
+ token_indices.append(int(idx))
+ return token_indices
+
+ def _pos_to_token_index(pos_str: str) -> int | None:
+ parsed = _parse_single_pos(pos_str)
+ if parsed is None:
+ return None
+ chain_id, res_id = parsed
+ return chain_res_to_tok.get((chain_id, int(res_id)))
+
+ # Accumulate partners as token-index sets
+ partners: list[set[int]] = [set() for _ in range(n_tokens)]
+ loop_token_idxs: set[int] = set()
+
+ def _add_pair(i: int, j: int) -> None:
+ if not (0 <= i < n_tokens and 0 <= j < n_tokens):
+ return
+ if i == j:
+ return
+ if i in loop_token_idxs or j in loop_token_idxs:
+ return
+ partners[i].add(j)
+ partners[j].add(i)
+
+ # Case 1: global ss_dbn
+ ss_dbn = spec.get("ss_dbn")
+ if isinstance(ss_dbn, str) and ss_dbn.strip():
+ pairs, unpaired = parse_dot_bracket(ss_dbn.strip())
+ L = min(len(ss_dbn), n_tokens)
+ for i_local, j_local in pairs:
+ if 0 <= i_local < L and 0 <= j_local < L:
+ _add_pair(int(i_local), int(j_local))
+ for i_local in unpaired:
+ if 0 <= int(i_local) < L:
+ loop_token_idxs.add(int(i_local))
+
+ # Case 1b: ss_dbn_dict
+ ss_dbn_dict = spec.get("ss_dbn_dict", {}) or {}
+ if isinstance(ss_dbn_dict, dict):
+ for region_str, dbn_str in ss_dbn_dict.items():
+ if not isinstance(region_str, str) or not isinstance(dbn_str, str):
+ continue
+ dbn_str = dbn_str.strip()
+ if not dbn_str:
+ continue
+ toks = _region_to_token_indices(region_str)
+ if not toks or len(toks) != len(dbn_str):
+ continue
+ pairs, unpaired = parse_dot_bracket(dbn_str)
+ for i_local, j_local in pairs:
+ if 0 <= i_local < len(toks) and 0 <= j_local < len(toks):
+ _add_pair(int(toks[int(i_local)]), int(toks[int(j_local)]))
+ for i_local in unpaired:
+ if 0 <= i_local < len(toks):
+ loop_token_idxs.add(int(toks[int(i_local)]))
+
+ # Case 2: paired_region_list
+ paired_region_list = spec.get("paired_region_list", [])
+ if isinstance(paired_region_list, str):
+ paired_region_list = [paired_region_list]
+ if isinstance(paired_region_list, list):
+ for region_entry in paired_region_list:
+ if not isinstance(region_entry, str) or not region_entry.strip():
+ continue
+ region_parts = [p.strip() for p in region_entry.split(",") if p.strip()]
+ if len(region_parts) != 2:
+ continue
+ toks1 = _region_to_token_indices(region_parts[0])
+ toks2 = _region_to_token_indices(region_parts[1])
+ if not toks1 or not toks2:
+ continue
+ for ti in toks1:
+ for tj in toks2:
+ _add_pair(int(ti), int(tj))
+
+ # Case 3: paired_position_list
+ paired_position_list = spec.get("paired_position_list", [])
+ if isinstance(paired_position_list, str):
+ paired_position_list = [paired_position_list]
+ if isinstance(paired_position_list, list):
+ for group_str in paired_position_list:
+ if not isinstance(group_str, str) or not group_str.strip():
+ continue
+ pos_parts = [p.strip() for p in group_str.split(",") if p.strip()]
+ tok_indices: list[int] = []
+ for pos_str in pos_parts:
+ tok = _pos_to_token_index(pos_str)
+ if tok is not None:
+ tok_indices.append(int(tok))
+ for i in range(len(tok_indices)):
+ for j in range(i + 1, len(tok_indices)):
+ _add_pair(tok_indices[i], tok_indices[j])
+
+ # Case 4: loop_region_list
+ loop_region_list = spec.get("loop_region_list", [])
+ if isinstance(loop_region_list, str):
+ loop_region_list = [loop_region_list]
+ if isinstance(loop_region_list, list):
+ for region_str in loop_region_list:
+ if not isinstance(region_str, str) or not region_str.strip():
+ continue
+ for tok in _region_to_token_indices(region_str):
+ loop_token_idxs.add(int(tok))
+
+ # Enforce loop tokens as unpaired: remove any pairs involving them
+ for i in list(loop_token_idxs):
+ if not (0 <= i < n_tokens):
+ continue
+ for j in list(partners[i]):
+ partners[j].discard(i)
+ partners[i].clear()
+
+ # Write lists of partner token_ids onto token-start atoms.
+ # Unspecified tokens remain unannotated (None) -> NA_SS_MASK.
+ for i in range(n_tokens):
+ atom_i = int(token_starts[i])
+ if len(partners[i]) > 0:
+ bp_partners_ann[atom_i] = []
+ for j in sorted(partners[i]):
+ partner_token_id = int(token_ids[int(j)])
+ bp_partners_ann[atom_i].append(partner_token_id)
+ elif int(i) in loop_token_idxs:
+ bp_partners_ann[atom_i] = []
+
+ atom_array.set_annotation("bp_partners", bp_partners_ann)
+ return atom_array
+
+
+def annotate_na_ss_from_data_specification(
+ data: dict,
+ *,
+ overwrite: bool = True,
+) -> AtomArray:
+ """Annotate ``bp_partners`` from ``data["specification"]``.
+
+ Convenience wrapper around :func:`annotate_na_ss_from_specification`.
+
+ Args:
+ data: Pipeline data dict containing ``atom_array`` and optionally
+ ``specification``.
+ overwrite: If False, merge with existing ``bp_partners``.
+
+ Returns:
+ The annotated AtomArray (also stored back in *data*).
+ """
+ atom_array = data["atom_array"]
+ spec = data.get("specification", {}) or {}
+ return annotate_na_ss_from_specification(atom_array, spec, overwrite=overwrite)
diff --git a/models/rfd3/src/rfd3/transforms/pipelines.py b/models/rfd3/src/rfd3/transforms/pipelines.py
index 89bc7289..8d8fee10 100644
--- a/models/rfd3/src/rfd3/transforms/pipelines.py
+++ b/models/rfd3/src/rfd3/transforms/pipelines.py
@@ -72,6 +72,7 @@
)
from rfd3.transforms.design_transforms import (
AddAdditional1dFeaturesToFeats,
+ AddAdditional2dFeaturesToFeats,
AddGroundTruthSequence,
AddIsXFeats,
AssignTypes,
@@ -84,6 +85,7 @@
)
from rfd3.transforms.dna_crop import ProteinDNAContactContiguousCrop
from rfd3.transforms.hbonds_hbplus import CalculateHbondsPlus
+from rfd3.transforms.na_geom import CalculateNucleicAcidGeomFeats
from rfd3.transforms.ppi_transforms import (
Add1DSSFeature,
AddGlobalIsNonLoopyFeature,
@@ -194,6 +196,7 @@ def get_crop_transform(
max_binder_length: int,
max_atoms_in_crop: int | None,
allowed_types: List[str],
+ association_scheme: str,
):
if (
crop_contiguous_probability > 0
@@ -213,7 +216,9 @@ def get_crop_transform(
), "Crop center cutoff distance must be greater than 0"
pre_crop_transforms = [
- SubsampleToTypes(allowed_types=allowed_types),
+ SubsampleToTypes(
+ allowed_types=allowed_types, association_scheme=association_scheme
+ ),
]
cropping_transform = RandomRoute(
@@ -350,6 +355,7 @@ def build_atom14_base_pipeline_(
center_option: str,
atom_1d_features: dict | None,
token_1d_features: dict | None,
+ token_2d_features: dict | None = None,
# PPI features
max_ppi_hotspots_frac_to_provide: float,
ppi_hotspot_max_distance: float,
@@ -357,6 +363,9 @@ def build_atom14_base_pipeline_(
max_ss_frac_to_provide: float,
min_ss_island_len: int,
max_ss_island_len: int,
+ ## Nucleic acid features #####
+ # add_na_pair_features: bool,
+ ## This should not be necessary, controlled through feature names in model, and meta conditioning probabilities, inference behavior handled in transform itself #####
**_, # dump additional kwargs (e.g. msa stuff)
):
"""
@@ -383,6 +392,7 @@ def build_atom14_base_pipeline_(
train_conditions=train_conditions,
meta_conditioning_probabilities=meta_conditioning_probabilities,
sequence_encoding=af3_sequence_encoding,
+ association_scheme=association_scheme,
),
),
]
@@ -405,6 +415,7 @@ def build_atom14_base_pipeline_(
max_binder_length=max_binder_length,
max_atoms_in_crop=max_atoms_in_crop,
allowed_types=allowed_types,
+ association_scheme=association_scheme,
)
if zero_occ_on_exposure_after_cropping:
@@ -422,7 +433,9 @@ def build_atom14_base_pipeline_(
# ... Add global token features (since number of tokens is fixed after cropping)
transforms.append(AddGlobalTokenIdAnnotation())
# ... Create masks (NOTE: Modulates token count, and resets global token id if necessary)
- transforms.append(TrainingRoute(SampleConditioningFlags()))
+ transforms.append(
+ TrainingRoute(SampleConditioningFlags(association_scheme=association_scheme))
+ )
# Post-crop transforms
transforms.append(
@@ -434,6 +447,16 @@ def build_atom14_base_pipeline_(
),
)
)
+ # Add nucleic acid geometry features
+ # if add_na_pair_features:
+ transforms.append(
+ CalculateNucleicAcidGeomFeats(
+ is_inference,
+ meta_conditioning_probabilities,
+ NA_only=False,
+ planar_only=True,
+ )
+ )
# Design Transforms
transforms += [
@@ -442,7 +465,9 @@ def build_atom14_base_pipeline_(
sharding_depth=1,
),
# ... Fuse inference and training conditioning assignments
- UnindexFlaggedTokens(central_atom=central_atom),
+ UnindexFlaggedTokens(
+ central_atom=central_atom, association_scheme=association_scheme
+ ),
# ... Virtual atom padding (NOTE: Last transform which modulates atom count)
PadTokensWithVirtualAtoms(
n_atoms_per_token=n_atoms_per_token,
@@ -518,6 +543,12 @@ def build_atom14_base_pipeline_(
autofill_zeros_if_not_present_in_atomarray=True,
token_1d_features=token_1d_features,
atom_1d_features=atom_1d_features,
+ association_scheme=association_scheme,
+ ),
+ AddAdditional2dFeaturesToFeats(
+ autofill_zeros_if_not_present_in_atomarray=True,
+ token_2d_features=token_2d_features,
+ association_scheme=association_scheme,
),
AddAF3TokenBondFeatures(),
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),
@@ -593,7 +624,6 @@ def build_atom14_base_pipeline(
Wrapper around pipeline construction to handle empty training args
Sets default behaviour for inference to keep backward compatibility
"""
-
if is_inference:
# Provide explicit defaults for training-only args
kwargs.setdefault("crop_size", 512)
@@ -609,6 +639,8 @@ def build_atom14_base_pipeline(
kwargs.setdefault("min_ss_island_len", 0)
kwargs.setdefault("max_ss_island_len", 999)
kwargs.setdefault("max_binder_length", 999)
+ # This should not be necessary.
+ # kwargs.setdefault("add_na_pair_features", False)
kwargs.setdefault("b_factor_min", None)
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
@@ -621,7 +653,7 @@ def build_atom14_base_pipeline(
kwargs.setdefault("residue_cache_dir", None)
# TODO: Delete these once all checkpoints are updated with the latest defaults
- kwargs.setdefault("generate_conformers_for_non_protein_only", True)
+ kwargs.setdefault("generate_conformers_for_non_protein_only", False)
kwargs.setdefault("return_atom_array", True)
kwargs.setdefault("provide_elements_for_unindexed_components", False)
kwargs.setdefault("center_option", "all")
diff --git a/models/rfd3/src/rfd3/transforms/training_conditions.py b/models/rfd3/src/rfd3/transforms/training_conditions.py
index f93e346f..67ced33a 100644
--- a/models/rfd3/src/rfd3/transforms/training_conditions.py
+++ b/models/rfd3/src/rfd3/transforms/training_conditions.py
@@ -13,6 +13,7 @@
spread_token_wise,
)
from biotite.structure import AtomArray, get_residue_starts
+from rfd3.constants import backbone_atoms_RNA
from rfd3.transforms.conditioning_utils import (
random_condition,
sample_island_tokens,
@@ -58,6 +59,8 @@ class IslandCondition(TrainingCondition):
Select islands as motif and assign conditioning strategies.
"""
+ association_scheme = "atom14"
+
def __init__(
self,
*,
@@ -70,9 +73,11 @@ def __init__(
p_fix_motif_coordinates,
p_fix_motif_sequence,
p_unindex_motif_tokens,
+ association_scheme="atom14",
):
self.name = name
self.frequency = frequency
+ self.association_scheme = association_scheme
# Token selection
self.island_sampling_kwargs = island_sampling_kwargs
@@ -88,11 +93,21 @@ def __init__(
self.p_fix_motif_sequence = p_fix_motif_sequence
self.p_unindex_motif_tokens = p_unindex_motif_tokens
+ self.association_scheme = association_scheme
+
def is_valid_for_example(self, data) -> bool:
is_protein = data["atom_array"].is_protein
- if not np.any(is_protein):
- return False
- return True
+ is_dna = data["atom_array"].is_dna
+ is_rna = data["atom_array"].is_rna
+ ### updating this to allow other polymers
+ if self.association_scheme == "atom23":
+ if np.any(is_protein | is_dna | is_rna):
+ return True
+ else:
+ if np.any(is_protein):
+ return True
+
+ return False
def sample_motif_tokens(self, atom_array):
"""
@@ -101,13 +116,30 @@ def sample_motif_tokens(self, atom_array):
token_level_array = atom_array[get_token_starts(atom_array)]
# initialize motif tokens as all non-protein tokens
- is_motif_token = np.asarray(~token_level_array.is_protein, dtype=bool).copy()
- n_protein_tokens = np.sum(token_level_array.is_protein)
- islands_mask = sample_island_tokens(
- n_protein_tokens,
- **self.island_sampling_kwargs,
- )
- is_motif_token[token_level_array.is_protein] = islands_mask
+ if self.association_scheme == "atom23":
+ polymer_mask = (
+ token_level_array.is_protein
+ | token_level_array.is_dna
+ | token_level_array.is_rna
+ )
+ is_motif_token = np.asarray(~polymer_mask, dtype=bool).copy()
+ n_polymer_tokens = np.sum(polymer_mask)
+ islands_mask = sample_island_tokens(
+ n_polymer_tokens,
+ **self.island_sampling_kwargs,
+ )
+ is_motif_token[polymer_mask] = islands_mask
+ else:
+ is_motif_token = np.asarray(
+ ~token_level_array.is_protein, dtype=bool
+ ).copy()
+ n_protein_tokens = np.sum(token_level_array.is_protein)
+
+ islands_mask = sample_island_tokens(
+ n_protein_tokens,
+ **self.island_sampling_kwargs,
+ )
+ is_motif_token[token_level_array.is_protein] = islands_mask
# TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this
# atom_with_coval_bond = token_level_array.covale # (n_atoms, )
@@ -127,7 +159,9 @@ def sample_motif_atoms(self, atom_array):
is_motif_atom = np.asarray(atom_array.is_motif_token, dtype=bool).copy()
if random_condition(self.p_diffuse_motif_sidechains):
- backbone_atoms = ["N", "C", "CA"]
+ backbone_atoms = backbone_atoms_RNA.copy()
+ backbone_atoms.remove("C1'")
+ backbone_atoms = ["N", "C", "CA"] + backbone_atoms # covers DNA also
if random_condition(self.p_include_oxygen_in_backbone_mask):
backbone_atoms.append("O")
is_motif_atom = is_motif_atom & np.isin(
@@ -137,11 +171,11 @@ def sample_motif_atoms(self, atom_array):
is_motif_atom = sample_motif_subgraphs(
atom_array=atom_array,
**self.subgraph_sampling_kwargs,
+ association_scheme=self.association_scheme,
)
# We also only want resolved atoms to be motif
is_motif_atom = (is_motif_atom) & (atom_array.occupancy > 0.0)
-
return is_motif_atom
def sample(self, data):
@@ -158,6 +192,7 @@ def sample(self, data):
p_fix_motif_sequence=self.p_fix_motif_sequence,
p_fix_motif_coordinates=self.p_fix_motif_coordinates,
p_unindex_motif_tokens=self.p_unindex_motif_tokens,
+ association_scheme=self.association_scheme,
)
atom_array.set_annotation(
@@ -169,7 +204,6 @@ def sample(self, data):
leak_global_index=data["conditions"]["unindex_leak_global_index"],
),
)
-
return atom_array
@@ -177,6 +211,7 @@ class PPICondition(TrainingCondition):
"""Get condition indicating what is motif and what is to be diffused for protein-protein interaction training."""
name = "ppi"
+ association_scheme = "atom14"
def is_valid_for_example(self, data):
# Extract relevant data
@@ -275,6 +310,7 @@ class SubtypeCondition(TrainingCondition):
"""
name = "subtype"
+ association_scheme = "atom14"
def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False):
self.frequency = frequency
@@ -370,6 +406,7 @@ def sample_motif_subgraphs(
hetatom_n_bond_expectation,
residue_p_fix_all,
hetatom_p_fix_all,
+ association_scheme="atom14",
):
"""
Returns a boolean mask over atoms, indicating which atoms are part of the sampled motif.
@@ -402,7 +439,17 @@ def sample_motif_subgraphs(
"n_bond_expectation": residue_n_bond_expectation,
"p_fix_all": residue_p_fix_all,
}
- if not atom_array_subset.is_protein.all():
+
+ if association_scheme == "atom23":
+ clause = (
+ atom_array_subset.is_protein.all()
+ | atom_array_subset.is_dna.all()
+ | atom_array_subset.is_rna.all()
+ )
+ else:
+ clause = atom_array_subset.is_protein.all()
+
+ if not clause:
args.update(
{
"p_seed_furthest_from_o": 0.0,
@@ -431,11 +478,14 @@ def sample_conditioning_strategy(
p_fix_motif_sequence,
p_fix_motif_coordinates,
p_unindex_motif_tokens,
+ association_scheme,
):
atom_array.set_annotation(
"is_motif_atom_with_fixed_seq",
sample_is_motif_atom_with_fixed_seq(
- atom_array, p_fix_motif_sequence=p_fix_motif_sequence
+ atom_array,
+ p_fix_motif_sequence=p_fix_motif_sequence,
+ association_scheme=association_scheme,
),
)
@@ -449,14 +499,17 @@ def sample_conditioning_strategy(
atom_array.set_annotation(
"is_motif_atom_unindexed",
sample_unindexed_atoms(
- atom_array, p_unindex_motif_tokens=p_unindex_motif_tokens
+ atom_array,
+ p_unindex_motif_tokens=p_unindex_motif_tokens,
+ association_scheme=association_scheme,
),
)
-
return atom_array
-def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence):
+def sample_is_motif_atom_with_fixed_seq(
+ atom_array, p_fix_motif_sequence, association_scheme
+):
"""
Samples what kind of conditioning to apply to motif tokens.
@@ -469,7 +522,12 @@ def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence):
is_motif_atom_with_fixed_seq = np.zeros(atom_array.array_length(), dtype=bool)
# By default reveal sequence for non-protein
- is_motif_atom_with_fixed_seq = is_motif_atom_with_fixed_seq | ~atom_array.is_protein
+
+ if not association_scheme == "atom23":
+ is_motif_atom_with_fixed_seq = (
+ is_motif_atom_with_fixed_seq | ~atom_array.is_protein
+ )
+
return is_motif_atom_with_fixed_seq
@@ -487,7 +545,9 @@ def sample_fix_motif_coordinates(atom_array, p_fix_motif_coordinates):
return is_motif_atom_with_fixed_coord
-def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens):
+def sample_unindexed_atoms(
+ atom_array, p_unindex_motif_tokens, association_scheme="atom14"
+):
"""
Samples which atoms in motif tokens should be flagged for unindexing.
@@ -498,11 +558,16 @@ def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens):
is_motif_atom_unindexed = atom_array.is_motif_atom.copy()
else:
is_motif_atom_unindexed = np.zeros(atom_array.array_length(), dtype=bool)
-
# ensure non-residue atoms are not already flagged
- is_motif_atom_unindexed = np.logical_and(
- is_motif_atom_unindexed, atom_array.is_residue
- )
+ if association_scheme == "atom23":
+ is_motif_atom_unindexed = np.logical_and(
+ is_motif_atom_unindexed,
+ (atom_array.is_residue | atom_array.is_dna | atom_array.is_rna),
+ ) # is_residue refers to is_protein here
+ else:
+ is_motif_atom_unindexed = np.logical_and(
+ is_motif_atom_unindexed, atom_array.is_residue
+ )
return is_motif_atom_unindexed
diff --git a/models/rfd3/src/rfd3/transforms/util_transforms.py b/models/rfd3/src/rfd3/transforms/util_transforms.py
index 024b1bf1..37f08f19 100644
--- a/models/rfd3/src/rfd3/transforms/util_transforms.py
+++ b/models/rfd3/src/rfd3/transforms/util_transforms.py
@@ -37,7 +37,7 @@ def assert_single_representative(token, central_atom="CB"):
mask = get_af3_token_representative_masks(token, central_atom=central_atom)
assert (
np.sum(mask) == 1
- ), f"No representative atom (CB) found. mask: {mask}\nToken: {token}"
+ ), f"No representative atom ({central_atom}) found. mask: {mask}\nToken: {token}"
def assert_single_token(token):
@@ -252,13 +252,13 @@ def get_af3_token_representative_masks(
atom_array: AtomArray, central_atom: str = "CA"
) -> np.ndarray:
pyrimidine_representative_atom = is_pyrimidine(atom_array.res_name) & (
- atom_array.atom_name == "C2"
+ atom_array.atom_name == "C1'"
)
purine_representative_atom = is_purine(atom_array.res_name) & (
- atom_array.atom_name == "C4"
+ atom_array.atom_name == "C1'"
)
unknown_na_representative_atom = is_unknown_nucleotide(atom_array.res_name) & (
- atom_array.atom_name == "C4"
+ atom_array.atom_name == "C1'"
)
glycine_representative_atom = is_glycine(atom_array.res_name) & (
diff --git a/models/rfd3/src/rfd3/transforms/virtual_atoms.py b/models/rfd3/src/rfd3/transforms/virtual_atoms.py
index ae40b3ce..d9097fea 100644
--- a/models/rfd3/src/rfd3/transforms/virtual_atoms.py
+++ b/models/rfd3/src/rfd3/transforms/virtual_atoms.py
@@ -10,8 +10,10 @@
)
from atomworks.ml.utils.token import get_token_starts
from rfd3.constants import (
- ATOM14_ATOM_NAME_TO_ELEMENT,
ATOM14_ATOM_NAMES,
+ ATOM23_ATOM_NAME_TO_ELEMENT,
+ ATOM23_ATOM_NAMES_DNA,
+ ATOM23_ATOM_NAMES_RNA,
VIRTUAL_ATOM_ELEMENT_NAME,
association_schemes,
association_schemes_stripped,
@@ -28,7 +30,9 @@
from foundry.common import exists
-def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="atom14"):
+def map_to_association_scheme(
+ atom_names: list | str, res_name: str, scheme="atom14", ATOM_NAMES=None
+):
"""
Maps a list of names to the atom14 naming scheme for that particular name (within a specific residue)
NB this function is a bit more general since it is used to handle tipatoms too.
@@ -37,16 +41,17 @@ def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="ato
raise ValueError(
f"Scheme {scheme} not found in association_schemes_stripped. Available schemes: {list(association_schemes_stripped.keys())}"
)
- atom_names = (
- [str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names
- )
+ atom_names = [atom_names] if isinstance(atom_names, str) else atom_names
idxs = np.array(
[
association_schemes_stripped[scheme][res_name].index(name)
for name in atom_names
]
)
- return ATOM14_ATOM_NAMES[idxs]
+ if ATOM_NAMES is None:
+ return ATOM14_ATOM_NAMES[idxs]
+ else:
+ return ATOM_NAMES[idxs]
def map_names_to_elements(
@@ -58,7 +63,7 @@ def map_names_to_elements(
then it returns the default value
"""
atom_names = [atom_names] if isinstance(atom_names, str) else atom_names
- elements = [ATOM14_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names]
+ elements = [ATOM23_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names]
return np.array(elements)
@@ -68,17 +73,20 @@ def generate_atom_mappings_(scheme="atom14"):
atom_mapping = {}
symmetry_mapping = {}
- for aaa, atom14_names in ccd_ordering_atomchar.items():
- mapping = list(range(14))
+ for aaa, atom_names in ccd_ordering_atomchar.items():
+ if aaa not in scheme:
+ continue
+
+ mapping = list(range(len(atom_names)))
scheme_names = scheme[aaa]
- for ccd_index in range(len(atom14_names)):
- atom14_name = atom14_names[ccd_index]
- if atom14_name is not None:
+ for ccd_index in range(len(atom_names)):
+ atom_name = atom_names[ccd_index]
+ if atom_name is not None:
assert (
- atom14_name in scheme_names
- ), f"{atom14_name} not in CCD ordering for {aaa}"
- scheme_index = scheme_names.index(atom14_name)
+ atom_name in scheme_names
+ ), f"{atom_name} not in CCD ordering for {aaa}"
+ scheme_index = scheme_names.index(atom_name)
scheme_index_in_cur_mapping = mapping.index(scheme_index)
mapping[ccd_index], mapping[scheme_index_in_cur_mapping] = (
mapping[scheme_index_in_cur_mapping],
@@ -121,6 +129,7 @@ def permute_symmetric_atom_names_(
) -> list:
# NB: Can leak GT sequence if the model receives the canconical ordering of atoms as input
# With the structure-local atom attention it will not unless N_keys(n_attn_seq_neighbours) > n_atom_attn_queries.
+
if res_name in association_map:
idx_to_swap = association_map[res_name]
atom_names = atom_names[idx_to_swap]
@@ -174,18 +183,35 @@ def forward(self, data: dict) -> dict:
), "Token ids and token level array have different lengths!"
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
- is_residue = (
- token_level_array.is_protein & ~token_level_array.atomize
- ) | is_motif_token_unindexed
-
- # Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
- is_paddable = is_residue & ~(
- is_motif_atom_with_fixed_seq | is_motif_token_unindexed
- )
- is_non_paddable_residue = is_residue & (
- is_motif_atom_with_fixed_seq | is_motif_token_unindexed
- )
+ if self.association_scheme == "atom23":
+ is_residue = (
+ token_level_array.is_protein & ~token_level_array.atomize
+ ) | is_motif_token_unindexed
+
+ is_residue_NA = (
+ (token_level_array.is_dna | token_level_array.is_rna)
+ & ~token_level_array.atomize
+ ) | is_motif_token_unindexed
+
+ # Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
+ is_paddable = (is_residue_NA | is_residue) & ~(
+ is_motif_atom_with_fixed_seq | is_motif_token_unindexed
+ )
+ is_non_paddable_residue = (is_residue_NA | is_residue) & (
+ is_motif_atom_with_fixed_seq | is_motif_token_unindexed
+ )
+ else:
+ is_residue = (
+ token_level_array.is_protein & ~token_level_array.atomize
+ ) | is_motif_token_unindexed
+ # Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
+ is_paddable = is_residue & ~(
+ is_motif_atom_with_fixed_seq | is_motif_token_unindexed
+ )
+ is_non_paddable_residue = is_residue & (
+ is_motif_atom_with_fixed_seq | is_motif_token_unindexed
+ )
# Collect virtual atoms to insert (we will insert them all at once)
virtual_atoms_to_insert = []
insert_positions = []
@@ -194,13 +220,26 @@ def forward(self, data: dict) -> dict:
for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
if is_paddable[token_id]:
token = atom_array[start:end]
+
# First, pad with virtual atoms if needed
- n_pad = self.n_atoms_per_token - len(token)
+ if self.association_scheme == "atom23" and atom_array[start].is_dna:
+ n_atoms_per_token = 22
+ central_atom = "C1'"
+ elif self.association_scheme == "atom23" and atom_array[start].is_rna:
+ n_atoms_per_token = 23
+ central_atom = "C1'"
+ else:
+ n_atoms_per_token = self.n_atoms_per_token
+ central_atom = self.atom_to_pad_from
+
+ n_pad = n_atoms_per_token - len(token)
+
if n_pad > 0:
mask = get_af3_token_representative_masks(
- token, central_atom=self.atom_to_pad_from
+ token, central_atom=central_atom
)
- assert_single_representative(token)
+
+ assert_single_representative(token, central_atom=central_atom)
# ... Create virtual atoms
pad_atoms = token[mask].copy()
@@ -263,17 +302,25 @@ def _fix_multidimensional_annotations_in_pad_array(
for token_id, (start, end) in enumerate(
zip(starts_padded[:-1], starts_padded[1:])
):
+ if atom_array_padded[start].is_dna:
+ ATOM_NAMES = ATOM23_ATOM_NAMES_DNA
+ elif atom_array_padded[start].is_rna:
+ ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
+ else:
+ ATOM_NAMES = ATOM14_ATOM_NAMES
+
if is_paddable[token_id]:
# ... Permutation of atom names during training
if not data["is_inference"] and exists(self.association_scheme):
atom_names = permute_symmetric_atom_names_(
- ATOM14_ATOM_NAMES,
+ ATOM_NAMES,
atom_array_padded.res_name[start],
association_map=self.association_map_,
symmetry_map=self.symmetry_map_,
)
else:
- atom_names = ATOM14_ATOM_NAMES
+ atom_names = ATOM_NAMES
+
atom_array_padded.atom_name[start:end] = atom_names
atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
@@ -285,7 +332,10 @@ def _fix_multidimensional_annotations_in_pad_array(
)
atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
atom_names = map_to_association_scheme(
- atom_names, res_name, scheme=self.association_scheme
+ atom_names,
+ res_name,
+ scheme=self.association_scheme,
+ ATOM_NAMES=ATOM_NAMES,
)
atom_array_padded.atom_name[start:end] = atom_names
else:
diff --git a/models/rfd3/src/rfd3/utils/inference.py b/models/rfd3/src/rfd3/utils/inference.py
index 2d04e149..1e00f566 100644
--- a/models/rfd3/src/rfd3/utils/inference.py
+++ b/models/rfd3/src/rfd3/utils/inference.py
@@ -452,7 +452,10 @@ def infer_ori_from_com(atom_array):
def set_com(
- atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None
+ atom_array,
+ ori_token: list | None = None,
+ infer_ori_strategy: str | None = None,
+ ori_jitter: float | None = None,
):
if exists(ori_token):
center = np.array([float(x) for x in ori_token], dtype=atom_array.coord.dtype)
@@ -505,6 +508,17 @@ def set_com(
atom_array.coord = np.zeros_like(
atom_array.coord, dtype=atom_array.coord.dtype
)
+ if ori_jitter is not None:
+ # randomly jitter ori with given scale
+ direction = np.random.normal(size=3)
+ direction /= np.linalg.norm(direction)
+
+ # Random length (mean ~ scale)
+ length = np.random.exponential(scale=scale)
+ jittered_offset = direction * length
+
+ atom_array.coord -= jittered_offset
+
return atom_array
diff --git a/src/foundry/utils/components.py b/src/foundry/utils/components.py
index 75bc87f3..0af9c7ea 100644
--- a/src/foundry/utils/components.py
+++ b/src/foundry/utils/components.py
@@ -96,8 +96,21 @@ def get_design_pattern_with_constraints(contig, length=None):
fixed_parts = []
pos_to_put_motif = []
+ suff = [] # suffixes for diffused regions P(optional),R,D
+
for part in contig_parts:
- if any(c.isalpha() for c in part): # Detect parts containing letters as fixed
+ ## updating to include DNA and RNA generation
+ if part[-1] in ["R", "D"]: ##Detect non-fixed RNA and DNA contig part
+ suff.append(part[-1])
+ part = part[:-1]
+ if "-" in part:
+ start, end = map(int, part.split("-"))
+ else:
+ start = end = int(part)
+ variable_ranges.append([start, end])
+ pos_to_put_motif.append(0)
+
+ elif any(c.isalpha() for c in part): # Detect parts containing letters as fixed
pn_unit_id, pn_unit_start, pn_unit_end = extract_pn_unit_info(part)
fixed_parts.append([pn_unit_id, pn_unit_start, pn_unit_end])
pos_to_put_motif.append(1)
@@ -110,6 +123,7 @@ def get_design_pattern_with_constraints(contig, length=None):
start = end = int(part)
variable_ranges.append([start, end])
pos_to_put_motif.append(0)
+ suff.append("P")
# adjust the total length to solely for free residues
num_motif_residues = sum([i[2] - i[1] + 1 for i in fixed_parts])
@@ -167,7 +181,7 @@ def get_design_pattern_with_constraints(contig, length=None):
atoms_with_motif.append(f"{pn_unit_id}{index}")
elif pos_to_put_motif[idx] == 0:
free_atom = num_free_atoms.pop(0)
- atoms_with_motif.append(free_atom)
+ atoms_with_motif.append(str(free_atom) + suff.pop(0))
elif pos_to_put_motif[idx] == 2:
atoms_with_motif.append("/0")