From e361fdeb79d416b6b76678b8a4d32c626bd5d749 Mon Sep 17 00:00:00 2001 From: RuizhiPeng Date: Sat, 27 Sep 2025 23:27:53 -0400 Subject: [PATCH 1/3] fix cyclic import issue --- rfdiffusion/inference/inference/__init__.py | 0 .../inference/inference/model_runners.py | 1054 +++++++++++++++++ rfdiffusion/inference/inference/sym_rots.npz | Bin 0 -> 7694 bytes rfdiffusion/inference/inference/symmetry.py | 236 ++++ rfdiffusion/inference/inference/utils.py | 1015 ++++++++++++++++ 5 files changed, 2305 insertions(+) create mode 100644 rfdiffusion/inference/inference/__init__.py create mode 100644 rfdiffusion/inference/inference/model_runners.py create mode 100644 rfdiffusion/inference/inference/sym_rots.npz create mode 100644 rfdiffusion/inference/inference/symmetry.py create mode 100644 rfdiffusion/inference/inference/utils.py diff --git a/rfdiffusion/inference/inference/__init__.py b/rfdiffusion/inference/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rfdiffusion/inference/inference/model_runners.py b/rfdiffusion/inference/inference/model_runners.py new file mode 100644 index 00000000..f8f45ae4 --- /dev/null +++ b/rfdiffusion/inference/inference/model_runners.py @@ -0,0 +1,1054 @@ +import torch +import numpy as np +from omegaconf import DictConfig, OmegaConf +from rfdiffusion.RoseTTAFoldModel import RoseTTAFoldModule +from rfdiffusion.kinematics import get_init_xyz, xyz_to_t2d +from rfdiffusion.diffusion import Diffuser +from rfdiffusion.chemical import seq2chars +from rfdiffusion.util_module import ComputeAllAtomCoords +from rfdiffusion.contigs import ContigMap +from rfdiffusion.inference import utils as iu, symmetry +from rfdiffusion.potentials.manager import PotentialManager +import logging +import torch.nn.functional as nn +from rfdiffusion import util +from hydra.core.hydra_config import HydraConfig +import os +import string + +from rfdiffusion.model_input_logger import pickle_function_call +import sys + +SCRIPT_DIR=os.path.dirname(os.path.realpath(__file__)) + +TOR_INDICES = util.torsion_indices +TOR_CAN_FLIP = util.torsion_can_flip +REF_ANGLES = util.reference_angles + + +class Sampler: + + def __init__(self, conf: DictConfig): + """ + Initialize sampler. + Args: + conf: Configuration. + """ + self.initialized = False + self.initialize(conf) + + def initialize(self, conf: DictConfig) -> None: + """ + Initialize sampler. + Args: + conf: Configuration + + - Selects appropriate model from input + - Assembles Config from model checkpoint and command line overrides + + """ + self._log = logging.getLogger(__name__) + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + needs_model_reload = not self.initialized or conf.inference.ckpt_override_path != self._conf.inference.ckpt_override_path + + # Assign config to Sampler + self._conf = conf + + ################################ + ### Select Appropriate Model ### + ################################ + + if conf.inference.model_directory_path is not None: + model_directory = conf.inference.model_directory_path + else: + model_directory = f"{SCRIPT_DIR}/../../models" + + print(f"Reading models from {model_directory}") + + # Initialize inference only helper objects to Sampler + if conf.inference.ckpt_override_path is not None: + self.ckpt_path = conf.inference.ckpt_override_path + print("WARNING: You're overriding the checkpoint path from the defaults. Check that the model you're providing can run with the inputs you're providing.") + else: + if conf.contigmap.inpaint_seq is not None or conf.contigmap.provide_seq is not None or conf.contigmap.inpaint_str: + # use model trained for inpaint_seq + if conf.contigmap.provide_seq is not None: + # this is only used for partial diffusion + assert conf.diffuser.partial_T is not None, "The provide_seq input is specifically for partial diffusion" + if conf.scaffoldguided.scaffoldguided: + self.ckpt_path = f'{model_directory}/InpaintSeq_Fold_ckpt.pt' + else: + self.ckpt_path = f'{model_directory}/InpaintSeq_ckpt.pt' + elif conf.ppi.hotspot_res is not None and conf.scaffoldguided.scaffoldguided is False: + # use complex trained model + self.ckpt_path = f'{model_directory}/Complex_base_ckpt.pt' + elif conf.scaffoldguided.scaffoldguided is True: + # use complex and secondary structure-guided model + self.ckpt_path = f'{model_directory}/Complex_Fold_base_ckpt.pt' + else: + # use default model + self.ckpt_path = f'{model_directory}/Base_ckpt.pt' + # for saving in trb file: + assert self._conf.inference.trb_save_ckpt_path is None, "trb_save_ckpt_path is not the place to specify an input model. Specify in inference.ckpt_override_path" + self._conf['inference']['trb_save_ckpt_path']=self.ckpt_path + + ####################### + ### Assemble Config ### + ####################### + + if needs_model_reload: + # Load checkpoint, so that we can assemble the config + self.load_checkpoint() + self.assemble_config_from_chk() + # Now actually load the model weights into RF + self.model = self.load_model() + else: + self.assemble_config_from_chk() + + # self.initialize_sampler(conf) + self.initialized=True + + # Initialize helper objects + self.inf_conf = self._conf.inference + self.contig_conf = self._conf.contigmap + self.denoiser_conf = self._conf.denoiser + self.ppi_conf = self._conf.ppi + self.potential_conf = self._conf.potentials + self.diffuser_conf = self._conf.diffuser + self.preprocess_conf = self._conf.preprocess + + if conf.inference.schedule_directory_path is not None: + schedule_directory = conf.inference.schedule_directory_path + else: + schedule_directory = f"{SCRIPT_DIR}/../../schedules" + + # Check for cache schedule + if not os.path.exists(schedule_directory): + os.mkdir(schedule_directory) + self.diffuser = Diffuser(**self._conf.diffuser, cache_dir=schedule_directory) + + ########################### + ### Initialise Symmetry ### + ########################### + + if self.inf_conf.symmetry is not None: + self.symmetry = symmetry.SymGen( + self.inf_conf.symmetry, + self.inf_conf.recenter, + self.inf_conf.radius, + self.inf_conf.model_only_neighbors, + ) + else: + self.symmetry = None + + self.allatom = ComputeAllAtomCoords().to(self.device) + + if self.inf_conf.input_pdb is None: + # set default pdb + script_dir=os.path.dirname(os.path.realpath(__file__)) + self.inf_conf.input_pdb=os.path.join(script_dir, '../../examples/input_pdbs/1qys.pdb') + self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) + self.chain_idx = None + self.idx_pdb = None + + ############################## + ### Handle Partial Noising ### + ############################## + + if self.diffuser_conf.partial_T: + assert self.diffuser_conf.partial_T <= self.diffuser_conf.T + self.t_step_input = int(self.diffuser_conf.partial_T) + else: + self.t_step_input = int(self.diffuser_conf.T) + + @property + def T(self): + ''' + Return the maximum number of timesteps + that this design protocol will perform. + + Output: + T (int): The maximum number of timesteps to perform + ''' + return self.diffuser_conf.T + + def load_checkpoint(self) -> None: + """Loads RF checkpoint, from which config can be generated.""" + self._log.info(f'Reading checkpoint from {self.ckpt_path}') + print('This is inf_conf.ckpt_path') + print(self.ckpt_path) + self.ckpt = torch.load( + self.ckpt_path, map_location=self.device) + + def assemble_config_from_chk(self) -> None: + """ + Function for loading model config from checkpoint directly. + + Takes: + - config file + + Actions: + - Replaces all -model and -diffuser items + - Throws a warning if there are items in -model and -diffuser that aren't in the checkpoint + + This throws an error if there is a flag in the checkpoint 'config_dict' that isn't in the inference config. + This should ensure that whenever a feature is added in the training setup, it is accounted for in the inference script. + + """ + # get overrides to re-apply after building the config from the checkpoint + overrides = [] + if HydraConfig.initialized(): + overrides = HydraConfig.get().overrides.task + print("Assembling -model, -diffuser and -preprocess configs from checkpoint") + + for cat in ['model','diffuser','preprocess']: + for key in self._conf[cat]: + try: + print(f"USING MODEL CONFIG: self._conf[{cat}][{key}] = {self.ckpt['config_dict'][cat][key]}") + self._conf[cat][key] = self.ckpt['config_dict'][cat][key] + except: + pass + + # add overrides back in again + for override in overrides: + if override.split(".")[0] in ['model','diffuser','preprocess']: + print(f'WARNING: You are changing {override.split("=")[0]} from the value this model was trained with. Are you sure you know what you are doing?') + mytype = type(self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]]) + self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]] = mytype(override.split("=")[1]) + + def load_model(self): + """Create RosettaFold model from preloaded checkpoint.""" + + # Read input dimensions from checkpoint. + self.d_t1d=self._conf.preprocess.d_t1d + self.d_t2d=self._conf.preprocess.d_t2d + model = RoseTTAFoldModule(**self._conf.model, d_t1d=self.d_t1d, d_t2d=self.d_t2d, T=self._conf.diffuser.T).to(self.device) + if self._conf.logging.inputs: + pickle_dir = pickle_function_call(model, 'forward', 'inference') + print(f'pickle_dir: {pickle_dir}') + model = model.eval() + self._log.info(f'Loading checkpoint.') + model.load_state_dict(self.ckpt['model_state_dict'], strict=True) + return model + + def construct_contig(self, target_feats): + """ + Construct contig class describing the protein to be generated + """ + self._log.info(f'Using contig: {self.contig_conf.contigs}') + return ContigMap(target_feats, **self.contig_conf) + + def construct_denoiser(self, L, visible): + """Make length-specific denoiser.""" + denoise_kwargs = OmegaConf.to_container(self.diffuser_conf) + denoise_kwargs.update(OmegaConf.to_container(self.denoiser_conf)) + denoise_kwargs.update({ + 'L': L, + 'diffuser': self.diffuser, + 'potential_manager': self.potential_manager, + }) + return iu.Denoise(**denoise_kwargs) + + def sample_init(self, return_forward_trajectory=False): + """ + Initial features to start the sampling process. + + Modify signature and function body for different initialization + based on the config. + + Returns: + xt: Starting positions with a portion of them randomly sampled. + seq_t: Starting sequence with a portion of them set to unknown. + """ + + ####################### + ### Parse input pdb ### + ####################### + + self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) + + ################################ + ### Generate specific contig ### + ################################ + + # Generate a specific contig from the range of possibilities specified at input + + self.contig_map = self.construct_contig(self.target_feats) + self.mappings = self.contig_map.get_mappings() + self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:] + self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:] + self.binderlen = len(self.contig_map.inpaint) + + ####################################### + ### Resolve cyclic peptide indicies ### + ####################################### + if self._conf.inference.cyclic: + if self._conf.inference.cyc_chains is None: + # default to all residues being cyclized + self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() + else: + # use cyc_chains arg to determine cyclic_reses mask + assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string' + cyc_chains = self._conf.inference.cyc_chains + cyc_chains = [i.upper() for i in cyc_chains] + hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains + is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty + + for ch in cyc_chains: + ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() + is_cyclized[ch_mask] = True # set this whole chain to be cyclic + self.cyclic_reses = is_cyclized + else: + self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() + + #################### + ### Get Hotspots ### + #################### + + self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) + + + ##################################### + ### Initialise Potentials Manager ### + ##################################### + + self.potential_manager = PotentialManager(self.potential_conf, + self.ppi_conf, + self.diffuser_conf, + self.inf_conf, + self.hotspot_0idx, + self.binderlen) + + ################################### + ### Initialize other attributes ### + ################################### + + xyz_27 = self.target_feats['xyz_27'] + mask_27 = self.target_feats['mask_27'] + seq_orig = self.target_feats['seq'] + L_mapped = len(self.contig_map.ref) + contig_map=self.contig_map + + self.diffusion_mask = self.mask_str + length_bound = self.contig_map.sampled_mask_length_bound.copy() + + first_res = 0 + self.chain_idx = [] + self.idx_pdb = [] + all_chains = {contig_ref[0] for contig_ref in self.contig_map.ref} + available_chains = sorted(list(set(string.ascii_letters) - all_chains)) + + # Iterate over each chain + for last_res in length_bound: + chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} + # If we are designing this chain, it will have a '-' in the contig map + # Renumber this chain from 1 + if "_" in chain_ids: + self.idx_pdb += [idx + 1 for idx in range(last_res - first_res)] + chain_ids = chain_ids - {"_"} + # If there are no fixed residues that have a chain id, pick the first available letter + if not chain_ids: + if not available_chains: + raise ValueError(f"No available chains! You are trying to design a new chain, and you have " + f"already used all upper- and lower-case chain ids (up to 52 chains): " + f"{','.join(all_chains)}.") + chain_id = available_chains[0] + available_chains.remove(chain_id) + # Otherwise, use the chain of the fixed (motif) residues + else: + assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" + chain_id = list(chain_ids)[0] + self.chain_idx += [chain_id] * (last_res - first_res) + # If this is a fixed chain, maintain the chain and residue numbering + else: + self.idx_pdb += [contig_ref[1] for contig_ref in self.contig_map.ref[first_res: last_res]] + assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" + self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res) + first_res = last_res + + #################################### + ### Generate initial coordinates ### + #################################### + + if self.diffuser_conf.partial_T: + assert xyz_27.shape[0] == L_mapped, f"there must be a coordinate in the input PDB for \ + each residue implied by the contig string for partial diffusion. length of \ + input PDB != length of contig string: {xyz_27.shape[0]} != {L_mapped}" + assert contig_map.hal_idx0 == contig_map.ref_idx0, f'for partial diffusion there can \ + be no offset between the index of a residue in the input and the index of the \ + residue in the output, {contig_map.hal_idx0} != {contig_map.ref_idx0}' + # Partially diffusing from a known structure + xyz_mapped=xyz_27 + atom_mask_mapped = mask_27 + else: + # Fully diffusing from points initialised at the origin + # adjust size of input xt according to residue map + xyz_mapped = torch.full((1,1,L_mapped,27,3), np.nan) + xyz_mapped[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...] + xyz_motif_prealign = xyz_mapped.clone() + motif_prealign_com = xyz_motif_prealign[0,0,:,1].mean(dim=0) + self.motif_com = xyz_27[contig_map.ref_idx0,1].mean(dim=0) + xyz_mapped = get_init_xyz(xyz_mapped).squeeze() + # adjust the size of the input atom map + atom_mask_mapped = torch.full((L_mapped, 27), False) + atom_mask_mapped[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] + + # Diffuse the contig-mapped coordinates + if self.diffuser_conf.partial_T: + assert self.diffuser_conf.partial_T <= self.diffuser_conf.T, "Partial_T must be less than T" + self.t_step_input = int(self.diffuser_conf.partial_T) + else: + self.t_step_input = int(self.diffuser_conf.T) + t_list = np.arange(1, self.t_step_input+1) + + ################################# + ### Generate initial sequence ### + ################################# + + seq_t = torch.full((1,L_mapped), 21).squeeze() # 21 is the mask token + seq_t[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0] + + # Unmask sequence if desired + if self._conf.contigmap.provide_seq is not None: + seq_t[self.mask_seq.squeeze()] = seq_orig[self.mask_seq.squeeze()] + + seq_t[~self.mask_seq.squeeze()] = 21 + seq_t = torch.nn.functional.one_hot(seq_t, num_classes=22).float() # [L,22] + seq_orig = torch.nn.functional.one_hot(seq_orig, num_classes=22).float() # [L,22] + + fa_stack, xyz_true = self.diffuser.diffuse_pose( + xyz_mapped, + torch.clone(seq_t), + atom_mask_mapped.squeeze(), + diffusion_mask=self.diffusion_mask.squeeze(), + t_list=t_list) + xT = fa_stack[-1].squeeze()[:,:14,:] + xt = torch.clone(xT) + + self.denoiser = self.construct_denoiser(len(self.contig_map.ref), visible=self.mask_seq.squeeze()) + + ###################### + ### Apply Symmetry ### + ###################### + + if self.symmetry is not None: + xt, seq_t = self.symmetry.apply_symmetry(xt, seq_t) + self._log.info(f'Sequence init: {seq2chars(torch.argmax(seq_t, dim=-1))}') + + self.msa_prev = None + self.pair_prev = None + self.state_prev = None + + ######################################### + ### Parse ligand for ligand potential ### + ######################################### + + if self.potential_conf.guiding_potentials is not None: + if any(list(filter(lambda x: "substrate_contacts" in x, self.potential_conf.guiding_potentials))): + assert len(self.target_feats['xyz_het']) > 0, "If you're using the Substrate Contact potential, \ + you need to make sure there's a ligand in the input_pdb file!" + het_names = np.array([i['name'].strip() for i in self.target_feats['info_het']]) + xyz_het = self.target_feats['xyz_het'][het_names == self._conf.potentials.substrate] + xyz_het = torch.from_numpy(xyz_het) + assert xyz_het.shape[0] > 0, f'expected >0 heteroatoms from ligand with name {self._conf.potentials.substrate}' + xyz_motif_prealign = xyz_motif_prealign[0,0][self.diffusion_mask.squeeze()] + motif_prealign_com = xyz_motif_prealign[:,1].mean(dim=0) + xyz_het_com = xyz_het.mean(dim=0) + for pot in self.potential_manager.potentials_to_apply: + pot.motif_substrate_atoms = xyz_het + pot.diffusion_mask = self.diffusion_mask.squeeze() + pot.xyz_motif = xyz_motif_prealign + pot.diffuser = self.diffuser + return xt, seq_t + + def _preprocess(self, seq, xyz_t, t, repack=False): + + """ + Function to prepare inputs to diffusion model + + seq (L,22) one-hot sequence + + msa_masked (1,1,L,48) + + msa_full (1,1,L,25) + + xyz_t (L,14,3) template crds (diffused) + + t1d (1,L,28) this is the t1d before tacking on the chi angles: + - seq + unknown/mask (21) + - global timestep (1-t/T if not motif else 1) (1) + + MODEL SPECIFIC: + - contacting residues: for ppi. Target residues in contact with binder (1) + - empty feature (legacy) (1) + - ss (H, E, L, MASK) (4) + + t2d (1, L, L, 45) + - last plane is block adjacency + """ + + L = seq.shape[0] + T = self.T + binderlen = self.binderlen + target_res = self.ppi_conf.hotspot_res + + ################## + ### msa_masked ### + ################## + msa_masked = torch.zeros((1,1,L,48)) + msa_masked[:,:,:,:22] = seq[None, None] + msa_masked[:,:,:,22:44] = seq[None, None] + msa_masked[:,:,0,46] = 1.0 + msa_masked[:,:,-1,47] = 1.0 + + ################ + ### msa_full ### + ################ + msa_full = torch.zeros((1,1,L,25)) + msa_full[:,:,:,:22] = seq[None, None] + msa_full[:,:,0,23] = 1.0 + msa_full[:,:,-1,24] = 1.0 + + ########### + ### t1d ### + ########### + + # Here we need to go from one hot with 22 classes to one hot with 21 classes (last plane is missing token) + t1d = torch.zeros((1,1,L,21)) + + seqt1d = torch.clone(seq) + for idx in range(L): + if seqt1d[idx,21] == 1: + seqt1d[idx,20] = 1 + seqt1d[idx,21] = 0 + + t1d[:,:,:,:21] = seqt1d[None,None,:,:21] + + + # Set timestep feature to 1 where diffusion mask is True, else 1-t/T + timefeature = torch.zeros((L)).float() + timefeature[self.mask_str.squeeze()] = 1 + timefeature[~self.mask_str.squeeze()] = 1 - t/self.T + timefeature = timefeature[None,None,...,None] + + t1d = torch.cat((t1d, timefeature), dim=-1).float() + + ############# + ### xyz_t ### + ############# + if self.preprocess_conf.sidechain_input: + xyz_t[torch.where(seq == 21, True, False),3:,:] = float('nan') + else: + xyz_t[~self.mask_str.squeeze(),3:,:] = float('nan') + + xyz_t=xyz_t[None, None] + xyz_t = torch.cat((xyz_t, torch.full((1,1,L,13,3), float('nan'))), dim=3) + + ########### + ### t2d ### + ########### + t2d = xyz_to_t2d(xyz_t) + + ########### + ### idx ### + ########### + idx = torch.tensor(self.contig_map.rf)[None] + + ############### + ### alpha_t ### + ############### + seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L) + alpha, _, alpha_mask, _ = util.get_torsions(xyz_t.reshape(-1, L, 27, 3), seq_tmp, TOR_INDICES, TOR_CAN_FLIP, REF_ANGLES) + alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) + alpha[torch.isnan(alpha)] = 0.0 + alpha = alpha.reshape(1,-1,L,10,2) + alpha_mask = alpha_mask.reshape(1,-1,L,10,1) + alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, L, 30) + + #put tensors on device + msa_masked = msa_masked.to(self.device) + msa_full = msa_full.to(self.device) + seq = seq.to(self.device) + xyz_t = xyz_t.to(self.device) + idx = idx.to(self.device) + t1d = t1d.to(self.device) + t2d = t2d.to(self.device) + alpha_t = alpha_t.to(self.device) + + ###################### + ### added_features ### + ###################### + if self.preprocess_conf.d_t1d >= 24: # add hotspot residues + hotspot_tens = torch.zeros(L).float() + if self.ppi_conf.hotspot_res is None: + print("WARNING: you're using a model trained on complexes and hotspot residues, without specifying hotspots.\ + If you're doing monomer diffusion this is fine") + hotspot_idx=[] + else: + hotspots = [(i[0],int(i[1:])) for i in self.ppi_conf.hotspot_res] + hotspot_idx=[] + for i,res in enumerate(self.contig_map.con_ref_pdb_idx): + if res in hotspots: + hotspot_idx.append(self.contig_map.hal_idx0[i]) + hotspot_tens[hotspot_idx] = 1.0 + + # Add blank (legacy) feature and hotspot tensor + t1d=torch.cat((t1d, torch.zeros_like(t1d[...,:1]), hotspot_tens[None,None,...,None].to(self.device)), dim=-1) + + return msa_masked, msa_full, seq[None], torch.squeeze(xyz_t, dim=0), idx, t1d, t2d, xyz_t, alpha_t + + def sample_step(self, *, t, x_t, seq_init, final_step): + '''Generate the next pose that the model should be supplied at timestep t-1. + + Args: + t (int): The timestep that has just been predicted + seq_t (torch.tensor): (L,22) The sequence at the beginning of this timestep + x_t (torch.tensor): (L,14,3) The residue positions at the beginning of this timestep + seq_init (torch.tensor): (L,22) The initialized sequence used in updating the sequence. + + Returns: + px0: (L,14,3) The model's prediction of x0. + x_t_1: (L,14,3) The updated positions of the next step. + seq_t_1: (L,22) The updated sequence of the next step. + tors_t_1: (L, ?) The updated torsion angles of the next step. + plddt: (L, 1) Predicted lDDT of x0. + ''' + msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess( + seq_init, x_t, t) + + N,L = msa_masked.shape[:2] + + if self.symmetry is not None: + idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb) + + msa_prev = None + pair_prev = None + state_prev = None + + with torch.no_grad(): + msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked, + msa_full, + seq_in, + xt_in, + idx_pdb, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev = msa_prev, + pair_prev = pair_prev, + state_prev = state_prev, + t=torch.tensor(t), + return_infer=True, + motif_mask=self.diffusion_mask.squeeze().to(self.device)) + + # prediction of X0 + _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) + px0 = px0.squeeze()[:,:14] + + ##################### + ### Get next pose ### + ##################### + + if t > final_step: + seq_t_1 = nn.one_hot(seq_init,num_classes=22).to(self.device) + x_t_1, px0 = self.denoiser.get_next_pose( + xt=x_t, + px0=px0, + t=t, + diffusion_mask=self.mask_str.squeeze(), + align_motif=self.inf_conf.align_motif + ) + else: + x_t_1 = torch.clone(px0).to(x_t.device) + seq_t_1 = torch.clone(seq_init) + px0 = px0.to(x_t.device) + + if self.symmetry is not None: + x_t_1, seq_t_1 = self.symmetry.apply_symmetry(x_t_1, seq_t_1) + + return px0, x_t_1, seq_t_1, plddt + + +class SelfConditioning(Sampler): + """ + Model Runner for self conditioning + pX0[t+1] is provided as a template input to the model at time t + """ + + def sample_step(self, *, t, x_t, seq_init, final_step): + ''' + Generate the next pose that the model should be supplied at timestep t-1. + Args: + t (int): The timestep that has just been predicted + seq_t (torch.tensor): (L,22) The sequence at the beginning of this timestep + x_t (torch.tensor): (L,14,3) The residue positions at the beginning of this timestep + seq_init (torch.tensor): (L,22) The initialized sequence used in updating the sequence. + Returns: + px0: (L,14,3) The model's prediction of x0. + x_t_1: (L,14,3) The updated positions of the next step. + seq_t_1: (L) The sequence to the next step (== seq_init) + plddt: (L, 1) Predicted lDDT of x0. + ''' + + msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess( + seq_init, x_t, t) + B,N,L = xyz_t.shape[:3] + + ################################## + ######## Str Self Cond ########### + ################################## + if (t < self.diffuser.T) and (t != self.diffuser_conf.partial_T): + zeros = torch.zeros(B,1,L,24,3).float().to(xyz_t.device) + xyz_t = torch.cat((self.prev_pred.unsqueeze(1),zeros), dim=-2) # [B,T,L,27,3] + t2d_44 = xyz_to_t2d(xyz_t) # [B,T,L,L,44] + else: + xyz_t = torch.zeros_like(xyz_t) + t2d_44 = torch.zeros_like(t2d[...,:44]) + # No effect if t2d is only dim 44 + t2d[...,:44] = t2d_44 + + if self.symmetry is not None: + idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb) + + #################### + ### Forward Pass ### + #################### + + with torch.no_grad(): + msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked, + msa_full, + seq_in, + xt_in, + idx_pdb, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev = None, + pair_prev = None, + state_prev = None, + t=torch.tensor(t), + return_infer=True, + motif_mask=self.diffusion_mask.squeeze().to(self.device), + cyclic_reses=self.cyclic_reses) + + if self.symmetry is not None and self.inf_conf.symmetric_self_cond: + px0 = self.symmetrise_prev_pred(px0=px0,seq_in=seq_in, alpha=alpha)[:,:,:3] + + self.prev_pred = torch.clone(px0) + + # prediction of X0 + _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) + px0 = px0.squeeze()[:,:14] + + ########################### + ### Generate Next Input ### + ########################### + + seq_t_1 = torch.clone(seq_init) + if t > final_step: + x_t_1, px0 = self.denoiser.get_next_pose( + xt=x_t, + px0=px0, + t=t, + diffusion_mask=self.mask_str.squeeze(), + align_motif=self.inf_conf.align_motif, + include_motif_sidechains=self.preprocess_conf.motif_sidechain_input + ) + self._log.info( + f'Timestep {t}, input to next step: { seq2chars(torch.argmax(seq_t_1, dim=-1).tolist())}') + else: + x_t_1 = torch.clone(px0).to(x_t.device) + px0 = px0.to(x_t.device) + + ###################### + ### Apply symmetry ### + ###################### + + if self.symmetry is not None: + x_t_1, seq_t_1 = self.symmetry.apply_symmetry(x_t_1, seq_t_1) + + return px0, x_t_1, seq_t_1, plddt + + def symmetrise_prev_pred(self, px0, seq_in, alpha): + """ + Method for symmetrising px0 output for self-conditioning + """ + _,px0_aa = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) + px0_sym,_ = self.symmetry.apply_symmetry(px0_aa.to('cpu').squeeze()[:,:14], torch.argmax(seq_in, dim=-1).squeeze().to('cpu')) + px0_sym = px0_sym[None].to(self.device) + return px0_sym + +class ScaffoldedSampler(SelfConditioning): + """ + Model Runner for Scaffold-Constrained diffusion + """ + def __init__(self, conf: DictConfig): + """ + Initialize scaffolded sampler. + Two basic approaches here: + i) Given a block adjacency/secondary structure input, generate a fold (in the presence or absence of a target) + - This allows easy generation of binders or specific folds + - Allows simple expansion of an input, to sample different lengths + ii) Providing a contig input and corresponding block adjacency/secondary structure input + - This allows mixed motif scaffolding and fold-conditioning. + - Adjacency/secondary structure inputs must correspond exactly in length to the contig string + """ + super().__init__(conf) + # initialize BlockAdjacency sampling class + if conf.scaffoldguided.scaffold_dir is None: + assert any(x is not None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop)) + if conf.contigmap.inpaint_str_loop is not None: + assert conf.scaffoldguided.mask_loops == False, "You shouldn't be masking loops if you're specifying loop secondary structure" + else: + # initialize BlockAdjacency sampling class + assert all(x is None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop)), "can't provide scaffold_dir if you're also specifying per-residue ss" + self.blockadjacency = iu.BlockAdjacency(conf.scaffoldguided, conf.inference.num_designs) + + + ################################################# + ### Initialize target, if doing binder design ### + ################################################# + + if conf.scaffoldguided.target_pdb: + self.target = iu.Target(conf.scaffoldguided, conf.ppi.hotspot_res) + self.target_pdb = self.target.get_target() + if conf.scaffoldguided.target_ss is not None: + self.target_ss = torch.load(conf.scaffoldguided.target_ss).long() + self.target_ss = torch.nn.functional.one_hot(self.target_ss, num_classes=4) + if self._conf.scaffoldguided.contig_crop is not None: + self.target_ss=self.target_ss[self.target_pdb['crop_mask']] + if conf.scaffoldguided.target_adj is not None: + self.target_adj = torch.load(conf.scaffoldguided.target_adj).long() + self.target_adj=torch.nn.functional.one_hot(self.target_adj, num_classes=3) + if self._conf.scaffoldguided.contig_crop is not None: + self.target_adj=self.target_adj[self.target_pdb['crop_mask']] + self.target_adj=self.target_adj[:,self.target_pdb['crop_mask']] + else: + self.target = None + self.target_pdb=False + + def sample_init(self): + """ + Wrapper method for taking secondary structure + adj, and outputting xt, seq_t + """ + + ########################## + ### Process Fold Input ### + ########################## + if hasattr(self, 'blockadjacency'): + self.L, self.ss, self.adj = self.blockadjacency.get_scaffold() + self.adj = nn.one_hot(self.adj.long(), num_classes=3) + else: + self.L=100 # shim. Get's overwritten + + ############################## + ### Auto-contig generation ### + ############################## + + if self.contig_conf.contigs is None: + # process target + xT = torch.full((self.L, 27,3), np.nan) + xT = get_init_xyz(xT[None,None]).squeeze() + seq_T = torch.full((self.L,),21) + self.diffusion_mask = torch.full((self.L,),False) + atom_mask = torch.full((self.L,27), False) + self.binderlen=self.L + + if self.target: + target_L = np.shape(self.target_pdb['xyz'])[0] + # xyz + target_xyz = torch.full((target_L, 27, 3), np.nan) + target_xyz[:,:14,:] = torch.from_numpy(self.target_pdb['xyz']) + xT = torch.cat((xT, target_xyz), dim=0) + # seq + seq_T = torch.cat((seq_T, torch.from_numpy(self.target_pdb['seq'])), dim=0) + # diffusion mask + self.diffusion_mask = torch.cat((self.diffusion_mask, torch.full((target_L,), True)),dim=0) + # atom mask + mask_27 = torch.full((target_L, 27), False) + mask_27[:,:14] = torch.from_numpy(self.target_pdb['mask']) + atom_mask = torch.cat((atom_mask, mask_27), dim=0) + self.L += target_L + # generate contigmap object + contig = [] + for idx,i in enumerate(self.target_pdb['pdb_idx'][:-1]): + if idx==0: + start=i[1] + if i[1] + 1 != self.target_pdb['pdb_idx'][idx+1][1] or i[0] != self.target_pdb['pdb_idx'][idx+1][0]: + contig.append(f'{i[0]}{start}-{i[1]}/0 ') + start = self.target_pdb['pdb_idx'][idx+1][1] + contig.append(f"{self.target_pdb['pdb_idx'][-1][0]}{start}-{self.target_pdb['pdb_idx'][-1][1]}/0 ") + contig.append(f"{self.binderlen}-{self.binderlen}") + contig = ["".join(contig)] + else: + contig = [f"{self.binderlen}-{self.binderlen}"] + self.contig_map=ContigMap(self.target_pdb, contig) + self.mappings = self.contig_map.get_mappings() + self.mask_seq = self.diffusion_mask + self.mask_str = self.diffusion_mask + L_mapped=len(self.contig_map.ref) + + ############################ + ### Specific Contig mode ### + ############################ + + else: + # get contigmap from command line + assert self.target is None, "Giving a target is the wrong way of handling this is you're doing contigs and secondary structure" + + # process target and reinitialise potential_manager. This is here because the 'target' is always set up to be the second chain in out inputs. + self.target_feats = iu.process_target(self.inf_conf.input_pdb) + self.contig_map = self.construct_contig(self.target_feats) + self.mappings = self.contig_map.get_mappings() + self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:] + self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:] + self.binderlen = len(self.contig_map.inpaint) + self.L = len(self.contig_map.inpaint_seq) + target_feats = self.target_feats + contig_map = self.contig_map + + xyz_27 = target_feats['xyz_27'] + mask_27 = target_feats['mask_27'] + seq_orig = target_feats['seq'] + L_mapped = len(self.contig_map.ref) + seq_T=torch.full((L_mapped,),21) + seq_T[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0] + seq_T[~self.mask_seq.squeeze()] = 21 + + diffusion_mask = self.mask_str + self.diffusion_mask = diffusion_mask + + xT = torch.full((1,1,L_mapped,27,3), np.nan) + xT[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...] + xT = get_init_xyz(xT).squeeze() + atom_mask = torch.full((L_mapped, 27), False) + atom_mask[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] + + if hasattr(self.contig_map, 'ss_spec'): + self.adj=torch.full((L_mapped, L_mapped),2) # masked + self.adj=nn.one_hot(self.adj.long(), num_classes=3) + self.ss=iu.ss_from_contig(self.contig_map.ss_spec) + assert L_mapped==self.adj.shape[0] + + #################### + ### Get hotspots ### + #################### + self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) + + ######################### + ### Set up potentials ### + ######################### + + self.potential_manager = PotentialManager(self.potential_conf, + self.ppi_conf, + self.diffuser_conf, + self.inf_conf, + self.hotspot_0idx, + self.binderlen) + + self.chain_idx=['A' if i < self.binderlen else 'B' for i in range(self.L)] + + ######################## + ### Handle Partial T ### + ######################## + + if self.diffuser_conf.partial_T: + assert self.diffuser_conf.partial_T <= self.diffuser_conf.T + self.t_step_input = int(self.diffuser_conf.partial_T) + else: + self.t_step_input = int(self.diffuser_conf.T) + t_list = np.arange(1, self.t_step_input+1) + seq_T=torch.nn.functional.one_hot(seq_T, num_classes=22).float() + + fa_stack, xyz_true = self.diffuser.diffuse_pose( + xT, + torch.clone(seq_T), + atom_mask.squeeze(), + diffusion_mask=self.diffusion_mask.squeeze(), + t_list=t_list, + include_motif_sidechains=self.preprocess_conf.motif_sidechain_input) + + ####################### + ### Set up Denoiser ### + ####################### + + self.denoiser = self.construct_denoiser(self.L, visible=self.mask_seq.squeeze()) + + ####################################### + ### Resolve cyclic peptide indicies ### + ####################################### + if self._conf.inference.cyclic: + if self._conf.inference.cyc_chains is None: + # default to all residues being cyclized + self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() + else: + # use cyc_chains arg to determine cyclic_reses mask + assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string' + cyc_chains = self._conf.inference.cyc_chains + cyc_chains = [i.upper() for i in cyc_chains] + hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains + is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty + + for ch in cyc_chains: + ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() + is_cyclized[ch_mask] = True # set this whole chain to be cyclic + self.cyclic_reses = is_cyclized + else: + self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() + + xT = torch.clone(fa_stack[-1].squeeze()[:,:14,:]) + return xT, seq_T + + def _preprocess(self, seq, xyz_t, t): + msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t = super()._preprocess(seq, xyz_t, t, repack=False) + + ################################### + ### Add Adj/Secondary Structure ### + ################################### + + assert self.preprocess_conf.d_t1d == 28, "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" + assert self.preprocess_conf.d_t2d == 47, "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" + + ##################### + ### Handle Target ### + ##################### + + if self.target: + blank_ss = torch.nn.functional.one_hot(torch.full((self.L-self.binderlen,), 3), num_classes=4) + full_ss = torch.cat((self.ss, blank_ss), dim=0) + if self._conf.scaffoldguided.target_ss is not None: + full_ss[self.binderlen:] = self.target_ss + else: + full_ss = self.ss + t1d=torch.cat((t1d, full_ss[None,None].to(self.device)), dim=-1) + + t1d = t1d.float() + + ########### + ### t2d ### + ########### + + if self.d_t2d == 47: + if self.target: + full_adj = torch.zeros((self.L, self.L, 3)) + full_adj[:,:,-1] = 1. #set to mask + full_adj[:self.binderlen, :self.binderlen] = self.adj + if self._conf.scaffoldguided.target_adj is not None: + full_adj[self.binderlen:,self.binderlen:] = self.target_adj + else: + full_adj = self.adj + t2d=torch.cat((t2d, full_adj[None,None].to(self.device)),dim=-1) + + ########### + ### idx ### + ########### + + if self.target: + idx_pdb[:,self.binderlen:] += 200 + + return msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t diff --git a/rfdiffusion/inference/inference/sym_rots.npz b/rfdiffusion/inference/inference/sym_rots.npz new file mode 100644 index 0000000000000000000000000000000000000000..8e6b38011a547d085099c0be682733f3aa573b1b GIT binary patch literal 7694 zcmb`ML5o~P6osqps1PA4%r?+jwZb4eifBk^agoiS^u~oNZDu+|2$`XKP=g5XAGmN4 zTo`cW!hN=Hmp|Yi7;xpv`3ZKV?ssPTl*JoN-lY4!Tj$(!?yc$9uQNU0A79x?<+zc4 zdi?17({W1q@p{@xC(D!7;@KY#wkXY(&d>D%VYmzPC7DZt}O@+uNPIzxns}?&QbG z|2nRgF3x5pww}-WK)U~C2nP-xxRTn2sJajAE*41yUhpfV7arwATbh+TIB^@jkVT#PC+GT&ywyKS@}ZvBCxihVucwUI#6y?< z#iP(7JX=o8Cmdbdhe3V#VIVyWLl-}m(IqcHvS`mE}}Z9e9x{7qLp)1e>jtL>9;uZtXA@oX;k zN4V92qtoMWI_w8Hx%xI;{9IRcjr{?y>U`fDe$?YnjvhIE!xPT@=v$xE^?@sgcgcOj z4;|(Bn+`nst>&O#tD`@!>+!KUD&59h-VeFvFkSh1fAAXpRJsHG$jA4?bNov5XYPwc zKBy>0L4_mKl@sot;(=R)XUpgR!FHf43+NKnU*MEU-_{+M4nQNKM-;8o*|U?TUhu%B zWA&)DaO&Y{{#c-nD4v}N?*g!hPm8B@)1{w5Jg!gi%wIaeBX>J}9^*O;e)DsbaIQnH zPrF|F;?(?H2hRHk{(k-92d*4He8>m-nSn!=Rg=dk`Xzt0KJ zbnwTg@;4pvOouw=QXahKdg!x{@T#0Xe17^c9q$ue^<_Hv3HQG7u{!zCA9am>;G;Pke$4AU>cPq3;iugCq)t3= z+b7qhZWjBdoPDu5;K|4KoBq&&=ecFL{aG{5Pz=2G9L!(8&AKjyHx z;K|4R>DTM#`=gh3#;xyAi(7u~M)}#kQT5{O?^kj`Il9q}E6KiX@!IXb4N;kme|s(Y WUyA1YFTav0bnk', coords_out[:subunit_len], self.sym_rots[i]) + seq_out[start_i:end_i] = seq_out[:subunit_len] + return coords_out, seq_out + + def _lin_chainbreaks(self, num_breaks, res_idx, offset=None): + assert res_idx.ndim == 2 + res_idx = torch.clone(res_idx) + subunit_len = res_idx.shape[-1] // num_breaks + chain_delimiters = [] + if offset is None: + offset = res_idx.shape[-1] + for i in range(num_breaks): + start_i = subunit_len * i + end_i = subunit_len * (i+1) + chain_labels = list(string.ascii_uppercase) + [str(i+j) for i in + string.ascii_uppercase for j in string.ascii_uppercase] + chain_delimiters.extend( + [chain_labels[i] for _ in range(subunit_len)] + ) + res_idx[:, start_i:end_i] = res_idx[:, start_i:end_i] + offset * (i+1) + return res_idx, chain_delimiters + + ####################### + ## Dihedral symmetry ## + ####################### + def _init_dihedral(self, order): + sym_rots = [] + flip = Rotation.from_euler('x', 180, degrees=True).as_matrix() + for i in range(order): + deg = i * 360.0 / order + rot = Rotation.from_euler('z', deg, degrees=True).as_matrix() + sym_rots.append(format_rots(rot)) + rot2 = flip @ rot + sym_rots.append(format_rots(rot2)) + self.sym_rots = sym_rots + self.order = order * 2 + + ######################### + ## Octahedral symmetry ## + ######################### + def _init_octahedral(self): + sym_rots = np.load(f"{pathlib.Path(__file__).parent.resolve()}/sym_rots.npz") + self.sym_rots = [ + torch.tensor(v_i, dtype=torch.float32) + for v_i in sym_rots['octahedral'] + ] + self.order = len(self.sym_rots) + + def _apply_octahedral(self, coords_in, seq_in): + coords_out = torch.clone(coords_in) + seq_out = torch.clone(seq_in) + if seq_out.shape[0] % self.order != 0: + raise ValueError( + f'Sequence length must be divisble by {self.order}') + subunit_len = seq_out.shape[0] // self.order + base_axis = torch.tensor([self._radius, 0., 0.])[None] + for i in range(self.order): + start_i = subunit_len * i + end_i = subunit_len * (i+1) + subunit_chain = torch.einsum( + 'bnj,kj->bnk', coords_in[:subunit_len], self.sym_rots[i]) + + if self._recenter: + center = torch.mean(subunit_chain[:, 1, :], axis=0) + subunit_chain -= center[None, None, :] + rotated_axis = torch.einsum( + 'nj,kj->nk', base_axis, self.sym_rots[i]) + subunit_chain += rotated_axis[:, None, :] + + coords_out[start_i:end_i] = subunit_chain + seq_out[start_i:end_i] = seq_out[:subunit_len] + return coords_out, seq_out + + ####################### + ## symmetry from file # + ####################### + def _init_from_symrots_file(self, name): + """ _init_from_symrots_file initializes using + ./inference/sym_rots.npz + + Args: + name: name of symmetry (of tetrahedral, octahedral, icosahedral) + + sets self.sym_rots to be a list of torch.tensor of shape [3, 3] + """ + assert name in saved_symmetries, name + " not in " + str(saved_symmetries) + + # Load in list of rotation matrices for `name` + fn = f"{pathlib.Path(__file__).parent.resolve()}/sym_rots.npz" + obj = np.load(fn) + symms = None + for k, v in obj.items(): + if str(k) == name: symms = v + assert symms is not None, "%s not found in %s"%(name, fn) + + + self.sym_rots = [torch.tensor(v_i, dtype=torch.float32) for v_i in symms] + self.order = len(self.sym_rots) + + # Return if identity is the first rotation + if not np.isclose(((self.sym_rots[0]-np.eye(3))**2).sum(), 0): + + # Move identity to be the first rotation + for i, rot in enumerate(self.sym_rots): + if np.isclose(((rot-np.eye(3))**2).sum(), 0): + self.sym_rots = [self.sym_rots.pop(i)] + self.sym_rots + + assert len(self.sym_rots) == self.order + assert np.isclose(((self.sym_rots[0]-np.eye(3))**2).sum(), 0) + + def close_neighbors(self): + """close_neighbors finds the rotations within self.sym_rots that + correspond to close neighbors. + + Returns: + list of rotation matrices corresponding to the identity and close neighbors + """ + # set of small rotation angle rotations + rel_rot = lambda M: np.linalg.norm(Rotation.from_matrix(M).as_rotvec()) + rel_rots = [(i+1, rel_rot(M)) for i, M in enumerate(self.sym_rots[1:])] + min_rot = min(rel_rot_val[1] for rel_rot_val in rel_rots) + close_rots = [np.eye(3)] + [ + self.sym_rots[i] for i, rel_rot_val in rel_rots if + np.isclose(rel_rot_val, min_rot) + ] + return close_rots diff --git a/rfdiffusion/inference/inference/utils.py b/rfdiffusion/inference/inference/utils.py new file mode 100644 index 00000000..2ed6105b --- /dev/null +++ b/rfdiffusion/inference/inference/utils.py @@ -0,0 +1,1015 @@ +import numpy as np +import os +from omegaconf import DictConfig +import torch +import torch.nn.functional as nn +from rfdiffusion.diffusion import get_beta_schedule +from scipy.spatial.transform import Rotation as scipy_R +from rfdiffusion.util import rigid_from_3_points +from rfdiffusion.util_module import ComputeAllAtomCoords +from rfdiffusion import util +import random +import logging +from rfdiffusion.inference import model_runners +import glob + +########################################################### +#### Functions which can be called outside of Denoiser #### +########################################################### + + +def get_next_frames(xt, px0, t, diffuser, so3_type, diffusion_mask, noise_scale=1.0): + """ + get_next_frames gets updated frames using IGSO(3) + score_based reverse diffusion. + + + based on self.so3_type use score based update. + + Generate frames at t-1 + Rather than generating random rotations (as occurs during forward process), calculate rotation between xt and px0 + + Args: + xt: noised coordinates of shape [L, 14, 3] + px0: prediction of coordinates at t=0, of shape [L, 14, 3] + t: integer time step + diffuser: Diffuser object for reverse igSO3 sampling + so3_type: The type of SO3 noising being used ('igso3') + diffusion_mask: of shape [L] of type bool, True means not to be + updated (e.g. mask is true for motif residues) + noise_scale: scale factor for the noise added (IGSO3 only) + + Returns: + backbone coordinates for step x_t-1 of shape [L, 3, 3] + """ + N_0 = px0[None, :, 0, :] + Ca_0 = px0[None, :, 1, :] + C_0 = px0[None, :, 2, :] + + R_0, Ca_0 = rigid_from_3_points(N_0, Ca_0, C_0) + + N_t = xt[None, :, 0, :] + Ca_t = xt[None, :, 1, :] + C_t = xt[None, :, 2, :] + + R_t, Ca_t = rigid_from_3_points(N_t, Ca_t, C_t) + + # this must be to normalize them or something + R_0 = scipy_R.from_matrix(R_0.squeeze().numpy()).as_matrix() + R_t = scipy_R.from_matrix(R_t.squeeze().numpy()).as_matrix() + + L = R_t.shape[0] + all_rot_transitions = np.broadcast_to(np.identity(3), (L, 3, 3)).copy() + # Sample next frame for each residue + if so3_type == "igso3": + # don't do calculations on masked positions since they end up as identity matrix + all_rot_transitions[ + ~diffusion_mask + ] = diffuser.so3_diffuser.reverse_sample_vectorized( + R_t[~diffusion_mask], + R_0[~diffusion_mask], + t, + noise_level=noise_scale, + mask=None, + return_perturb=True, + ) + else: + assert False, "so3 diffusion type %s not implemented" % so3_type + + all_rot_transitions = all_rot_transitions[:, None, :, :] + + # Apply the interpolated rotation matrices to the coordinates + next_crds = ( + np.einsum( + "lrij,laj->lrai", + all_rot_transitions, + xt[:, :3, :] - Ca_t.squeeze()[:, None, ...].numpy(), + ) + + Ca_t.squeeze()[:, None, None, ...].numpy() + ) + + # (L,3,3) set of backbone coordinates with slight rotation + return next_crds.squeeze(1) + + +def get_mu_xt_x0(xt, px0, t, beta_schedule, alphabar_schedule, eps=1e-6): + """ + Given xt, predicted x0 and the timestep t, give mu of x(t-1) + Assumes t is 0 indexed + """ + # sigma is predefined from beta. Often referred to as beta tilde t + t_idx = t - 1 + sigma = ( + (1 - alphabar_schedule[t_idx - 1]) / (1 - alphabar_schedule[t_idx]) + ) * beta_schedule[t_idx] + + xt_ca = xt[:, 1, :] + px0_ca = px0[:, 1, :] + + a = ( + (torch.sqrt(alphabar_schedule[t_idx - 1] + eps) * beta_schedule[t_idx]) + / (1 - alphabar_schedule[t_idx]) + ) * px0_ca + b = ( + ( + torch.sqrt(1 - beta_schedule[t_idx] + eps) + * (1 - alphabar_schedule[t_idx - 1]) + ) + / (1 - alphabar_schedule[t_idx]) + ) * xt_ca + + mu = a + b + + return mu, sigma + + +def get_next_ca( + xt, + px0, + t, + diffusion_mask, + crd_scale, + beta_schedule, + alphabar_schedule, + noise_scale=1.0, +): + """ + Given full atom x0 prediction (xyz coordinates), diffuse to x(t-1) + + Parameters: + + xt (L, 14/27, 3) set of coordinates + + px0 (L, 14/27, 3) set of coordinates + + t: time step. Note this is zero-index current time step, so are generating t-1 + + logits_aa (L x 20 ) amino acid probabilities at each position + + seq_schedule (L): Tensor of bools, True is unmasked, False is masked. For this specific t + + diffusion_mask (torch.tensor, required): Tensor of bools, True means NOT diffused at this residue, False means diffused + + noise_scale: scale factor for the noise being added + + """ + get_allatom = ComputeAllAtomCoords().to(device=xt.device) + L = len(xt) + + # bring to origin after global alignment (when don't have a motif) or replace input motif and bring to origin, and then scale + px0 = px0 * crd_scale + xt = xt * crd_scale + + # get mu(xt, x0) + mu, sigma = get_mu_xt_x0( + xt, px0, t, beta_schedule=beta_schedule, alphabar_schedule=alphabar_schedule + ) + + sampled_crds = torch.normal(mu, torch.sqrt(sigma * noise_scale)) + delta = sampled_crds - xt[:, 1, :] # check sign of this is correct + + if not diffusion_mask is None: + # Don't move motif + delta[diffusion_mask, ...] = 0 + + out_crds = xt + delta[:, None, :] + + return out_crds / crd_scale, delta / crd_scale + + +def get_noise_schedule(T, noiseT, noise1, schedule_type): + """ + Function to create a schedule that varies the scale of noise given to the model over time + + Parameters: + + T: The total number of timesteps in the denoising trajectory + + noiseT: The inital (t=T) noise scale + + noise1: The final (t=1) noise scale + + schedule_type: The type of function to use to interpolate between noiseT and noise1 + + Returns: + + noise_schedule: A function which maps timestep to noise scale + + """ + + noise_schedules = { + "constant": lambda t: noiseT, + "linear": lambda t: ((t - 1) / (T - 1)) * (noiseT - noise1) + noise1, + } + + assert ( + schedule_type in noise_schedules + ), f"noise_schedule must be one of {noise_schedules.keys()}. Received noise_schedule={schedule_type}. Exiting." + + return noise_schedules[schedule_type] + + +class Denoise: + """ + Class for getting x(t-1) from predicted x0 and x(t) + Strategy: + Ca coordinates: Rediffuse to x(t-1) from predicted x0 + Frames: Approximate update from rotation score + Torsions: 1/t of the way to the x0 prediction + + """ + + def __init__( + self, + T, + L, + diffuser, + b_0=0.001, + b_T=0.1, + min_b=1.0, + max_b=12.5, + min_sigma=0.05, + max_sigma=1.5, + noise_level=0.5, + schedule_type="linear", + so3_schedule_type="linear", + schedule_kwargs={}, + so3_type="igso3", + noise_scale_ca=1.0, + final_noise_scale_ca=1, + ca_noise_schedule_type="constant", + noise_scale_frame=0.5, + final_noise_scale_frame=0.5, + frame_noise_schedule_type="constant", + crd_scale=1 / 15, + potential_manager=None, + partial_T=None, + ): + """ + + Parameters: + noise_level: scaling on the noise added (set to 0 to use no noise, + to 1 to have full noise) + + """ + self.T = T + self.L = L + self.diffuser = diffuser + self.b_0 = b_0 + self.b_T = b_T + self.noise_level = noise_level + self.schedule_type = schedule_type + self.so3_type = so3_type + self.crd_scale = crd_scale + self.noise_scale_ca = noise_scale_ca + self.final_noise_scale_ca = final_noise_scale_ca + self.ca_noise_schedule_type = ca_noise_schedule_type + self.noise_scale_frame = noise_scale_frame + self.final_noise_scale_frame = final_noise_scale_frame + self.frame_noise_schedule_type = frame_noise_schedule_type + self.potential_manager = potential_manager + self._log = logging.getLogger(__name__) + + self.schedule, self.alpha_schedule, self.alphabar_schedule = get_beta_schedule( + self.T, self.b_0, self.b_T, self.schedule_type, inference=True + ) + + self.noise_schedule_ca = get_noise_schedule( + self.T, + self.noise_scale_ca, + self.final_noise_scale_ca, + self.ca_noise_schedule_type, + ) + self.noise_schedule_frame = get_noise_schedule( + self.T, + self.noise_scale_frame, + self.final_noise_scale_frame, + self.frame_noise_schedule_type, + ) + + @property + def idx2steps(self): + return self.decode_scheduler.idx2steps.numpy() + + def align_to_xt_motif(self, px0, xT, diffusion_mask, eps=1e-6): + """ + Need to align px0 to motif in xT. This is to permit the swapping of residue positions in the px0 motif for the true coordinates. + First, get rotation matrix from px0 to xT for the motif residues. + Second, rotate px0 (whole structure) by that rotation matrix + Third, centre at origin + """ + + def rmsd(V, W, eps=0): + # First sum down atoms, then sum down xyz + N = V.shape[-2] + return np.sqrt(np.sum((V - W) * (V - W), axis=(-2, -1)) / N + eps) + + assert ( + xT.shape[1] == px0.shape[1] + ), f"xT has shape {xT.shape} and px0 has shape {px0.shape}" + + L, n_atom, _ = xT.shape # A is number of atoms + atom_mask = ~torch.isnan(px0) + # convert to numpy arrays + px0 = px0.cpu().detach().numpy() + xT = xT.cpu().detach().numpy() + diffusion_mask = diffusion_mask.cpu().detach().numpy() + + # 1 centre motifs at origin and get rotation matrix + px0_motif = px0[diffusion_mask, :3].reshape(-1, 3) + xT_motif = xT[diffusion_mask, :3].reshape(-1, 3) + px0_motif_mean = np.copy(px0_motif.mean(0)) # need later + xT_motif_mean = np.copy(xT_motif.mean(0)) + + # center at origin + px0_motif = px0_motif - px0_motif_mean + xT_motif = xT_motif - xT_motif_mean + + # A = px0_motif + # B = xT_motif + A = xT_motif + B = px0_motif + + C = np.matmul(A.T, B) + + # compute optimal rotation matrix using SVD + U, S, Vt = np.linalg.svd(C) + + # ensure right handed coordinate system + d = np.eye(3) + d[-1, -1] = np.sign(np.linalg.det(Vt.T @ U.T)) + + # construct rotation matrix + R = Vt.T @ d @ U.T + + # get rotated coords + rB = B @ R + + # calculate rmsd + rms = rmsd(A, rB) + self._log.info(f"Sampled motif RMSD: {rms:.2f}") + + # 2 rotate whole px0 by rotation matrix + atom_mask = atom_mask.cpu() + px0[~atom_mask] = 0 # convert nans to 0 + px0 = px0.reshape(-1, 3) - px0_motif_mean + px0_ = px0 @ R + + # 3 put in same global position as xT + px0_ = px0_ + xT_motif_mean + px0_ = px0_.reshape([L, n_atom, 3]) + px0_[~atom_mask] = float("nan") + return torch.Tensor(px0_) + + def get_potential_gradients(self, xyz, diffusion_mask): + """ + This could be moved into potential manager if desired - NRB + + Function to take a structure (x) and get per-atom gradients used to guide diffusion update + + Inputs: + + xyz (torch.tensor, required): [L,27,3] Coordinates at which the gradient will be computed + + Outputs: + + Ca_grads (torch.tensor): [L,3] The gradient at each Ca atom + """ + + if self.potential_manager == None or self.potential_manager.is_empty(): + return torch.zeros(xyz.shape[0], 3) + + use_Cb = False + + # seq.requires_grad = True + xyz.requires_grad = True + + if not xyz.grad is None: + xyz.grad.zero_() + + current_potential = self.potential_manager.compute_all_potentials(xyz) + current_potential.backward() + + # Since we are not moving frames, Cb grads are same as Ca grads + # Need access to calculated Cb coordinates to be able to get Cb grads though + Ca_grads = xyz.grad[:, 1, :] + + if not diffusion_mask == None: + Ca_grads[diffusion_mask, :] = 0 + + # check for NaN's + if torch.isnan(Ca_grads).any(): + print("WARNING: NaN in potential gradients, replacing with zero grad.") + Ca_grads[:] = 0 + + return Ca_grads + + def get_next_pose( + self, + xt, + px0, + t, + diffusion_mask, + fix_motif=True, + align_motif=True, + include_motif_sidechains=True, + ): + """ + Wrapper function to take px0, xt and t, and to produce xt-1 + First, aligns px0 to xt + Then gets coordinates, frames and torsion angles + + Parameters: + + xt (torch.tensor, required): Current coordinates at timestep t + + px0 (torch.tensor, required): Prediction of x0 + + t (int, required): timestep t + + diffusion_mask (torch.tensor, required): Mask for structure diffusion + + fix_motif (bool): Fix the motif structure + + align_motif (bool): Align the model's prediction of the motif to the input motif + + include_motif_sidechains (bool): Provide sidechains of the fixed motif to the model + """ + + get_allatom = ComputeAllAtomCoords().to(device=xt.device) + L, n_atom = xt.shape[:2] + assert (xt.shape[1] == 14) or (xt.shape[1] == 27) + assert (px0.shape[1] == 14) or (px0.shape[1] == 27) + + ############################### + ### Align pX0 onto Xt motif ### + ############################### + + if align_motif and diffusion_mask.any(): + px0 = self.align_to_xt_motif(px0, xt, diffusion_mask) + # xT_motif_aligned = self.align_to_xt_motif(px0, xt, diffusion_mask) + + px0 = px0.to(xt.device) + # Now done with diffusion mask. if fix motif is False, just set diffusion mask to be all True, and all coordinates can diffuse + if not fix_motif: + diffusion_mask[:] = False + + # get the next set of CA coordinates + noise_scale_ca = self.noise_schedule_ca(t) + _, ca_deltas = get_next_ca( + xt, + px0, + t, + diffusion_mask, + crd_scale=self.crd_scale, + beta_schedule=self.schedule, + alphabar_schedule=self.alphabar_schedule, + noise_scale=noise_scale_ca, + ) + + # get the next set of backbone frames (coordinates) + noise_scale_frame = self.noise_schedule_frame(t) + frames_next = get_next_frames( + xt, + px0, + t, + diffuser=self.diffuser, + so3_type=self.so3_type, + diffusion_mask=diffusion_mask, + noise_scale=noise_scale_frame, + ) + + # Apply gradient step from guiding potentials + # This can be moved to below where the full atom representation is calculated to allow for potentials involving sidechains + + grad_ca = self.get_potential_gradients( + xt.clone(), diffusion_mask=diffusion_mask + ) + + ca_deltas += self.potential_manager.get_guide_scale(t) * grad_ca + + # add the delta to the new frames + frames_next = torch.from_numpy(frames_next) + ca_deltas[:, None, :] # translate + + fullatom_next = torch.full_like(xt, float("nan")).unsqueeze(0) + fullatom_next[:, :, :3] = frames_next[None] + # This is never used so just make it a fudged tensor - NRB + torsions_next = torch.zeros(1, 1) + + if include_motif_sidechains: + fullatom_next[:, diffusion_mask, :14] = xt[None, diffusion_mask] + + return fullatom_next.squeeze()[:, :14, :], px0 + + +def sampler_selector(conf: DictConfig): + if conf.scaffoldguided.scaffoldguided: + sampler = model_runners.ScaffoldedSampler(conf) + else: + if conf.inference.model_runner == "default": + sampler = model_runners.Sampler(conf) + elif conf.inference.model_runner == "SelfConditioning": + sampler = model_runners.SelfConditioning(conf) + elif conf.inference.model_runner == "ScaffoldedSampler": + sampler = model_runners.ScaffoldedSampler(conf) + else: + raise ValueError(f"Unrecognized sampler {conf.model_runner}") + return sampler + + +def parse_pdb(filename, **kwargs): + """extract xyz coords for all heavy atoms""" + with open(filename,"r") as f: + lines=f.readlines() + return parse_pdb_lines(lines, **kwargs) + + +def parse_pdb_lines(lines, parse_hetatom=False, ignore_het_h=True): + # indices of residues observed in the structure + res, pdb_idx = [],[] + for l in lines: + if l[:4] == "ATOM" and l[12:16].strip() == "CA": + res.append((l[22:26], l[17:20])) + # chain letter, res num + pdb_idx.append((l[21:22].strip(), int(l[22:26].strip()))) + seq = [util.aa2num[r[1]] if r[1] in util.aa2num.keys() else 20 for r in res] + pdb_idx = [ + (l[21:22].strip(), int(l[22:26].strip())) + for l in lines + if l[:4] == "ATOM" and l[12:16].strip() == "CA" + ] # chain letter, res num + + # 4 BB + up to 10 SC atoms + xyz = np.full((len(res), 14, 3), np.nan, dtype=np.float32) + for l in lines: + if l[:4] != "ATOM": + continue + chain, resNo, atom, aa = ( + l[21:22], + int(l[22:26]), + " " + l[12:16].strip().ljust(3), + l[17:20], + ) + if (chain,resNo) in pdb_idx: + idx = pdb_idx.index((chain, resNo)) + # for i_atm, tgtatm in enumerate(util.aa2long[util.aa2num[aa]]): + for i_atm, tgtatm in enumerate( + util.aa2long[util.aa2num[aa]][:14] + ): + if ( + tgtatm is not None and tgtatm.strip() == atom.strip() + ): # ignore whitespace + xyz[idx, i_atm, :] = [float(l[30:38]), float(l[38:46]), float(l[46:54])] + break + + # save atom mask + mask = np.logical_not(np.isnan(xyz[..., 0])) + xyz[np.isnan(xyz[..., 0])] = 0.0 + + # remove duplicated (chain, resi) + new_idx = [] + i_unique = [] + for i, idx in enumerate(pdb_idx): + if idx not in new_idx: + new_idx.append(idx) + i_unique.append(i) + + pdb_idx = new_idx + xyz = xyz[i_unique] + mask = mask[i_unique] + + seq = np.array(seq)[i_unique] + + out = { + "xyz": xyz, # cartesian coordinates, [Lx14] + "mask": mask, # mask showing which atoms are present in the PDB file, [Lx14] + "idx": np.array( + [i[1] for i in pdb_idx] + ), # residue numbers in the PDB file, [L] + "seq": np.array(seq), # amino acid sequence, [L] + "pdb_idx": pdb_idx, # list of (chain letter, residue number) in the pdb file, [L] + } + + # heteroatoms (ligands, etc) + if parse_hetatom: + xyz_het, info_het = [], [] + for l in lines: + if l[:6] == "HETATM" and not (ignore_het_h and l[77] == "H"): + info_het.append( + dict( + idx=int(l[7:11]), + atom_id=l[12:16], + atom_type=l[77], + name=l[16:20], + ) + ) + xyz_het.append([float(l[30:38]), float(l[38:46]), float(l[46:54])]) + + out["xyz_het"] = np.array(xyz_het) + out["info_het"] = info_het + + return out + + +def process_target(pdb_path, parse_hetatom=False, center=True): + # Read target pdb and extract features. + target_struct = parse_pdb(pdb_path, parse_hetatom=parse_hetatom) + + # Zero-center positions + ca_center = target_struct["xyz"][:, :1, :].mean(axis=0, keepdims=True) + if not center: + ca_center = 0 + xyz = torch.from_numpy(target_struct["xyz"] - ca_center) + seq_orig = torch.from_numpy(target_struct["seq"]) + atom_mask = torch.from_numpy(target_struct["mask"]) + seq_len = len(xyz) + + # Make 27 atom representation + xyz_27 = torch.full((seq_len, 27, 3), np.nan).float() + xyz_27[:, :14, :] = xyz[:, :14, :] + + mask_27 = torch.full((seq_len, 27), False) + mask_27[:, :14] = atom_mask + out = { + "xyz_27": xyz_27, + "mask_27": mask_27, + "seq": seq_orig, + "pdb_idx": target_struct["pdb_idx"], + } + if parse_hetatom: + out["xyz_het"] = target_struct["xyz_het"] + out["info_het"] = target_struct["info_het"] + return out + + +def get_idx0_hotspots(mappings, ppi_conf, binderlen): + """ + Take pdb-indexed hotspot resudes and the length of the binder, and makes the 0-indexed tensor of hotspots + """ + + hotspot_idx = None + if binderlen > 0: + if ppi_conf.hotspot_res is not None: + assert all( + [i[0].isalpha() for i in ppi_conf.hotspot_res] + ), "Hotspot residues need to be provided in pdb-indexed form. E.g. A100,A103" + hotspots = [(i[0], int(i[1:])) for i in ppi_conf.hotspot_res] + hotspot_idx = [] + for i, res in enumerate(mappings["receptor_con_ref_pdb_idx"]): + if res in hotspots: + hotspot_idx.append(mappings["receptor_con_hal_idx0"][i]) + return hotspot_idx + + +class BlockAdjacency: + """ + Class for handling PPI design inference with ss/block_adj inputs. + Basic idea is to provide a list of scaffolds, and to output ss and adjacency + matrices based off of these, while sampling additional lengths. + Inputs: + - scaffold_list: list of scaffolds (e.g. ['2kl8','1cif']). Can also be a .txt file. + - scaffold dir: directory where scaffold ss and adj are precalculated + - sampled_insertion: how many additional residues do you want to add to each loop segment? Randomly sampled 0-this number (or within given range) + - sampled_N: randomly sample up to this number of additional residues at N-term + - sampled_C: randomly sample up to this number of additional residues at C-term + - ss_mask: how many residues do you want to mask at either end of a ss (H or E) block. Fixed value + - num_designs: how many designs are you wanting to generate? Currently only used for bookkeeping + - systematic: do you want to systematically work through the list of scaffolds, or randomly sample (default) + - num_designs_per_input: Not really implemented yet. Maybe not necessary + Outputs: + - L: new length of chain to be diffused + - ss: all loops and insertions, and ends of ss blocks (up to ss_mask) set to mask token (3). Onehot encoded. (L,4) + - adj: block adjacency with equivalent masking as ss (L,L) + """ + + def __init__(self, conf, num_designs): + """ + Parameters: + inputs: + conf.scaffold_list as conf + conf.inference.num_designs for sanity checking + """ + + self.conf=conf + # either list or path to .txt file with list of scaffolds + if self.conf.scaffoldguided.scaffold_list is not None: + if type(self.conf.scaffoldguided.scaffold_list) == list: + self.scaffold_list = scaffold_list + elif self.conf.scaffoldguided.scaffold_list[-4:] == ".txt": + # txt file with list of ids + list_from_file = [] + with open(self.conf.scaffoldguided.scaffold_list, "r") as f: + for line in f: + list_from_file.append(line.strip()) + self.scaffold_list = list_from_file + else: + raise NotImplementedError + else: + self.scaffold_list = [ + os.path.split(i)[1][:-6] + for i in glob.glob(f"{self.conf.scaffoldguided.scaffold_dir}/*_ss.pt") + ] + self.scaffold_list.sort() + + # path to directory with scaffolds, ss files and block_adjacency files + self.scaffold_dir = self.conf.scaffoldguided.scaffold_dir + + # maximum sampled insertion in each loop segment + if "-" in str(self.conf.scaffoldguided.sampled_insertion): + self.sampled_insertion = [ + int(str(self.conf.scaffoldguided.sampled_insertion).split("-")[0]), + int(str(self.conf.scaffoldguided.sampled_insertion).split("-")[1]), + ] + else: + self.sampled_insertion = [0, int(self.conf.scaffoldguided.sampled_insertion)] + + # maximum sampled insertion at N- and C-terminus + if "-" in str(self.conf.scaffoldguided.sampled_N): + self.sampled_N = [ + int(str(self.conf.scaffoldguided.sampled_N).split("-")[0]), + int(str(self.conf.scaffoldguided.sampled_N).split("-")[1]), + ] + else: + self.sampled_N = [0, int(self.conf.scaffoldguided.sampled_N)] + if "-" in str(self.conf.scaffoldguided.sampled_C): + self.sampled_C = [ + int(str(self.conf.scaffoldguided.sampled_C).split("-")[0]), + int(str(self.conf.scaffoldguided.sampled_C).split("-")[1]), + ] + else: + self.sampled_C = [0, int(self.conf.scaffoldguided.sampled_C)] + + # number of residues to mask ss identity of in H/E regions (from junction) + # e.g. if ss_mask = 2, L,L,L,H,H,H,H,H,H,H,L,L,E,E,E,E,E,E,L,L,L,L,L,L would become\ + # M,M,M,M,M,H,H,H,M,M,M,M,M,M,E,E,M,M,M,M,M,M,M,M where M is mask + self.ss_mask = self.conf.scaffoldguided.ss_mask + + # whether or not to work systematically through the list + self.systematic = self.conf.scaffoldguided.systematic + + self.num_designs = num_designs + + if len(self.scaffold_list) > self.num_designs: + print( + "WARNING: Scaffold set is bigger than num_designs, so not every scaffold type will be sampled" + ) + + # for tracking number of designs + self.num_completed = 0 + if self.systematic: + self.item_n = 0 + + # whether to mask loops or not + if not self.conf.scaffoldguided.mask_loops: + assert self.conf.scaffoldguided.sampled_N == 0, "can't add length if not masking loops" + assert self.conf.scaffoldguided.sampled_C == 0, "can't add lemgth if not masking loops" + assert self.conf.scaffoldguided.sampled_insertion == 0, "can't add length if not masking loops" + self.mask_loops = False + else: + self.mask_loops = True + + def get_ss_adj(self, item): + """ + Given at item, get the ss tensor and block adjacency matrix for that item + """ + ss = torch.load(os.path.join(self.scaffold_dir, f'{item.split(".")[0]}_ss.pt')) + adj = torch.load( + os.path.join(self.scaffold_dir, f'{item.split(".")[0]}_adj.pt') + ) + + return ss, adj + + def mask_to_segments(self, mask): + """ + Takes a mask of True (loop) and False (non-loop), and outputs list of tuples (loop or not, length of element) + """ + segments = [] + begin = -1 + end = -1 + for i in range(mask.shape[0]): + # Starting edge case + if i == 0: + begin = 0 + continue + + if not mask[i] == mask[i - 1]: + end = i + if mask[i - 1].item() is True: + segments.append(("loop", end - begin)) + else: + segments.append(("ss", end - begin)) + begin = i + + # Ending edge case: last segment is length one + if not end == mask.shape[0]: + if mask[i].item() is True: + segments.append(("loop", mask.shape[0] - begin)) + else: + segments.append(("ss", mask.shape[0] - begin)) + return segments + + def expand_mask(self, mask, segments): + """ + Function to generate a new mask with dilated loops and N and C terminal additions + """ + N_add = random.randint(self.sampled_N[0], self.sampled_N[1]) + C_add = random.randint(self.sampled_C[0], self.sampled_C[1]) + + output = N_add * [False] + for ss, length in segments: + if ss == "ss": + output.extend(length * [True]) + else: + # randomly sample insertion length + ins = random.randint( + self.sampled_insertion[0], self.sampled_insertion[1] + ) + output.extend((length + ins) * [False]) + output.extend(C_add * [False]) + assert torch.sum(torch.tensor(output)) == torch.sum(~mask) + return torch.tensor(output) + + def expand_ss(self, ss, adj, mask, expanded_mask): + """ + Given an expanded mask, populate a new ss and adj based on this + """ + ss_out = torch.ones(expanded_mask.shape[0]) * 3 # set to mask token + adj_out = torch.full((expanded_mask.shape[0], expanded_mask.shape[0]), 0.0) + ss_out[expanded_mask] = ss[~mask] + expanded_mask_2d = torch.full(adj_out.shape, True) + # mask out loops/insertions, which is ~expanded_mask + expanded_mask_2d[~expanded_mask, :] = False + expanded_mask_2d[:, ~expanded_mask] = False + + mask_2d = torch.full(adj.shape, True) + # mask out loops. This mask is True=loop + mask_2d[mask, :] = False + mask_2d[:, mask] = False + adj_out[expanded_mask_2d] = adj[mask_2d] + adj_out = adj_out.reshape((expanded_mask.shape[0], expanded_mask.shape[0])) + + return ss_out, adj_out + + def mask_ss_adj(self, ss, adj, expanded_mask): + """ + Given an expanded ss and adj, mask some number of residues at either end of non-loop ss + """ + original_mask = torch.clone(expanded_mask) + if self.ss_mask > 0: + for i in range(1, self.ss_mask + 1): + expanded_mask[i:] *= original_mask[:-i] + expanded_mask[:-i] *= original_mask[i:] + + if self.mask_loops: + ss[~expanded_mask] = 3 + adj[~expanded_mask, :] = 0 + adj[:, ~expanded_mask] = 0 + + # mask adjacency + adj[~expanded_mask] = 2 + adj[:, ~expanded_mask] = 2 + + return ss, adj + + def get_scaffold(self): + """ + Wrapper method for pulling an item from the list, and preparing ss and block adj features + """ + + # Handle determinism. Useful for integration tests + if self.conf.inference.deterministic: + torch.manual_seed(self.num_completed) + np.random.seed(self.num_completed) + random.seed(self.num_completed) + + if self.systematic: + # reset if num designs > num_scaffolds + if self.item_n >= len(self.scaffold_list): + self.item_n = 0 + item = self.scaffold_list[self.item_n] + self.item_n += 1 + else: + item = random.choice(self.scaffold_list) + print("Scaffold constrained based on file: ", item) + # load files + ss, adj = self.get_ss_adj(item) + adj_orig = torch.clone(adj) + # separate into segments (loop or not) + mask = torch.where(ss == 2, 1, 0).bool() + segments = self.mask_to_segments(mask) + + # insert into loops to generate new mask + expanded_mask = self.expand_mask(mask, segments) + + # expand ss and adj + ss, adj = self.expand_ss(ss, adj, mask, expanded_mask) + + # finally, mask some proportion of the ss at either end of the non-loop ss blocks + ss, adj = self.mask_ss_adj(ss, adj, expanded_mask) + + # and then update num_completed + self.num_completed += 1 + + return ss.shape[0], torch.nn.functional.one_hot(ss.long(), num_classes=4), adj + + +class Target: + """ + Class to handle targets (fixed chains). + Inputs: + - path to pdb file + - hotspot residues, in the form B10,B12,B60 etc + - whether or not to crop, and with which method + Outputs: + - Dictionary of xyz coordinates, indices, pdb_indices, pdb mask + """ + + def __init__(self, conf: DictConfig, hotspots=None): + self.pdb = parse_pdb(conf.target_path) + + if hotspots is not None: + self.hotspots = hotspots + else: + self.hotspots = [] + self.pdb["hotspots"] = np.array( + [ + True if f"{i[0]}{i[1]}" in self.hotspots else False + for i in self.pdb["pdb_idx"] + ] + ) + + if conf.contig_crop: + self.contig_crop(conf.contig_crop) + + def parse_contig(self, contig_crop): + """ + Takes contig input and parses + """ + contig_list = [] + for contig in contig_crop[0].split(" "): + subcon = [] + for crop in contig.split("/"): + if crop[0].isalpha(): + subcon.extend( + [ + (crop[0], p) + for p in np.arange( + int(crop.split("-")[0][1:]), int(crop.split("-")[1]) + 1 + ) + ] + ) + contig_list.append(subcon) + return contig_list + + def contig_crop(self, contig_crop, residue_offset=200) -> None: + """ + Method to take a contig string referring to the receptor and output a pdb dictionary with just this crop + NB there are two ways to provide inputs: + - 1) e.g. B1-30,0 B50-60,0. This will add a residue offset between each chunk + - 2) e.g. B1-30,B50-60,B80-100. This will keep the original indexing of the pdb file. + Can handle the target being on multiple chains + """ + + # add residue offset between chains if multiple chains in receptor file + for idx, val in enumerate(self.pdb["pdb_idx"]): + if idx != 0 and val != self.pdb["pdb_idx"][idx - 1]: + self.pdb["idx"][idx:] += residue_offset + idx + + # convert contig to mask + contig_list = self.parse_contig(contig_crop) + + # add residue offset to different parts of contig_list + for contig in contig_list[1:]: + start = int(contig[0][1]) + self.pdb["idx"][start:] += residue_offset + # flatten list + contig_list = [i for j in contig_list for i in j] + mask = np.array( + [True if i in contig_list else False for i in self.pdb["pdb_idx"]] + ) + + # sanity check + assert np.sum(self.pdb["hotspots"]) == np.sum( + self.pdb["hotspots"][mask] + ), "Supplied hotspot residues are missing from the target contig!" + # crop pdb + for key, val in self.pdb.items(): + try: + self.pdb[key] = val[mask] + except: + self.pdb[key] = [i for idx, i in enumerate(val) if mask[idx]] + self.pdb["crop_mask"] = mask + + def get_target(self): + return self.pdb + +def ss_from_contig(ss_masks: dict): + """ + Function for taking 1D masks for each of the ss types, and outputting a secondary structure input + """ + L=len(ss_masks['helix']) + ss=torch.zeros((L, 4)).long() + ss[:,3] = 1 #mask + for idx, mask in enumerate([ss_masks['helix'],ss_masks['strand'], ss_masks['loop']]): + ss[mask,idx] = 1 + ss[mask, 3] = 0 # remove the mask token + return ss \ No newline at end of file From 49869a976177ab04a6aeb0eabeba21cbd8f5abc2 Mon Sep 17 00:00:00 2001 From: RuizhiPeng Date: Sun, 28 Sep 2025 00:06:47 -0400 Subject: [PATCH 2/3] fix --- rfdiffusion/inference/model_runners.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index f47d0e96..d51e11d2 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -939,7 +939,28 @@ def sample_init(self): ### Get hotspots ### #################### self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) - + + ####################################### + ### Resolve cyclic peptide indicies ### + ####################################### + if self._conf.inference.cyclic: + if self._conf.inference.cyc_chains is None: + # default to all residues being cyclized + self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() + else: + # use cyc_chains arg to determine cyclic_reses mask + assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string' + cyc_chains = self._conf.inference.cyc_chains + cyc_chains = [i.upper() for i in cyc_chains] + hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains + is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty + for ch in cyc_chains: + ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() + is_cyclized[ch_mask] = True # set this whole chain to be cyclic + self.cyclic_reses = is_cyclized + else: + self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() + ######################### ### Set up potentials ### ######################### From 5636b67f1418525067be733398157f2815a14251 Mon Sep 17 00:00:00 2001 From: Ruizhi Peng <51493110+RuizhiPeng@users.noreply.github.com> Date: Sun, 28 Sep 2025 00:08:52 -0400 Subject: [PATCH 3/3] Delete rfdiffusion/inference/inference directory wrong upload --- rfdiffusion/inference/inference/__init__.py | 0 .../inference/inference/model_runners.py | 1054 ----------------- rfdiffusion/inference/inference/sym_rots.npz | Bin 7694 -> 0 bytes rfdiffusion/inference/inference/symmetry.py | 236 ---- rfdiffusion/inference/inference/utils.py | 1015 ---------------- 5 files changed, 2305 deletions(-) delete mode 100644 rfdiffusion/inference/inference/__init__.py delete mode 100644 rfdiffusion/inference/inference/model_runners.py delete mode 100644 rfdiffusion/inference/inference/sym_rots.npz delete mode 100644 rfdiffusion/inference/inference/symmetry.py delete mode 100644 rfdiffusion/inference/inference/utils.py diff --git a/rfdiffusion/inference/inference/__init__.py b/rfdiffusion/inference/inference/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/rfdiffusion/inference/inference/model_runners.py b/rfdiffusion/inference/inference/model_runners.py deleted file mode 100644 index f8f45ae4..00000000 --- a/rfdiffusion/inference/inference/model_runners.py +++ /dev/null @@ -1,1054 +0,0 @@ -import torch -import numpy as np -from omegaconf import DictConfig, OmegaConf -from rfdiffusion.RoseTTAFoldModel import RoseTTAFoldModule -from rfdiffusion.kinematics import get_init_xyz, xyz_to_t2d -from rfdiffusion.diffusion import Diffuser -from rfdiffusion.chemical import seq2chars -from rfdiffusion.util_module import ComputeAllAtomCoords -from rfdiffusion.contigs import ContigMap -from rfdiffusion.inference import utils as iu, symmetry -from rfdiffusion.potentials.manager import PotentialManager -import logging -import torch.nn.functional as nn -from rfdiffusion import util -from hydra.core.hydra_config import HydraConfig -import os -import string - -from rfdiffusion.model_input_logger import pickle_function_call -import sys - -SCRIPT_DIR=os.path.dirname(os.path.realpath(__file__)) - -TOR_INDICES = util.torsion_indices -TOR_CAN_FLIP = util.torsion_can_flip -REF_ANGLES = util.reference_angles - - -class Sampler: - - def __init__(self, conf: DictConfig): - """ - Initialize sampler. - Args: - conf: Configuration. - """ - self.initialized = False - self.initialize(conf) - - def initialize(self, conf: DictConfig) -> None: - """ - Initialize sampler. - Args: - conf: Configuration - - - Selects appropriate model from input - - Assembles Config from model checkpoint and command line overrides - - """ - self._log = logging.getLogger(__name__) - if torch.cuda.is_available(): - self.device = torch.device('cuda') - else: - self.device = torch.device('cpu') - needs_model_reload = not self.initialized or conf.inference.ckpt_override_path != self._conf.inference.ckpt_override_path - - # Assign config to Sampler - self._conf = conf - - ################################ - ### Select Appropriate Model ### - ################################ - - if conf.inference.model_directory_path is not None: - model_directory = conf.inference.model_directory_path - else: - model_directory = f"{SCRIPT_DIR}/../../models" - - print(f"Reading models from {model_directory}") - - # Initialize inference only helper objects to Sampler - if conf.inference.ckpt_override_path is not None: - self.ckpt_path = conf.inference.ckpt_override_path - print("WARNING: You're overriding the checkpoint path from the defaults. Check that the model you're providing can run with the inputs you're providing.") - else: - if conf.contigmap.inpaint_seq is not None or conf.contigmap.provide_seq is not None or conf.contigmap.inpaint_str: - # use model trained for inpaint_seq - if conf.contigmap.provide_seq is not None: - # this is only used for partial diffusion - assert conf.diffuser.partial_T is not None, "The provide_seq input is specifically for partial diffusion" - if conf.scaffoldguided.scaffoldguided: - self.ckpt_path = f'{model_directory}/InpaintSeq_Fold_ckpt.pt' - else: - self.ckpt_path = f'{model_directory}/InpaintSeq_ckpt.pt' - elif conf.ppi.hotspot_res is not None and conf.scaffoldguided.scaffoldguided is False: - # use complex trained model - self.ckpt_path = f'{model_directory}/Complex_base_ckpt.pt' - elif conf.scaffoldguided.scaffoldguided is True: - # use complex and secondary structure-guided model - self.ckpt_path = f'{model_directory}/Complex_Fold_base_ckpt.pt' - else: - # use default model - self.ckpt_path = f'{model_directory}/Base_ckpt.pt' - # for saving in trb file: - assert self._conf.inference.trb_save_ckpt_path is None, "trb_save_ckpt_path is not the place to specify an input model. Specify in inference.ckpt_override_path" - self._conf['inference']['trb_save_ckpt_path']=self.ckpt_path - - ####################### - ### Assemble Config ### - ####################### - - if needs_model_reload: - # Load checkpoint, so that we can assemble the config - self.load_checkpoint() - self.assemble_config_from_chk() - # Now actually load the model weights into RF - self.model = self.load_model() - else: - self.assemble_config_from_chk() - - # self.initialize_sampler(conf) - self.initialized=True - - # Initialize helper objects - self.inf_conf = self._conf.inference - self.contig_conf = self._conf.contigmap - self.denoiser_conf = self._conf.denoiser - self.ppi_conf = self._conf.ppi - self.potential_conf = self._conf.potentials - self.diffuser_conf = self._conf.diffuser - self.preprocess_conf = self._conf.preprocess - - if conf.inference.schedule_directory_path is not None: - schedule_directory = conf.inference.schedule_directory_path - else: - schedule_directory = f"{SCRIPT_DIR}/../../schedules" - - # Check for cache schedule - if not os.path.exists(schedule_directory): - os.mkdir(schedule_directory) - self.diffuser = Diffuser(**self._conf.diffuser, cache_dir=schedule_directory) - - ########################### - ### Initialise Symmetry ### - ########################### - - if self.inf_conf.symmetry is not None: - self.symmetry = symmetry.SymGen( - self.inf_conf.symmetry, - self.inf_conf.recenter, - self.inf_conf.radius, - self.inf_conf.model_only_neighbors, - ) - else: - self.symmetry = None - - self.allatom = ComputeAllAtomCoords().to(self.device) - - if self.inf_conf.input_pdb is None: - # set default pdb - script_dir=os.path.dirname(os.path.realpath(__file__)) - self.inf_conf.input_pdb=os.path.join(script_dir, '../../examples/input_pdbs/1qys.pdb') - self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) - self.chain_idx = None - self.idx_pdb = None - - ############################## - ### Handle Partial Noising ### - ############################## - - if self.diffuser_conf.partial_T: - assert self.diffuser_conf.partial_T <= self.diffuser_conf.T - self.t_step_input = int(self.diffuser_conf.partial_T) - else: - self.t_step_input = int(self.diffuser_conf.T) - - @property - def T(self): - ''' - Return the maximum number of timesteps - that this design protocol will perform. - - Output: - T (int): The maximum number of timesteps to perform - ''' - return self.diffuser_conf.T - - def load_checkpoint(self) -> None: - """Loads RF checkpoint, from which config can be generated.""" - self._log.info(f'Reading checkpoint from {self.ckpt_path}') - print('This is inf_conf.ckpt_path') - print(self.ckpt_path) - self.ckpt = torch.load( - self.ckpt_path, map_location=self.device) - - def assemble_config_from_chk(self) -> None: - """ - Function for loading model config from checkpoint directly. - - Takes: - - config file - - Actions: - - Replaces all -model and -diffuser items - - Throws a warning if there are items in -model and -diffuser that aren't in the checkpoint - - This throws an error if there is a flag in the checkpoint 'config_dict' that isn't in the inference config. - This should ensure that whenever a feature is added in the training setup, it is accounted for in the inference script. - - """ - # get overrides to re-apply after building the config from the checkpoint - overrides = [] - if HydraConfig.initialized(): - overrides = HydraConfig.get().overrides.task - print("Assembling -model, -diffuser and -preprocess configs from checkpoint") - - for cat in ['model','diffuser','preprocess']: - for key in self._conf[cat]: - try: - print(f"USING MODEL CONFIG: self._conf[{cat}][{key}] = {self.ckpt['config_dict'][cat][key]}") - self._conf[cat][key] = self.ckpt['config_dict'][cat][key] - except: - pass - - # add overrides back in again - for override in overrides: - if override.split(".")[0] in ['model','diffuser','preprocess']: - print(f'WARNING: You are changing {override.split("=")[0]} from the value this model was trained with. Are you sure you know what you are doing?') - mytype = type(self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]]) - self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]] = mytype(override.split("=")[1]) - - def load_model(self): - """Create RosettaFold model from preloaded checkpoint.""" - - # Read input dimensions from checkpoint. - self.d_t1d=self._conf.preprocess.d_t1d - self.d_t2d=self._conf.preprocess.d_t2d - model = RoseTTAFoldModule(**self._conf.model, d_t1d=self.d_t1d, d_t2d=self.d_t2d, T=self._conf.diffuser.T).to(self.device) - if self._conf.logging.inputs: - pickle_dir = pickle_function_call(model, 'forward', 'inference') - print(f'pickle_dir: {pickle_dir}') - model = model.eval() - self._log.info(f'Loading checkpoint.') - model.load_state_dict(self.ckpt['model_state_dict'], strict=True) - return model - - def construct_contig(self, target_feats): - """ - Construct contig class describing the protein to be generated - """ - self._log.info(f'Using contig: {self.contig_conf.contigs}') - return ContigMap(target_feats, **self.contig_conf) - - def construct_denoiser(self, L, visible): - """Make length-specific denoiser.""" - denoise_kwargs = OmegaConf.to_container(self.diffuser_conf) - denoise_kwargs.update(OmegaConf.to_container(self.denoiser_conf)) - denoise_kwargs.update({ - 'L': L, - 'diffuser': self.diffuser, - 'potential_manager': self.potential_manager, - }) - return iu.Denoise(**denoise_kwargs) - - def sample_init(self, return_forward_trajectory=False): - """ - Initial features to start the sampling process. - - Modify signature and function body for different initialization - based on the config. - - Returns: - xt: Starting positions with a portion of them randomly sampled. - seq_t: Starting sequence with a portion of them set to unknown. - """ - - ####################### - ### Parse input pdb ### - ####################### - - self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) - - ################################ - ### Generate specific contig ### - ################################ - - # Generate a specific contig from the range of possibilities specified at input - - self.contig_map = self.construct_contig(self.target_feats) - self.mappings = self.contig_map.get_mappings() - self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:] - self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:] - self.binderlen = len(self.contig_map.inpaint) - - ####################################### - ### Resolve cyclic peptide indicies ### - ####################################### - if self._conf.inference.cyclic: - if self._conf.inference.cyc_chains is None: - # default to all residues being cyclized - self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() - else: - # use cyc_chains arg to determine cyclic_reses mask - assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string' - cyc_chains = self._conf.inference.cyc_chains - cyc_chains = [i.upper() for i in cyc_chains] - hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains - is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty - - for ch in cyc_chains: - ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() - is_cyclized[ch_mask] = True # set this whole chain to be cyclic - self.cyclic_reses = is_cyclized - else: - self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() - - #################### - ### Get Hotspots ### - #################### - - self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) - - - ##################################### - ### Initialise Potentials Manager ### - ##################################### - - self.potential_manager = PotentialManager(self.potential_conf, - self.ppi_conf, - self.diffuser_conf, - self.inf_conf, - self.hotspot_0idx, - self.binderlen) - - ################################### - ### Initialize other attributes ### - ################################### - - xyz_27 = self.target_feats['xyz_27'] - mask_27 = self.target_feats['mask_27'] - seq_orig = self.target_feats['seq'] - L_mapped = len(self.contig_map.ref) - contig_map=self.contig_map - - self.diffusion_mask = self.mask_str - length_bound = self.contig_map.sampled_mask_length_bound.copy() - - first_res = 0 - self.chain_idx = [] - self.idx_pdb = [] - all_chains = {contig_ref[0] for contig_ref in self.contig_map.ref} - available_chains = sorted(list(set(string.ascii_letters) - all_chains)) - - # Iterate over each chain - for last_res in length_bound: - chain_ids = {contig_ref[0] for contig_ref in self.contig_map.ref[first_res: last_res]} - # If we are designing this chain, it will have a '-' in the contig map - # Renumber this chain from 1 - if "_" in chain_ids: - self.idx_pdb += [idx + 1 for idx in range(last_res - first_res)] - chain_ids = chain_ids - {"_"} - # If there are no fixed residues that have a chain id, pick the first available letter - if not chain_ids: - if not available_chains: - raise ValueError(f"No available chains! You are trying to design a new chain, and you have " - f"already used all upper- and lower-case chain ids (up to 52 chains): " - f"{','.join(all_chains)}.") - chain_id = available_chains[0] - available_chains.remove(chain_id) - # Otherwise, use the chain of the fixed (motif) residues - else: - assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" - chain_id = list(chain_ids)[0] - self.chain_idx += [chain_id] * (last_res - first_res) - # If this is a fixed chain, maintain the chain and residue numbering - else: - self.idx_pdb += [contig_ref[1] for contig_ref in self.contig_map.ref[first_res: last_res]] - assert len(chain_ids) == 1, f"Error: Multiple chain IDs in chain: {chain_ids}" - self.chain_idx += [list(chain_ids)[0]] * (last_res - first_res) - first_res = last_res - - #################################### - ### Generate initial coordinates ### - #################################### - - if self.diffuser_conf.partial_T: - assert xyz_27.shape[0] == L_mapped, f"there must be a coordinate in the input PDB for \ - each residue implied by the contig string for partial diffusion. length of \ - input PDB != length of contig string: {xyz_27.shape[0]} != {L_mapped}" - assert contig_map.hal_idx0 == contig_map.ref_idx0, f'for partial diffusion there can \ - be no offset between the index of a residue in the input and the index of the \ - residue in the output, {contig_map.hal_idx0} != {contig_map.ref_idx0}' - # Partially diffusing from a known structure - xyz_mapped=xyz_27 - atom_mask_mapped = mask_27 - else: - # Fully diffusing from points initialised at the origin - # adjust size of input xt according to residue map - xyz_mapped = torch.full((1,1,L_mapped,27,3), np.nan) - xyz_mapped[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...] - xyz_motif_prealign = xyz_mapped.clone() - motif_prealign_com = xyz_motif_prealign[0,0,:,1].mean(dim=0) - self.motif_com = xyz_27[contig_map.ref_idx0,1].mean(dim=0) - xyz_mapped = get_init_xyz(xyz_mapped).squeeze() - # adjust the size of the input atom map - atom_mask_mapped = torch.full((L_mapped, 27), False) - atom_mask_mapped[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] - - # Diffuse the contig-mapped coordinates - if self.diffuser_conf.partial_T: - assert self.diffuser_conf.partial_T <= self.diffuser_conf.T, "Partial_T must be less than T" - self.t_step_input = int(self.diffuser_conf.partial_T) - else: - self.t_step_input = int(self.diffuser_conf.T) - t_list = np.arange(1, self.t_step_input+1) - - ################################# - ### Generate initial sequence ### - ################################# - - seq_t = torch.full((1,L_mapped), 21).squeeze() # 21 is the mask token - seq_t[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0] - - # Unmask sequence if desired - if self._conf.contigmap.provide_seq is not None: - seq_t[self.mask_seq.squeeze()] = seq_orig[self.mask_seq.squeeze()] - - seq_t[~self.mask_seq.squeeze()] = 21 - seq_t = torch.nn.functional.one_hot(seq_t, num_classes=22).float() # [L,22] - seq_orig = torch.nn.functional.one_hot(seq_orig, num_classes=22).float() # [L,22] - - fa_stack, xyz_true = self.diffuser.diffuse_pose( - xyz_mapped, - torch.clone(seq_t), - atom_mask_mapped.squeeze(), - diffusion_mask=self.diffusion_mask.squeeze(), - t_list=t_list) - xT = fa_stack[-1].squeeze()[:,:14,:] - xt = torch.clone(xT) - - self.denoiser = self.construct_denoiser(len(self.contig_map.ref), visible=self.mask_seq.squeeze()) - - ###################### - ### Apply Symmetry ### - ###################### - - if self.symmetry is not None: - xt, seq_t = self.symmetry.apply_symmetry(xt, seq_t) - self._log.info(f'Sequence init: {seq2chars(torch.argmax(seq_t, dim=-1))}') - - self.msa_prev = None - self.pair_prev = None - self.state_prev = None - - ######################################### - ### Parse ligand for ligand potential ### - ######################################### - - if self.potential_conf.guiding_potentials is not None: - if any(list(filter(lambda x: "substrate_contacts" in x, self.potential_conf.guiding_potentials))): - assert len(self.target_feats['xyz_het']) > 0, "If you're using the Substrate Contact potential, \ - you need to make sure there's a ligand in the input_pdb file!" - het_names = np.array([i['name'].strip() for i in self.target_feats['info_het']]) - xyz_het = self.target_feats['xyz_het'][het_names == self._conf.potentials.substrate] - xyz_het = torch.from_numpy(xyz_het) - assert xyz_het.shape[0] > 0, f'expected >0 heteroatoms from ligand with name {self._conf.potentials.substrate}' - xyz_motif_prealign = xyz_motif_prealign[0,0][self.diffusion_mask.squeeze()] - motif_prealign_com = xyz_motif_prealign[:,1].mean(dim=0) - xyz_het_com = xyz_het.mean(dim=0) - for pot in self.potential_manager.potentials_to_apply: - pot.motif_substrate_atoms = xyz_het - pot.diffusion_mask = self.diffusion_mask.squeeze() - pot.xyz_motif = xyz_motif_prealign - pot.diffuser = self.diffuser - return xt, seq_t - - def _preprocess(self, seq, xyz_t, t, repack=False): - - """ - Function to prepare inputs to diffusion model - - seq (L,22) one-hot sequence - - msa_masked (1,1,L,48) - - msa_full (1,1,L,25) - - xyz_t (L,14,3) template crds (diffused) - - t1d (1,L,28) this is the t1d before tacking on the chi angles: - - seq + unknown/mask (21) - - global timestep (1-t/T if not motif else 1) (1) - - MODEL SPECIFIC: - - contacting residues: for ppi. Target residues in contact with binder (1) - - empty feature (legacy) (1) - - ss (H, E, L, MASK) (4) - - t2d (1, L, L, 45) - - last plane is block adjacency - """ - - L = seq.shape[0] - T = self.T - binderlen = self.binderlen - target_res = self.ppi_conf.hotspot_res - - ################## - ### msa_masked ### - ################## - msa_masked = torch.zeros((1,1,L,48)) - msa_masked[:,:,:,:22] = seq[None, None] - msa_masked[:,:,:,22:44] = seq[None, None] - msa_masked[:,:,0,46] = 1.0 - msa_masked[:,:,-1,47] = 1.0 - - ################ - ### msa_full ### - ################ - msa_full = torch.zeros((1,1,L,25)) - msa_full[:,:,:,:22] = seq[None, None] - msa_full[:,:,0,23] = 1.0 - msa_full[:,:,-1,24] = 1.0 - - ########### - ### t1d ### - ########### - - # Here we need to go from one hot with 22 classes to one hot with 21 classes (last plane is missing token) - t1d = torch.zeros((1,1,L,21)) - - seqt1d = torch.clone(seq) - for idx in range(L): - if seqt1d[idx,21] == 1: - seqt1d[idx,20] = 1 - seqt1d[idx,21] = 0 - - t1d[:,:,:,:21] = seqt1d[None,None,:,:21] - - - # Set timestep feature to 1 where diffusion mask is True, else 1-t/T - timefeature = torch.zeros((L)).float() - timefeature[self.mask_str.squeeze()] = 1 - timefeature[~self.mask_str.squeeze()] = 1 - t/self.T - timefeature = timefeature[None,None,...,None] - - t1d = torch.cat((t1d, timefeature), dim=-1).float() - - ############# - ### xyz_t ### - ############# - if self.preprocess_conf.sidechain_input: - xyz_t[torch.where(seq == 21, True, False),3:,:] = float('nan') - else: - xyz_t[~self.mask_str.squeeze(),3:,:] = float('nan') - - xyz_t=xyz_t[None, None] - xyz_t = torch.cat((xyz_t, torch.full((1,1,L,13,3), float('nan'))), dim=3) - - ########### - ### t2d ### - ########### - t2d = xyz_to_t2d(xyz_t) - - ########### - ### idx ### - ########### - idx = torch.tensor(self.contig_map.rf)[None] - - ############### - ### alpha_t ### - ############### - seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L) - alpha, _, alpha_mask, _ = util.get_torsions(xyz_t.reshape(-1, L, 27, 3), seq_tmp, TOR_INDICES, TOR_CAN_FLIP, REF_ANGLES) - alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) - alpha[torch.isnan(alpha)] = 0.0 - alpha = alpha.reshape(1,-1,L,10,2) - alpha_mask = alpha_mask.reshape(1,-1,L,10,1) - alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, L, 30) - - #put tensors on device - msa_masked = msa_masked.to(self.device) - msa_full = msa_full.to(self.device) - seq = seq.to(self.device) - xyz_t = xyz_t.to(self.device) - idx = idx.to(self.device) - t1d = t1d.to(self.device) - t2d = t2d.to(self.device) - alpha_t = alpha_t.to(self.device) - - ###################### - ### added_features ### - ###################### - if self.preprocess_conf.d_t1d >= 24: # add hotspot residues - hotspot_tens = torch.zeros(L).float() - if self.ppi_conf.hotspot_res is None: - print("WARNING: you're using a model trained on complexes and hotspot residues, without specifying hotspots.\ - If you're doing monomer diffusion this is fine") - hotspot_idx=[] - else: - hotspots = [(i[0],int(i[1:])) for i in self.ppi_conf.hotspot_res] - hotspot_idx=[] - for i,res in enumerate(self.contig_map.con_ref_pdb_idx): - if res in hotspots: - hotspot_idx.append(self.contig_map.hal_idx0[i]) - hotspot_tens[hotspot_idx] = 1.0 - - # Add blank (legacy) feature and hotspot tensor - t1d=torch.cat((t1d, torch.zeros_like(t1d[...,:1]), hotspot_tens[None,None,...,None].to(self.device)), dim=-1) - - return msa_masked, msa_full, seq[None], torch.squeeze(xyz_t, dim=0), idx, t1d, t2d, xyz_t, alpha_t - - def sample_step(self, *, t, x_t, seq_init, final_step): - '''Generate the next pose that the model should be supplied at timestep t-1. - - Args: - t (int): The timestep that has just been predicted - seq_t (torch.tensor): (L,22) The sequence at the beginning of this timestep - x_t (torch.tensor): (L,14,3) The residue positions at the beginning of this timestep - seq_init (torch.tensor): (L,22) The initialized sequence used in updating the sequence. - - Returns: - px0: (L,14,3) The model's prediction of x0. - x_t_1: (L,14,3) The updated positions of the next step. - seq_t_1: (L,22) The updated sequence of the next step. - tors_t_1: (L, ?) The updated torsion angles of the next step. - plddt: (L, 1) Predicted lDDT of x0. - ''' - msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess( - seq_init, x_t, t) - - N,L = msa_masked.shape[:2] - - if self.symmetry is not None: - idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb) - - msa_prev = None - pair_prev = None - state_prev = None - - with torch.no_grad(): - msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked, - msa_full, - seq_in, - xt_in, - idx_pdb, - t1d=t1d, - t2d=t2d, - xyz_t=xyz_t, - alpha_t=alpha_t, - msa_prev = msa_prev, - pair_prev = pair_prev, - state_prev = state_prev, - t=torch.tensor(t), - return_infer=True, - motif_mask=self.diffusion_mask.squeeze().to(self.device)) - - # prediction of X0 - _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) - px0 = px0.squeeze()[:,:14] - - ##################### - ### Get next pose ### - ##################### - - if t > final_step: - seq_t_1 = nn.one_hot(seq_init,num_classes=22).to(self.device) - x_t_1, px0 = self.denoiser.get_next_pose( - xt=x_t, - px0=px0, - t=t, - diffusion_mask=self.mask_str.squeeze(), - align_motif=self.inf_conf.align_motif - ) - else: - x_t_1 = torch.clone(px0).to(x_t.device) - seq_t_1 = torch.clone(seq_init) - px0 = px0.to(x_t.device) - - if self.symmetry is not None: - x_t_1, seq_t_1 = self.symmetry.apply_symmetry(x_t_1, seq_t_1) - - return px0, x_t_1, seq_t_1, plddt - - -class SelfConditioning(Sampler): - """ - Model Runner for self conditioning - pX0[t+1] is provided as a template input to the model at time t - """ - - def sample_step(self, *, t, x_t, seq_init, final_step): - ''' - Generate the next pose that the model should be supplied at timestep t-1. - Args: - t (int): The timestep that has just been predicted - seq_t (torch.tensor): (L,22) The sequence at the beginning of this timestep - x_t (torch.tensor): (L,14,3) The residue positions at the beginning of this timestep - seq_init (torch.tensor): (L,22) The initialized sequence used in updating the sequence. - Returns: - px0: (L,14,3) The model's prediction of x0. - x_t_1: (L,14,3) The updated positions of the next step. - seq_t_1: (L) The sequence to the next step (== seq_init) - plddt: (L, 1) Predicted lDDT of x0. - ''' - - msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess( - seq_init, x_t, t) - B,N,L = xyz_t.shape[:3] - - ################################## - ######## Str Self Cond ########### - ################################## - if (t < self.diffuser.T) and (t != self.diffuser_conf.partial_T): - zeros = torch.zeros(B,1,L,24,3).float().to(xyz_t.device) - xyz_t = torch.cat((self.prev_pred.unsqueeze(1),zeros), dim=-2) # [B,T,L,27,3] - t2d_44 = xyz_to_t2d(xyz_t) # [B,T,L,L,44] - else: - xyz_t = torch.zeros_like(xyz_t) - t2d_44 = torch.zeros_like(t2d[...,:44]) - # No effect if t2d is only dim 44 - t2d[...,:44] = t2d_44 - - if self.symmetry is not None: - idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb) - - #################### - ### Forward Pass ### - #################### - - with torch.no_grad(): - msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked, - msa_full, - seq_in, - xt_in, - idx_pdb, - t1d=t1d, - t2d=t2d, - xyz_t=xyz_t, - alpha_t=alpha_t, - msa_prev = None, - pair_prev = None, - state_prev = None, - t=torch.tensor(t), - return_infer=True, - motif_mask=self.diffusion_mask.squeeze().to(self.device), - cyclic_reses=self.cyclic_reses) - - if self.symmetry is not None and self.inf_conf.symmetric_self_cond: - px0 = self.symmetrise_prev_pred(px0=px0,seq_in=seq_in, alpha=alpha)[:,:,:3] - - self.prev_pred = torch.clone(px0) - - # prediction of X0 - _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) - px0 = px0.squeeze()[:,:14] - - ########################### - ### Generate Next Input ### - ########################### - - seq_t_1 = torch.clone(seq_init) - if t > final_step: - x_t_1, px0 = self.denoiser.get_next_pose( - xt=x_t, - px0=px0, - t=t, - diffusion_mask=self.mask_str.squeeze(), - align_motif=self.inf_conf.align_motif, - include_motif_sidechains=self.preprocess_conf.motif_sidechain_input - ) - self._log.info( - f'Timestep {t}, input to next step: { seq2chars(torch.argmax(seq_t_1, dim=-1).tolist())}') - else: - x_t_1 = torch.clone(px0).to(x_t.device) - px0 = px0.to(x_t.device) - - ###################### - ### Apply symmetry ### - ###################### - - if self.symmetry is not None: - x_t_1, seq_t_1 = self.symmetry.apply_symmetry(x_t_1, seq_t_1) - - return px0, x_t_1, seq_t_1, plddt - - def symmetrise_prev_pred(self, px0, seq_in, alpha): - """ - Method for symmetrising px0 output for self-conditioning - """ - _,px0_aa = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) - px0_sym,_ = self.symmetry.apply_symmetry(px0_aa.to('cpu').squeeze()[:,:14], torch.argmax(seq_in, dim=-1).squeeze().to('cpu')) - px0_sym = px0_sym[None].to(self.device) - return px0_sym - -class ScaffoldedSampler(SelfConditioning): - """ - Model Runner for Scaffold-Constrained diffusion - """ - def __init__(self, conf: DictConfig): - """ - Initialize scaffolded sampler. - Two basic approaches here: - i) Given a block adjacency/secondary structure input, generate a fold (in the presence or absence of a target) - - This allows easy generation of binders or specific folds - - Allows simple expansion of an input, to sample different lengths - ii) Providing a contig input and corresponding block adjacency/secondary structure input - - This allows mixed motif scaffolding and fold-conditioning. - - Adjacency/secondary structure inputs must correspond exactly in length to the contig string - """ - super().__init__(conf) - # initialize BlockAdjacency sampling class - if conf.scaffoldguided.scaffold_dir is None: - assert any(x is not None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop)) - if conf.contigmap.inpaint_str_loop is not None: - assert conf.scaffoldguided.mask_loops == False, "You shouldn't be masking loops if you're specifying loop secondary structure" - else: - # initialize BlockAdjacency sampling class - assert all(x is None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop)), "can't provide scaffold_dir if you're also specifying per-residue ss" - self.blockadjacency = iu.BlockAdjacency(conf.scaffoldguided, conf.inference.num_designs) - - - ################################################# - ### Initialize target, if doing binder design ### - ################################################# - - if conf.scaffoldguided.target_pdb: - self.target = iu.Target(conf.scaffoldguided, conf.ppi.hotspot_res) - self.target_pdb = self.target.get_target() - if conf.scaffoldguided.target_ss is not None: - self.target_ss = torch.load(conf.scaffoldguided.target_ss).long() - self.target_ss = torch.nn.functional.one_hot(self.target_ss, num_classes=4) - if self._conf.scaffoldguided.contig_crop is not None: - self.target_ss=self.target_ss[self.target_pdb['crop_mask']] - if conf.scaffoldguided.target_adj is not None: - self.target_adj = torch.load(conf.scaffoldguided.target_adj).long() - self.target_adj=torch.nn.functional.one_hot(self.target_adj, num_classes=3) - if self._conf.scaffoldguided.contig_crop is not None: - self.target_adj=self.target_adj[self.target_pdb['crop_mask']] - self.target_adj=self.target_adj[:,self.target_pdb['crop_mask']] - else: - self.target = None - self.target_pdb=False - - def sample_init(self): - """ - Wrapper method for taking secondary structure + adj, and outputting xt, seq_t - """ - - ########################## - ### Process Fold Input ### - ########################## - if hasattr(self, 'blockadjacency'): - self.L, self.ss, self.adj = self.blockadjacency.get_scaffold() - self.adj = nn.one_hot(self.adj.long(), num_classes=3) - else: - self.L=100 # shim. Get's overwritten - - ############################## - ### Auto-contig generation ### - ############################## - - if self.contig_conf.contigs is None: - # process target - xT = torch.full((self.L, 27,3), np.nan) - xT = get_init_xyz(xT[None,None]).squeeze() - seq_T = torch.full((self.L,),21) - self.diffusion_mask = torch.full((self.L,),False) - atom_mask = torch.full((self.L,27), False) - self.binderlen=self.L - - if self.target: - target_L = np.shape(self.target_pdb['xyz'])[0] - # xyz - target_xyz = torch.full((target_L, 27, 3), np.nan) - target_xyz[:,:14,:] = torch.from_numpy(self.target_pdb['xyz']) - xT = torch.cat((xT, target_xyz), dim=0) - # seq - seq_T = torch.cat((seq_T, torch.from_numpy(self.target_pdb['seq'])), dim=0) - # diffusion mask - self.diffusion_mask = torch.cat((self.diffusion_mask, torch.full((target_L,), True)),dim=0) - # atom mask - mask_27 = torch.full((target_L, 27), False) - mask_27[:,:14] = torch.from_numpy(self.target_pdb['mask']) - atom_mask = torch.cat((atom_mask, mask_27), dim=0) - self.L += target_L - # generate contigmap object - contig = [] - for idx,i in enumerate(self.target_pdb['pdb_idx'][:-1]): - if idx==0: - start=i[1] - if i[1] + 1 != self.target_pdb['pdb_idx'][idx+1][1] or i[0] != self.target_pdb['pdb_idx'][idx+1][0]: - contig.append(f'{i[0]}{start}-{i[1]}/0 ') - start = self.target_pdb['pdb_idx'][idx+1][1] - contig.append(f"{self.target_pdb['pdb_idx'][-1][0]}{start}-{self.target_pdb['pdb_idx'][-1][1]}/0 ") - contig.append(f"{self.binderlen}-{self.binderlen}") - contig = ["".join(contig)] - else: - contig = [f"{self.binderlen}-{self.binderlen}"] - self.contig_map=ContigMap(self.target_pdb, contig) - self.mappings = self.contig_map.get_mappings() - self.mask_seq = self.diffusion_mask - self.mask_str = self.diffusion_mask - L_mapped=len(self.contig_map.ref) - - ############################ - ### Specific Contig mode ### - ############################ - - else: - # get contigmap from command line - assert self.target is None, "Giving a target is the wrong way of handling this is you're doing contigs and secondary structure" - - # process target and reinitialise potential_manager. This is here because the 'target' is always set up to be the second chain in out inputs. - self.target_feats = iu.process_target(self.inf_conf.input_pdb) - self.contig_map = self.construct_contig(self.target_feats) - self.mappings = self.contig_map.get_mappings() - self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:] - self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:] - self.binderlen = len(self.contig_map.inpaint) - self.L = len(self.contig_map.inpaint_seq) - target_feats = self.target_feats - contig_map = self.contig_map - - xyz_27 = target_feats['xyz_27'] - mask_27 = target_feats['mask_27'] - seq_orig = target_feats['seq'] - L_mapped = len(self.contig_map.ref) - seq_T=torch.full((L_mapped,),21) - seq_T[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0] - seq_T[~self.mask_seq.squeeze()] = 21 - - diffusion_mask = self.mask_str - self.diffusion_mask = diffusion_mask - - xT = torch.full((1,1,L_mapped,27,3), np.nan) - xT[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...] - xT = get_init_xyz(xT).squeeze() - atom_mask = torch.full((L_mapped, 27), False) - atom_mask[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] - - if hasattr(self.contig_map, 'ss_spec'): - self.adj=torch.full((L_mapped, L_mapped),2) # masked - self.adj=nn.one_hot(self.adj.long(), num_classes=3) - self.ss=iu.ss_from_contig(self.contig_map.ss_spec) - assert L_mapped==self.adj.shape[0] - - #################### - ### Get hotspots ### - #################### - self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) - - ######################### - ### Set up potentials ### - ######################### - - self.potential_manager = PotentialManager(self.potential_conf, - self.ppi_conf, - self.diffuser_conf, - self.inf_conf, - self.hotspot_0idx, - self.binderlen) - - self.chain_idx=['A' if i < self.binderlen else 'B' for i in range(self.L)] - - ######################## - ### Handle Partial T ### - ######################## - - if self.diffuser_conf.partial_T: - assert self.diffuser_conf.partial_T <= self.diffuser_conf.T - self.t_step_input = int(self.diffuser_conf.partial_T) - else: - self.t_step_input = int(self.diffuser_conf.T) - t_list = np.arange(1, self.t_step_input+1) - seq_T=torch.nn.functional.one_hot(seq_T, num_classes=22).float() - - fa_stack, xyz_true = self.diffuser.diffuse_pose( - xT, - torch.clone(seq_T), - atom_mask.squeeze(), - diffusion_mask=self.diffusion_mask.squeeze(), - t_list=t_list, - include_motif_sidechains=self.preprocess_conf.motif_sidechain_input) - - ####################### - ### Set up Denoiser ### - ####################### - - self.denoiser = self.construct_denoiser(self.L, visible=self.mask_seq.squeeze()) - - ####################################### - ### Resolve cyclic peptide indicies ### - ####################################### - if self._conf.inference.cyclic: - if self._conf.inference.cyc_chains is None: - # default to all residues being cyclized - self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() - else: - # use cyc_chains arg to determine cyclic_reses mask - assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string' - cyc_chains = self._conf.inference.cyc_chains - cyc_chains = [i.upper() for i in cyc_chains] - hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains - is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty - - for ch in cyc_chains: - ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() - is_cyclized[ch_mask] = True # set this whole chain to be cyclic - self.cyclic_reses = is_cyclized - else: - self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() - - xT = torch.clone(fa_stack[-1].squeeze()[:,:14,:]) - return xT, seq_T - - def _preprocess(self, seq, xyz_t, t): - msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t = super()._preprocess(seq, xyz_t, t, repack=False) - - ################################### - ### Add Adj/Secondary Structure ### - ################################### - - assert self.preprocess_conf.d_t1d == 28, "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" - assert self.preprocess_conf.d_t2d == 47, "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" - - ##################### - ### Handle Target ### - ##################### - - if self.target: - blank_ss = torch.nn.functional.one_hot(torch.full((self.L-self.binderlen,), 3), num_classes=4) - full_ss = torch.cat((self.ss, blank_ss), dim=0) - if self._conf.scaffoldguided.target_ss is not None: - full_ss[self.binderlen:] = self.target_ss - else: - full_ss = self.ss - t1d=torch.cat((t1d, full_ss[None,None].to(self.device)), dim=-1) - - t1d = t1d.float() - - ########### - ### t2d ### - ########### - - if self.d_t2d == 47: - if self.target: - full_adj = torch.zeros((self.L, self.L, 3)) - full_adj[:,:,-1] = 1. #set to mask - full_adj[:self.binderlen, :self.binderlen] = self.adj - if self._conf.scaffoldguided.target_adj is not None: - full_adj[self.binderlen:,self.binderlen:] = self.target_adj - else: - full_adj = self.adj - t2d=torch.cat((t2d, full_adj[None,None].to(self.device)),dim=-1) - - ########### - ### idx ### - ########### - - if self.target: - idx_pdb[:,self.binderlen:] += 200 - - return msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t diff --git a/rfdiffusion/inference/inference/sym_rots.npz b/rfdiffusion/inference/inference/sym_rots.npz deleted file mode 100644 index 8e6b38011a547d085099c0be682733f3aa573b1b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7694 zcmb`ML5o~P6osqps1PA4%r?+jwZb4eifBk^agoiS^u~oNZDu+|2$`XKP=g5XAGmN4 zTo`cW!hN=Hmp|Yi7;xpv`3ZKV?ssPTl*JoN-lY4!Tj$(!?yc$9uQNU0A79x?<+zc4 zdi?17({W1q@p{@xC(D!7;@KY#wkXY(&d>D%VYmzPC7DZt}O@+uNPIzxns}?&QbG z|2nRgF3x5pww}-WK)U~C2nP-xxRTn2sJajAE*41yUhpfV7arwATbh+TIB^@jkVT#PC+GT&ywyKS@}ZvBCxihVucwUI#6y?< z#iP(7JX=o8Cmdbdhe3V#VIVyWLl-}m(IqcHvS`mE}}Z9e9x{7qLp)1e>jtL>9;uZtXA@oX;k zN4V92qtoMWI_w8Hx%xI;{9IRcjr{?y>U`fDe$?YnjvhIE!xPT@=v$xE^?@sgcgcOj z4;|(Bn+`nst>&O#tD`@!>+!KUD&59h-VeFvFkSh1fAAXpRJsHG$jA4?bNov5XYPwc zKBy>0L4_mKl@sot;(=R)XUpgR!FHf43+NKnU*MEU-_{+M4nQNKM-;8o*|U?TUhu%B zWA&)DaO&Y{{#c-nD4v}N?*g!hPm8B@)1{w5Jg!gi%wIaeBX>J}9^*O;e)DsbaIQnH zPrF|F;?(?H2hRHk{(k-92d*4He8>m-nSn!=Rg=dk`Xzt0KJ zbnwTg@;4pvOouw=QXahKdg!x{@T#0Xe17^c9q$ue^<_Hv3HQG7u{!zCA9am>;G;Pke$4AU>cPq3;iugCq)t3= z+b7qhZWjBdoPDu5;K|4KoBq&&=ecFL{aG{5Pz=2G9L!(8&AKjyHx z;K|4R>DTM#`=gh3#;xyAi(7u~M)}#kQT5{O?^kj`Il9q}E6KiX@!IXb4N;kme|s(Y WUyA1YFTav0bnk', coords_out[:subunit_len], self.sym_rots[i]) - seq_out[start_i:end_i] = seq_out[:subunit_len] - return coords_out, seq_out - - def _lin_chainbreaks(self, num_breaks, res_idx, offset=None): - assert res_idx.ndim == 2 - res_idx = torch.clone(res_idx) - subunit_len = res_idx.shape[-1] // num_breaks - chain_delimiters = [] - if offset is None: - offset = res_idx.shape[-1] - for i in range(num_breaks): - start_i = subunit_len * i - end_i = subunit_len * (i+1) - chain_labels = list(string.ascii_uppercase) + [str(i+j) for i in - string.ascii_uppercase for j in string.ascii_uppercase] - chain_delimiters.extend( - [chain_labels[i] for _ in range(subunit_len)] - ) - res_idx[:, start_i:end_i] = res_idx[:, start_i:end_i] + offset * (i+1) - return res_idx, chain_delimiters - - ####################### - ## Dihedral symmetry ## - ####################### - def _init_dihedral(self, order): - sym_rots = [] - flip = Rotation.from_euler('x', 180, degrees=True).as_matrix() - for i in range(order): - deg = i * 360.0 / order - rot = Rotation.from_euler('z', deg, degrees=True).as_matrix() - sym_rots.append(format_rots(rot)) - rot2 = flip @ rot - sym_rots.append(format_rots(rot2)) - self.sym_rots = sym_rots - self.order = order * 2 - - ######################### - ## Octahedral symmetry ## - ######################### - def _init_octahedral(self): - sym_rots = np.load(f"{pathlib.Path(__file__).parent.resolve()}/sym_rots.npz") - self.sym_rots = [ - torch.tensor(v_i, dtype=torch.float32) - for v_i in sym_rots['octahedral'] - ] - self.order = len(self.sym_rots) - - def _apply_octahedral(self, coords_in, seq_in): - coords_out = torch.clone(coords_in) - seq_out = torch.clone(seq_in) - if seq_out.shape[0] % self.order != 0: - raise ValueError( - f'Sequence length must be divisble by {self.order}') - subunit_len = seq_out.shape[0] // self.order - base_axis = torch.tensor([self._radius, 0., 0.])[None] - for i in range(self.order): - start_i = subunit_len * i - end_i = subunit_len * (i+1) - subunit_chain = torch.einsum( - 'bnj,kj->bnk', coords_in[:subunit_len], self.sym_rots[i]) - - if self._recenter: - center = torch.mean(subunit_chain[:, 1, :], axis=0) - subunit_chain -= center[None, None, :] - rotated_axis = torch.einsum( - 'nj,kj->nk', base_axis, self.sym_rots[i]) - subunit_chain += rotated_axis[:, None, :] - - coords_out[start_i:end_i] = subunit_chain - seq_out[start_i:end_i] = seq_out[:subunit_len] - return coords_out, seq_out - - ####################### - ## symmetry from file # - ####################### - def _init_from_symrots_file(self, name): - """ _init_from_symrots_file initializes using - ./inference/sym_rots.npz - - Args: - name: name of symmetry (of tetrahedral, octahedral, icosahedral) - - sets self.sym_rots to be a list of torch.tensor of shape [3, 3] - """ - assert name in saved_symmetries, name + " not in " + str(saved_symmetries) - - # Load in list of rotation matrices for `name` - fn = f"{pathlib.Path(__file__).parent.resolve()}/sym_rots.npz" - obj = np.load(fn) - symms = None - for k, v in obj.items(): - if str(k) == name: symms = v - assert symms is not None, "%s not found in %s"%(name, fn) - - - self.sym_rots = [torch.tensor(v_i, dtype=torch.float32) for v_i in symms] - self.order = len(self.sym_rots) - - # Return if identity is the first rotation - if not np.isclose(((self.sym_rots[0]-np.eye(3))**2).sum(), 0): - - # Move identity to be the first rotation - for i, rot in enumerate(self.sym_rots): - if np.isclose(((rot-np.eye(3))**2).sum(), 0): - self.sym_rots = [self.sym_rots.pop(i)] + self.sym_rots - - assert len(self.sym_rots) == self.order - assert np.isclose(((self.sym_rots[0]-np.eye(3))**2).sum(), 0) - - def close_neighbors(self): - """close_neighbors finds the rotations within self.sym_rots that - correspond to close neighbors. - - Returns: - list of rotation matrices corresponding to the identity and close neighbors - """ - # set of small rotation angle rotations - rel_rot = lambda M: np.linalg.norm(Rotation.from_matrix(M).as_rotvec()) - rel_rots = [(i+1, rel_rot(M)) for i, M in enumerate(self.sym_rots[1:])] - min_rot = min(rel_rot_val[1] for rel_rot_val in rel_rots) - close_rots = [np.eye(3)] + [ - self.sym_rots[i] for i, rel_rot_val in rel_rots if - np.isclose(rel_rot_val, min_rot) - ] - return close_rots diff --git a/rfdiffusion/inference/inference/utils.py b/rfdiffusion/inference/inference/utils.py deleted file mode 100644 index 2ed6105b..00000000 --- a/rfdiffusion/inference/inference/utils.py +++ /dev/null @@ -1,1015 +0,0 @@ -import numpy as np -import os -from omegaconf import DictConfig -import torch -import torch.nn.functional as nn -from rfdiffusion.diffusion import get_beta_schedule -from scipy.spatial.transform import Rotation as scipy_R -from rfdiffusion.util import rigid_from_3_points -from rfdiffusion.util_module import ComputeAllAtomCoords -from rfdiffusion import util -import random -import logging -from rfdiffusion.inference import model_runners -import glob - -########################################################### -#### Functions which can be called outside of Denoiser #### -########################################################### - - -def get_next_frames(xt, px0, t, diffuser, so3_type, diffusion_mask, noise_scale=1.0): - """ - get_next_frames gets updated frames using IGSO(3) + score_based reverse diffusion. - - - based on self.so3_type use score based update. - - Generate frames at t-1 - Rather than generating random rotations (as occurs during forward process), calculate rotation between xt and px0 - - Args: - xt: noised coordinates of shape [L, 14, 3] - px0: prediction of coordinates at t=0, of shape [L, 14, 3] - t: integer time step - diffuser: Diffuser object for reverse igSO3 sampling - so3_type: The type of SO3 noising being used ('igso3') - diffusion_mask: of shape [L] of type bool, True means not to be - updated (e.g. mask is true for motif residues) - noise_scale: scale factor for the noise added (IGSO3 only) - - Returns: - backbone coordinates for step x_t-1 of shape [L, 3, 3] - """ - N_0 = px0[None, :, 0, :] - Ca_0 = px0[None, :, 1, :] - C_0 = px0[None, :, 2, :] - - R_0, Ca_0 = rigid_from_3_points(N_0, Ca_0, C_0) - - N_t = xt[None, :, 0, :] - Ca_t = xt[None, :, 1, :] - C_t = xt[None, :, 2, :] - - R_t, Ca_t = rigid_from_3_points(N_t, Ca_t, C_t) - - # this must be to normalize them or something - R_0 = scipy_R.from_matrix(R_0.squeeze().numpy()).as_matrix() - R_t = scipy_R.from_matrix(R_t.squeeze().numpy()).as_matrix() - - L = R_t.shape[0] - all_rot_transitions = np.broadcast_to(np.identity(3), (L, 3, 3)).copy() - # Sample next frame for each residue - if so3_type == "igso3": - # don't do calculations on masked positions since they end up as identity matrix - all_rot_transitions[ - ~diffusion_mask - ] = diffuser.so3_diffuser.reverse_sample_vectorized( - R_t[~diffusion_mask], - R_0[~diffusion_mask], - t, - noise_level=noise_scale, - mask=None, - return_perturb=True, - ) - else: - assert False, "so3 diffusion type %s not implemented" % so3_type - - all_rot_transitions = all_rot_transitions[:, None, :, :] - - # Apply the interpolated rotation matrices to the coordinates - next_crds = ( - np.einsum( - "lrij,laj->lrai", - all_rot_transitions, - xt[:, :3, :] - Ca_t.squeeze()[:, None, ...].numpy(), - ) - + Ca_t.squeeze()[:, None, None, ...].numpy() - ) - - # (L,3,3) set of backbone coordinates with slight rotation - return next_crds.squeeze(1) - - -def get_mu_xt_x0(xt, px0, t, beta_schedule, alphabar_schedule, eps=1e-6): - """ - Given xt, predicted x0 and the timestep t, give mu of x(t-1) - Assumes t is 0 indexed - """ - # sigma is predefined from beta. Often referred to as beta tilde t - t_idx = t - 1 - sigma = ( - (1 - alphabar_schedule[t_idx - 1]) / (1 - alphabar_schedule[t_idx]) - ) * beta_schedule[t_idx] - - xt_ca = xt[:, 1, :] - px0_ca = px0[:, 1, :] - - a = ( - (torch.sqrt(alphabar_schedule[t_idx - 1] + eps) * beta_schedule[t_idx]) - / (1 - alphabar_schedule[t_idx]) - ) * px0_ca - b = ( - ( - torch.sqrt(1 - beta_schedule[t_idx] + eps) - * (1 - alphabar_schedule[t_idx - 1]) - ) - / (1 - alphabar_schedule[t_idx]) - ) * xt_ca - - mu = a + b - - return mu, sigma - - -def get_next_ca( - xt, - px0, - t, - diffusion_mask, - crd_scale, - beta_schedule, - alphabar_schedule, - noise_scale=1.0, -): - """ - Given full atom x0 prediction (xyz coordinates), diffuse to x(t-1) - - Parameters: - - xt (L, 14/27, 3) set of coordinates - - px0 (L, 14/27, 3) set of coordinates - - t: time step. Note this is zero-index current time step, so are generating t-1 - - logits_aa (L x 20 ) amino acid probabilities at each position - - seq_schedule (L): Tensor of bools, True is unmasked, False is masked. For this specific t - - diffusion_mask (torch.tensor, required): Tensor of bools, True means NOT diffused at this residue, False means diffused - - noise_scale: scale factor for the noise being added - - """ - get_allatom = ComputeAllAtomCoords().to(device=xt.device) - L = len(xt) - - # bring to origin after global alignment (when don't have a motif) or replace input motif and bring to origin, and then scale - px0 = px0 * crd_scale - xt = xt * crd_scale - - # get mu(xt, x0) - mu, sigma = get_mu_xt_x0( - xt, px0, t, beta_schedule=beta_schedule, alphabar_schedule=alphabar_schedule - ) - - sampled_crds = torch.normal(mu, torch.sqrt(sigma * noise_scale)) - delta = sampled_crds - xt[:, 1, :] # check sign of this is correct - - if not diffusion_mask is None: - # Don't move motif - delta[diffusion_mask, ...] = 0 - - out_crds = xt + delta[:, None, :] - - return out_crds / crd_scale, delta / crd_scale - - -def get_noise_schedule(T, noiseT, noise1, schedule_type): - """ - Function to create a schedule that varies the scale of noise given to the model over time - - Parameters: - - T: The total number of timesteps in the denoising trajectory - - noiseT: The inital (t=T) noise scale - - noise1: The final (t=1) noise scale - - schedule_type: The type of function to use to interpolate between noiseT and noise1 - - Returns: - - noise_schedule: A function which maps timestep to noise scale - - """ - - noise_schedules = { - "constant": lambda t: noiseT, - "linear": lambda t: ((t - 1) / (T - 1)) * (noiseT - noise1) + noise1, - } - - assert ( - schedule_type in noise_schedules - ), f"noise_schedule must be one of {noise_schedules.keys()}. Received noise_schedule={schedule_type}. Exiting." - - return noise_schedules[schedule_type] - - -class Denoise: - """ - Class for getting x(t-1) from predicted x0 and x(t) - Strategy: - Ca coordinates: Rediffuse to x(t-1) from predicted x0 - Frames: Approximate update from rotation score - Torsions: 1/t of the way to the x0 prediction - - """ - - def __init__( - self, - T, - L, - diffuser, - b_0=0.001, - b_T=0.1, - min_b=1.0, - max_b=12.5, - min_sigma=0.05, - max_sigma=1.5, - noise_level=0.5, - schedule_type="linear", - so3_schedule_type="linear", - schedule_kwargs={}, - so3_type="igso3", - noise_scale_ca=1.0, - final_noise_scale_ca=1, - ca_noise_schedule_type="constant", - noise_scale_frame=0.5, - final_noise_scale_frame=0.5, - frame_noise_schedule_type="constant", - crd_scale=1 / 15, - potential_manager=None, - partial_T=None, - ): - """ - - Parameters: - noise_level: scaling on the noise added (set to 0 to use no noise, - to 1 to have full noise) - - """ - self.T = T - self.L = L - self.diffuser = diffuser - self.b_0 = b_0 - self.b_T = b_T - self.noise_level = noise_level - self.schedule_type = schedule_type - self.so3_type = so3_type - self.crd_scale = crd_scale - self.noise_scale_ca = noise_scale_ca - self.final_noise_scale_ca = final_noise_scale_ca - self.ca_noise_schedule_type = ca_noise_schedule_type - self.noise_scale_frame = noise_scale_frame - self.final_noise_scale_frame = final_noise_scale_frame - self.frame_noise_schedule_type = frame_noise_schedule_type - self.potential_manager = potential_manager - self._log = logging.getLogger(__name__) - - self.schedule, self.alpha_schedule, self.alphabar_schedule = get_beta_schedule( - self.T, self.b_0, self.b_T, self.schedule_type, inference=True - ) - - self.noise_schedule_ca = get_noise_schedule( - self.T, - self.noise_scale_ca, - self.final_noise_scale_ca, - self.ca_noise_schedule_type, - ) - self.noise_schedule_frame = get_noise_schedule( - self.T, - self.noise_scale_frame, - self.final_noise_scale_frame, - self.frame_noise_schedule_type, - ) - - @property - def idx2steps(self): - return self.decode_scheduler.idx2steps.numpy() - - def align_to_xt_motif(self, px0, xT, diffusion_mask, eps=1e-6): - """ - Need to align px0 to motif in xT. This is to permit the swapping of residue positions in the px0 motif for the true coordinates. - First, get rotation matrix from px0 to xT for the motif residues. - Second, rotate px0 (whole structure) by that rotation matrix - Third, centre at origin - """ - - def rmsd(V, W, eps=0): - # First sum down atoms, then sum down xyz - N = V.shape[-2] - return np.sqrt(np.sum((V - W) * (V - W), axis=(-2, -1)) / N + eps) - - assert ( - xT.shape[1] == px0.shape[1] - ), f"xT has shape {xT.shape} and px0 has shape {px0.shape}" - - L, n_atom, _ = xT.shape # A is number of atoms - atom_mask = ~torch.isnan(px0) - # convert to numpy arrays - px0 = px0.cpu().detach().numpy() - xT = xT.cpu().detach().numpy() - diffusion_mask = diffusion_mask.cpu().detach().numpy() - - # 1 centre motifs at origin and get rotation matrix - px0_motif = px0[diffusion_mask, :3].reshape(-1, 3) - xT_motif = xT[diffusion_mask, :3].reshape(-1, 3) - px0_motif_mean = np.copy(px0_motif.mean(0)) # need later - xT_motif_mean = np.copy(xT_motif.mean(0)) - - # center at origin - px0_motif = px0_motif - px0_motif_mean - xT_motif = xT_motif - xT_motif_mean - - # A = px0_motif - # B = xT_motif - A = xT_motif - B = px0_motif - - C = np.matmul(A.T, B) - - # compute optimal rotation matrix using SVD - U, S, Vt = np.linalg.svd(C) - - # ensure right handed coordinate system - d = np.eye(3) - d[-1, -1] = np.sign(np.linalg.det(Vt.T @ U.T)) - - # construct rotation matrix - R = Vt.T @ d @ U.T - - # get rotated coords - rB = B @ R - - # calculate rmsd - rms = rmsd(A, rB) - self._log.info(f"Sampled motif RMSD: {rms:.2f}") - - # 2 rotate whole px0 by rotation matrix - atom_mask = atom_mask.cpu() - px0[~atom_mask] = 0 # convert nans to 0 - px0 = px0.reshape(-1, 3) - px0_motif_mean - px0_ = px0 @ R - - # 3 put in same global position as xT - px0_ = px0_ + xT_motif_mean - px0_ = px0_.reshape([L, n_atom, 3]) - px0_[~atom_mask] = float("nan") - return torch.Tensor(px0_) - - def get_potential_gradients(self, xyz, diffusion_mask): - """ - This could be moved into potential manager if desired - NRB - - Function to take a structure (x) and get per-atom gradients used to guide diffusion update - - Inputs: - - xyz (torch.tensor, required): [L,27,3] Coordinates at which the gradient will be computed - - Outputs: - - Ca_grads (torch.tensor): [L,3] The gradient at each Ca atom - """ - - if self.potential_manager == None or self.potential_manager.is_empty(): - return torch.zeros(xyz.shape[0], 3) - - use_Cb = False - - # seq.requires_grad = True - xyz.requires_grad = True - - if not xyz.grad is None: - xyz.grad.zero_() - - current_potential = self.potential_manager.compute_all_potentials(xyz) - current_potential.backward() - - # Since we are not moving frames, Cb grads are same as Ca grads - # Need access to calculated Cb coordinates to be able to get Cb grads though - Ca_grads = xyz.grad[:, 1, :] - - if not diffusion_mask == None: - Ca_grads[diffusion_mask, :] = 0 - - # check for NaN's - if torch.isnan(Ca_grads).any(): - print("WARNING: NaN in potential gradients, replacing with zero grad.") - Ca_grads[:] = 0 - - return Ca_grads - - def get_next_pose( - self, - xt, - px0, - t, - diffusion_mask, - fix_motif=True, - align_motif=True, - include_motif_sidechains=True, - ): - """ - Wrapper function to take px0, xt and t, and to produce xt-1 - First, aligns px0 to xt - Then gets coordinates, frames and torsion angles - - Parameters: - - xt (torch.tensor, required): Current coordinates at timestep t - - px0 (torch.tensor, required): Prediction of x0 - - t (int, required): timestep t - - diffusion_mask (torch.tensor, required): Mask for structure diffusion - - fix_motif (bool): Fix the motif structure - - align_motif (bool): Align the model's prediction of the motif to the input motif - - include_motif_sidechains (bool): Provide sidechains of the fixed motif to the model - """ - - get_allatom = ComputeAllAtomCoords().to(device=xt.device) - L, n_atom = xt.shape[:2] - assert (xt.shape[1] == 14) or (xt.shape[1] == 27) - assert (px0.shape[1] == 14) or (px0.shape[1] == 27) - - ############################### - ### Align pX0 onto Xt motif ### - ############################### - - if align_motif and diffusion_mask.any(): - px0 = self.align_to_xt_motif(px0, xt, diffusion_mask) - # xT_motif_aligned = self.align_to_xt_motif(px0, xt, diffusion_mask) - - px0 = px0.to(xt.device) - # Now done with diffusion mask. if fix motif is False, just set diffusion mask to be all True, and all coordinates can diffuse - if not fix_motif: - diffusion_mask[:] = False - - # get the next set of CA coordinates - noise_scale_ca = self.noise_schedule_ca(t) - _, ca_deltas = get_next_ca( - xt, - px0, - t, - diffusion_mask, - crd_scale=self.crd_scale, - beta_schedule=self.schedule, - alphabar_schedule=self.alphabar_schedule, - noise_scale=noise_scale_ca, - ) - - # get the next set of backbone frames (coordinates) - noise_scale_frame = self.noise_schedule_frame(t) - frames_next = get_next_frames( - xt, - px0, - t, - diffuser=self.diffuser, - so3_type=self.so3_type, - diffusion_mask=diffusion_mask, - noise_scale=noise_scale_frame, - ) - - # Apply gradient step from guiding potentials - # This can be moved to below where the full atom representation is calculated to allow for potentials involving sidechains - - grad_ca = self.get_potential_gradients( - xt.clone(), diffusion_mask=diffusion_mask - ) - - ca_deltas += self.potential_manager.get_guide_scale(t) * grad_ca - - # add the delta to the new frames - frames_next = torch.from_numpy(frames_next) + ca_deltas[:, None, :] # translate - - fullatom_next = torch.full_like(xt, float("nan")).unsqueeze(0) - fullatom_next[:, :, :3] = frames_next[None] - # This is never used so just make it a fudged tensor - NRB - torsions_next = torch.zeros(1, 1) - - if include_motif_sidechains: - fullatom_next[:, diffusion_mask, :14] = xt[None, diffusion_mask] - - return fullatom_next.squeeze()[:, :14, :], px0 - - -def sampler_selector(conf: DictConfig): - if conf.scaffoldguided.scaffoldguided: - sampler = model_runners.ScaffoldedSampler(conf) - else: - if conf.inference.model_runner == "default": - sampler = model_runners.Sampler(conf) - elif conf.inference.model_runner == "SelfConditioning": - sampler = model_runners.SelfConditioning(conf) - elif conf.inference.model_runner == "ScaffoldedSampler": - sampler = model_runners.ScaffoldedSampler(conf) - else: - raise ValueError(f"Unrecognized sampler {conf.model_runner}") - return sampler - - -def parse_pdb(filename, **kwargs): - """extract xyz coords for all heavy atoms""" - with open(filename,"r") as f: - lines=f.readlines() - return parse_pdb_lines(lines, **kwargs) - - -def parse_pdb_lines(lines, parse_hetatom=False, ignore_het_h=True): - # indices of residues observed in the structure - res, pdb_idx = [],[] - for l in lines: - if l[:4] == "ATOM" and l[12:16].strip() == "CA": - res.append((l[22:26], l[17:20])) - # chain letter, res num - pdb_idx.append((l[21:22].strip(), int(l[22:26].strip()))) - seq = [util.aa2num[r[1]] if r[1] in util.aa2num.keys() else 20 for r in res] - pdb_idx = [ - (l[21:22].strip(), int(l[22:26].strip())) - for l in lines - if l[:4] == "ATOM" and l[12:16].strip() == "CA" - ] # chain letter, res num - - # 4 BB + up to 10 SC atoms - xyz = np.full((len(res), 14, 3), np.nan, dtype=np.float32) - for l in lines: - if l[:4] != "ATOM": - continue - chain, resNo, atom, aa = ( - l[21:22], - int(l[22:26]), - " " + l[12:16].strip().ljust(3), - l[17:20], - ) - if (chain,resNo) in pdb_idx: - idx = pdb_idx.index((chain, resNo)) - # for i_atm, tgtatm in enumerate(util.aa2long[util.aa2num[aa]]): - for i_atm, tgtatm in enumerate( - util.aa2long[util.aa2num[aa]][:14] - ): - if ( - tgtatm is not None and tgtatm.strip() == atom.strip() - ): # ignore whitespace - xyz[idx, i_atm, :] = [float(l[30:38]), float(l[38:46]), float(l[46:54])] - break - - # save atom mask - mask = np.logical_not(np.isnan(xyz[..., 0])) - xyz[np.isnan(xyz[..., 0])] = 0.0 - - # remove duplicated (chain, resi) - new_idx = [] - i_unique = [] - for i, idx in enumerate(pdb_idx): - if idx not in new_idx: - new_idx.append(idx) - i_unique.append(i) - - pdb_idx = new_idx - xyz = xyz[i_unique] - mask = mask[i_unique] - - seq = np.array(seq)[i_unique] - - out = { - "xyz": xyz, # cartesian coordinates, [Lx14] - "mask": mask, # mask showing which atoms are present in the PDB file, [Lx14] - "idx": np.array( - [i[1] for i in pdb_idx] - ), # residue numbers in the PDB file, [L] - "seq": np.array(seq), # amino acid sequence, [L] - "pdb_idx": pdb_idx, # list of (chain letter, residue number) in the pdb file, [L] - } - - # heteroatoms (ligands, etc) - if parse_hetatom: - xyz_het, info_het = [], [] - for l in lines: - if l[:6] == "HETATM" and not (ignore_het_h and l[77] == "H"): - info_het.append( - dict( - idx=int(l[7:11]), - atom_id=l[12:16], - atom_type=l[77], - name=l[16:20], - ) - ) - xyz_het.append([float(l[30:38]), float(l[38:46]), float(l[46:54])]) - - out["xyz_het"] = np.array(xyz_het) - out["info_het"] = info_het - - return out - - -def process_target(pdb_path, parse_hetatom=False, center=True): - # Read target pdb and extract features. - target_struct = parse_pdb(pdb_path, parse_hetatom=parse_hetatom) - - # Zero-center positions - ca_center = target_struct["xyz"][:, :1, :].mean(axis=0, keepdims=True) - if not center: - ca_center = 0 - xyz = torch.from_numpy(target_struct["xyz"] - ca_center) - seq_orig = torch.from_numpy(target_struct["seq"]) - atom_mask = torch.from_numpy(target_struct["mask"]) - seq_len = len(xyz) - - # Make 27 atom representation - xyz_27 = torch.full((seq_len, 27, 3), np.nan).float() - xyz_27[:, :14, :] = xyz[:, :14, :] - - mask_27 = torch.full((seq_len, 27), False) - mask_27[:, :14] = atom_mask - out = { - "xyz_27": xyz_27, - "mask_27": mask_27, - "seq": seq_orig, - "pdb_idx": target_struct["pdb_idx"], - } - if parse_hetatom: - out["xyz_het"] = target_struct["xyz_het"] - out["info_het"] = target_struct["info_het"] - return out - - -def get_idx0_hotspots(mappings, ppi_conf, binderlen): - """ - Take pdb-indexed hotspot resudes and the length of the binder, and makes the 0-indexed tensor of hotspots - """ - - hotspot_idx = None - if binderlen > 0: - if ppi_conf.hotspot_res is not None: - assert all( - [i[0].isalpha() for i in ppi_conf.hotspot_res] - ), "Hotspot residues need to be provided in pdb-indexed form. E.g. A100,A103" - hotspots = [(i[0], int(i[1:])) for i in ppi_conf.hotspot_res] - hotspot_idx = [] - for i, res in enumerate(mappings["receptor_con_ref_pdb_idx"]): - if res in hotspots: - hotspot_idx.append(mappings["receptor_con_hal_idx0"][i]) - return hotspot_idx - - -class BlockAdjacency: - """ - Class for handling PPI design inference with ss/block_adj inputs. - Basic idea is to provide a list of scaffolds, and to output ss and adjacency - matrices based off of these, while sampling additional lengths. - Inputs: - - scaffold_list: list of scaffolds (e.g. ['2kl8','1cif']). Can also be a .txt file. - - scaffold dir: directory where scaffold ss and adj are precalculated - - sampled_insertion: how many additional residues do you want to add to each loop segment? Randomly sampled 0-this number (or within given range) - - sampled_N: randomly sample up to this number of additional residues at N-term - - sampled_C: randomly sample up to this number of additional residues at C-term - - ss_mask: how many residues do you want to mask at either end of a ss (H or E) block. Fixed value - - num_designs: how many designs are you wanting to generate? Currently only used for bookkeeping - - systematic: do you want to systematically work through the list of scaffolds, or randomly sample (default) - - num_designs_per_input: Not really implemented yet. Maybe not necessary - Outputs: - - L: new length of chain to be diffused - - ss: all loops and insertions, and ends of ss blocks (up to ss_mask) set to mask token (3). Onehot encoded. (L,4) - - adj: block adjacency with equivalent masking as ss (L,L) - """ - - def __init__(self, conf, num_designs): - """ - Parameters: - inputs: - conf.scaffold_list as conf - conf.inference.num_designs for sanity checking - """ - - self.conf=conf - # either list or path to .txt file with list of scaffolds - if self.conf.scaffoldguided.scaffold_list is not None: - if type(self.conf.scaffoldguided.scaffold_list) == list: - self.scaffold_list = scaffold_list - elif self.conf.scaffoldguided.scaffold_list[-4:] == ".txt": - # txt file with list of ids - list_from_file = [] - with open(self.conf.scaffoldguided.scaffold_list, "r") as f: - for line in f: - list_from_file.append(line.strip()) - self.scaffold_list = list_from_file - else: - raise NotImplementedError - else: - self.scaffold_list = [ - os.path.split(i)[1][:-6] - for i in glob.glob(f"{self.conf.scaffoldguided.scaffold_dir}/*_ss.pt") - ] - self.scaffold_list.sort() - - # path to directory with scaffolds, ss files and block_adjacency files - self.scaffold_dir = self.conf.scaffoldguided.scaffold_dir - - # maximum sampled insertion in each loop segment - if "-" in str(self.conf.scaffoldguided.sampled_insertion): - self.sampled_insertion = [ - int(str(self.conf.scaffoldguided.sampled_insertion).split("-")[0]), - int(str(self.conf.scaffoldguided.sampled_insertion).split("-")[1]), - ] - else: - self.sampled_insertion = [0, int(self.conf.scaffoldguided.sampled_insertion)] - - # maximum sampled insertion at N- and C-terminus - if "-" in str(self.conf.scaffoldguided.sampled_N): - self.sampled_N = [ - int(str(self.conf.scaffoldguided.sampled_N).split("-")[0]), - int(str(self.conf.scaffoldguided.sampled_N).split("-")[1]), - ] - else: - self.sampled_N = [0, int(self.conf.scaffoldguided.sampled_N)] - if "-" in str(self.conf.scaffoldguided.sampled_C): - self.sampled_C = [ - int(str(self.conf.scaffoldguided.sampled_C).split("-")[0]), - int(str(self.conf.scaffoldguided.sampled_C).split("-")[1]), - ] - else: - self.sampled_C = [0, int(self.conf.scaffoldguided.sampled_C)] - - # number of residues to mask ss identity of in H/E regions (from junction) - # e.g. if ss_mask = 2, L,L,L,H,H,H,H,H,H,H,L,L,E,E,E,E,E,E,L,L,L,L,L,L would become\ - # M,M,M,M,M,H,H,H,M,M,M,M,M,M,E,E,M,M,M,M,M,M,M,M where M is mask - self.ss_mask = self.conf.scaffoldguided.ss_mask - - # whether or not to work systematically through the list - self.systematic = self.conf.scaffoldguided.systematic - - self.num_designs = num_designs - - if len(self.scaffold_list) > self.num_designs: - print( - "WARNING: Scaffold set is bigger than num_designs, so not every scaffold type will be sampled" - ) - - # for tracking number of designs - self.num_completed = 0 - if self.systematic: - self.item_n = 0 - - # whether to mask loops or not - if not self.conf.scaffoldguided.mask_loops: - assert self.conf.scaffoldguided.sampled_N == 0, "can't add length if not masking loops" - assert self.conf.scaffoldguided.sampled_C == 0, "can't add lemgth if not masking loops" - assert self.conf.scaffoldguided.sampled_insertion == 0, "can't add length if not masking loops" - self.mask_loops = False - else: - self.mask_loops = True - - def get_ss_adj(self, item): - """ - Given at item, get the ss tensor and block adjacency matrix for that item - """ - ss = torch.load(os.path.join(self.scaffold_dir, f'{item.split(".")[0]}_ss.pt')) - adj = torch.load( - os.path.join(self.scaffold_dir, f'{item.split(".")[0]}_adj.pt') - ) - - return ss, adj - - def mask_to_segments(self, mask): - """ - Takes a mask of True (loop) and False (non-loop), and outputs list of tuples (loop or not, length of element) - """ - segments = [] - begin = -1 - end = -1 - for i in range(mask.shape[0]): - # Starting edge case - if i == 0: - begin = 0 - continue - - if not mask[i] == mask[i - 1]: - end = i - if mask[i - 1].item() is True: - segments.append(("loop", end - begin)) - else: - segments.append(("ss", end - begin)) - begin = i - - # Ending edge case: last segment is length one - if not end == mask.shape[0]: - if mask[i].item() is True: - segments.append(("loop", mask.shape[0] - begin)) - else: - segments.append(("ss", mask.shape[0] - begin)) - return segments - - def expand_mask(self, mask, segments): - """ - Function to generate a new mask with dilated loops and N and C terminal additions - """ - N_add = random.randint(self.sampled_N[0], self.sampled_N[1]) - C_add = random.randint(self.sampled_C[0], self.sampled_C[1]) - - output = N_add * [False] - for ss, length in segments: - if ss == "ss": - output.extend(length * [True]) - else: - # randomly sample insertion length - ins = random.randint( - self.sampled_insertion[0], self.sampled_insertion[1] - ) - output.extend((length + ins) * [False]) - output.extend(C_add * [False]) - assert torch.sum(torch.tensor(output)) == torch.sum(~mask) - return torch.tensor(output) - - def expand_ss(self, ss, adj, mask, expanded_mask): - """ - Given an expanded mask, populate a new ss and adj based on this - """ - ss_out = torch.ones(expanded_mask.shape[0]) * 3 # set to mask token - adj_out = torch.full((expanded_mask.shape[0], expanded_mask.shape[0]), 0.0) - ss_out[expanded_mask] = ss[~mask] - expanded_mask_2d = torch.full(adj_out.shape, True) - # mask out loops/insertions, which is ~expanded_mask - expanded_mask_2d[~expanded_mask, :] = False - expanded_mask_2d[:, ~expanded_mask] = False - - mask_2d = torch.full(adj.shape, True) - # mask out loops. This mask is True=loop - mask_2d[mask, :] = False - mask_2d[:, mask] = False - adj_out[expanded_mask_2d] = adj[mask_2d] - adj_out = adj_out.reshape((expanded_mask.shape[0], expanded_mask.shape[0])) - - return ss_out, adj_out - - def mask_ss_adj(self, ss, adj, expanded_mask): - """ - Given an expanded ss and adj, mask some number of residues at either end of non-loop ss - """ - original_mask = torch.clone(expanded_mask) - if self.ss_mask > 0: - for i in range(1, self.ss_mask + 1): - expanded_mask[i:] *= original_mask[:-i] - expanded_mask[:-i] *= original_mask[i:] - - if self.mask_loops: - ss[~expanded_mask] = 3 - adj[~expanded_mask, :] = 0 - adj[:, ~expanded_mask] = 0 - - # mask adjacency - adj[~expanded_mask] = 2 - adj[:, ~expanded_mask] = 2 - - return ss, adj - - def get_scaffold(self): - """ - Wrapper method for pulling an item from the list, and preparing ss and block adj features - """ - - # Handle determinism. Useful for integration tests - if self.conf.inference.deterministic: - torch.manual_seed(self.num_completed) - np.random.seed(self.num_completed) - random.seed(self.num_completed) - - if self.systematic: - # reset if num designs > num_scaffolds - if self.item_n >= len(self.scaffold_list): - self.item_n = 0 - item = self.scaffold_list[self.item_n] - self.item_n += 1 - else: - item = random.choice(self.scaffold_list) - print("Scaffold constrained based on file: ", item) - # load files - ss, adj = self.get_ss_adj(item) - adj_orig = torch.clone(adj) - # separate into segments (loop or not) - mask = torch.where(ss == 2, 1, 0).bool() - segments = self.mask_to_segments(mask) - - # insert into loops to generate new mask - expanded_mask = self.expand_mask(mask, segments) - - # expand ss and adj - ss, adj = self.expand_ss(ss, adj, mask, expanded_mask) - - # finally, mask some proportion of the ss at either end of the non-loop ss blocks - ss, adj = self.mask_ss_adj(ss, adj, expanded_mask) - - # and then update num_completed - self.num_completed += 1 - - return ss.shape[0], torch.nn.functional.one_hot(ss.long(), num_classes=4), adj - - -class Target: - """ - Class to handle targets (fixed chains). - Inputs: - - path to pdb file - - hotspot residues, in the form B10,B12,B60 etc - - whether or not to crop, and with which method - Outputs: - - Dictionary of xyz coordinates, indices, pdb_indices, pdb mask - """ - - def __init__(self, conf: DictConfig, hotspots=None): - self.pdb = parse_pdb(conf.target_path) - - if hotspots is not None: - self.hotspots = hotspots - else: - self.hotspots = [] - self.pdb["hotspots"] = np.array( - [ - True if f"{i[0]}{i[1]}" in self.hotspots else False - for i in self.pdb["pdb_idx"] - ] - ) - - if conf.contig_crop: - self.contig_crop(conf.contig_crop) - - def parse_contig(self, contig_crop): - """ - Takes contig input and parses - """ - contig_list = [] - for contig in contig_crop[0].split(" "): - subcon = [] - for crop in contig.split("/"): - if crop[0].isalpha(): - subcon.extend( - [ - (crop[0], p) - for p in np.arange( - int(crop.split("-")[0][1:]), int(crop.split("-")[1]) + 1 - ) - ] - ) - contig_list.append(subcon) - return contig_list - - def contig_crop(self, contig_crop, residue_offset=200) -> None: - """ - Method to take a contig string referring to the receptor and output a pdb dictionary with just this crop - NB there are two ways to provide inputs: - - 1) e.g. B1-30,0 B50-60,0. This will add a residue offset between each chunk - - 2) e.g. B1-30,B50-60,B80-100. This will keep the original indexing of the pdb file. - Can handle the target being on multiple chains - """ - - # add residue offset between chains if multiple chains in receptor file - for idx, val in enumerate(self.pdb["pdb_idx"]): - if idx != 0 and val != self.pdb["pdb_idx"][idx - 1]: - self.pdb["idx"][idx:] += residue_offset + idx - - # convert contig to mask - contig_list = self.parse_contig(contig_crop) - - # add residue offset to different parts of contig_list - for contig in contig_list[1:]: - start = int(contig[0][1]) - self.pdb["idx"][start:] += residue_offset - # flatten list - contig_list = [i for j in contig_list for i in j] - mask = np.array( - [True if i in contig_list else False for i in self.pdb["pdb_idx"]] - ) - - # sanity check - assert np.sum(self.pdb["hotspots"]) == np.sum( - self.pdb["hotspots"][mask] - ), "Supplied hotspot residues are missing from the target contig!" - # crop pdb - for key, val in self.pdb.items(): - try: - self.pdb[key] = val[mask] - except: - self.pdb[key] = [i for idx, i in enumerate(val) if mask[idx]] - self.pdb["crop_mask"] = mask - - def get_target(self): - return self.pdb - -def ss_from_contig(ss_masks: dict): - """ - Function for taking 1D masks for each of the ss types, and outputting a secondary structure input - """ - L=len(ss_masks['helix']) - ss=torch.zeros((L, 4)).long() - ss[:,3] = 1 #mask - for idx, mask in enumerate([ss_masks['helix'],ss_masks['strand'], ss_masks['loop']]): - ss[mask,idx] = 1 - ss[mask, 3] = 0 # remove the mask token - return ss \ No newline at end of file