diff --git a/alphapulldown/scripts/create_individual_features.py b/alphapulldown/scripts/create_individual_features.py index 58a1ee76..867550d1 100644 --- a/alphapulldown/scripts/create_individual_features.py +++ b/alphapulldown/scripts/create_individual_features.py @@ -15,6 +15,7 @@ import tempfile from datetime import datetime from pathlib import Path +import numpy as np from absl import logging, app, flags from colabfold.utils import DEFAULT_API_SERVER @@ -29,6 +30,9 @@ from alphapulldown.objects import MonomericObject from alphapulldown.utils.file_handling import iter_seqs, parse_csv_file from alphapulldown.utils.modelling_setup import create_uniprot_runner +from alphapulldown.utils.multimeric_template_utils import ( + extract_multimeric_template_features_for_single_chain, +) from alphapulldown.utils import save_meta_data # Try to import AlphaFold3, but it's optional @@ -340,27 +344,151 @@ def create_individual_features(): monomer.uniprot_runner = uniprot_runner create_and_save_monomer_objects(monomer, pipeline) -def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None): - """Save a MonomericObject after feature creation (pickled, optionally compressed).""" - # Ensure output directory exists - os.makedirs(FLAGS.output_dir, exist_ok=True) - - pickle_path = os.path.join(FLAGS.output_dir, f"{monomer.description}.pkl") - if FLAGS.compress_features: - pickle_path += ".xz" - if FLAGS.skip_existing and os.path.exists(pickle_path): - logging.info(f"Feature file for {monomer.description} already exists. Skipping...") - return - meta_dict = save_meta_data.get_meta_dict(FLAGS.flag_values_dict()) - metadata_output_path = os.path.join( - FLAGS.output_dir, f"{monomer.description}_feature_metadata_{datetime.now().date()}.json" +def _feature_pickle_path(description, *, compress): + suffix = ".pkl.xz" if compress else ".pkl" + return Path(FLAGS.output_dir) / f"{description}{suffix}" + + +def _metadata_output_path(description): + return Path(FLAGS.output_dir) / ( + f"{description}_feature_metadata_{datetime.now().date()}.json" ) + + +def _should_skip_monomer_output(description): + pickle_path = _feature_pickle_path(description, compress=FLAGS.compress_features) + if FLAGS.skip_existing and pickle_path.exists(): + logging.info(f"Feature file for {description} already exists. Skipping...") + return True + return False + + +def _persist_monomer_outputs(monomer): + meta_dict = save_meta_data.get_meta_dict(FLAGS.flag_values_dict()) + metadata_output_path = _metadata_output_path(monomer.description) if FLAGS.compress_features: - with lzma.open(metadata_output_path + '.xz', "wt") as meta_data_outfile: + with lzma.open(str(metadata_output_path) + ".xz", "wt") as meta_data_outfile: json.dump(meta_dict, meta_data_outfile) + with lzma.open(_feature_pickle_path(monomer.description, compress=True), "wb") as pickle_file: + pickle.dump(monomer, pickle_file) else: with open(metadata_output_path, "w") as meta_data_outfile: json.dump(meta_dict, meta_data_outfile) + with open(_feature_pickle_path(monomer.description, compress=False), "wb") as pickle_file: + pickle.dump(monomer, pickle_file) + + +def _load_existing_monomer_from_output_dir(description): + for suffix in (".pkl", ".pkl.xz"): + pickle_path = Path(FLAGS.output_dir) / f"{description}{suffix}" + if not pickle_path.exists(): + continue + if suffix == ".pkl.xz": + with lzma.open(pickle_path, "rb") as handle: + return pickle.load(handle) + with open(pickle_path, "rb") as handle: + return pickle.load(handle) + return None + + +def _infer_truemultimer_source_name(protein, template_paths, chains): + if len(template_paths) != 1 or len(chains) != 1: + return protein + template_name = Path(template_paths[0]).name + suffix = f".{template_name}.{chains[0]}" + if protein.endswith(suffix): + return protein[: -len(suffix)] + return protein + + +def _replace_template_features(monomer, template_features): + monomer.feature_dict = { + key: value + for key, value in monomer.feature_dict.items() + if not key.startswith("template_") + } + monomer.feature_dict.update(template_features) + + template_count = 0 + if "template_aatype" in monomer.feature_dict: + template_count = int(monomer.feature_dict["template_aatype"].shape[0]) + elif "template_sequence" in monomer.feature_dict: + template_count = len(monomer.feature_dict["template_sequence"]) + template_count = max(template_count, 1) + + if "template_sum_probs" in monomer.feature_dict: + monomer.feature_dict["template_sum_probs"] = np.asarray( + monomer.feature_dict["template_sum_probs"], dtype=np.float32 + ) + else: + monomer.feature_dict["template_sum_probs"] = np.zeros( + template_count, dtype=np.float32 + ) + if "template_confidence_scores" not in monomer.feature_dict: + monomer.feature_dict["template_confidence_scores"] = np.ones( + (template_count, len(monomer.sequence)), dtype=np.float32 + ) + if "template_release_date" not in monomer.feature_dict: + monomer.feature_dict["template_release_date"] = np.array( + ["none"] * template_count, dtype=object + ) + + +def _reuse_truemultimer_monomer_features(feat): + if FLAGS.use_mmseqs2 or len(feat["templates"]) != 1 or len(feat["chains"]) != 1: + return None + + source_name = _infer_truemultimer_source_name( + feat["protein"], feat["templates"], feat["chains"] + ) + monomer = _load_existing_monomer_from_output_dir(source_name) + if monomer is None: + return None + if monomer.sequence != feat["sequence"]: + logging.warning( + "Existing monomer features for %s use sequence %s, but the current " + "TrueMultimer entry expects %s. Falling back to full feature generation.", + source_name, + monomer.sequence, + feat["sequence"], + ) + return None + + template_path = feat["templates"][0] + chain_id = feat["chains"][0] + template_result = extract_multimeric_template_features_for_single_chain( + query_seq=monomer.sequence, + pdb_id=Path(template_path).stem, + chain_id=chain_id, + mmcif_file=template_path, + threshold_clashes=FLAGS.threshold_clashes, + hb_allowance=FLAGS.hb_allowance, + plddt_threshold=FLAGS.plddt_threshold, + ) + if template_result is None or template_result.features is None: + raise RuntimeError( + f"Failed to extract template features from {template_path} chain {chain_id}." + ) + + monomer.description = feat["protein"] + monomer.sequence = feat["sequence"] + monomer.uniprot_runner = None + _replace_template_features(monomer, template_result.features) + logging.info( + "Reused existing monomer features from %s for TrueMultimer target %s.", + source_name, + feat["protein"], + ) + return monomer + + +def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None): + """Save a MonomericObject after feature creation (pickled, optionally compressed).""" + # Ensure output directory exists + os.makedirs(FLAGS.output_dir, exist_ok=True) + + if _should_skip_monomer_output(monomer.description): + return if FLAGS.use_mmseqs2: monomer.make_mmseq_features( DEFAULT_API_SERVER=DEFAULT_API_SERVER, @@ -374,12 +502,7 @@ def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None pipeline=pipeline, output_dir=FLAGS.output_dir, use_precomputed_msa=FLAGS.use_precomputed_msas, save_msa=FLAGS.save_msa_files) - if FLAGS.compress_features: - with lzma.open(pickle_path, "wb") as pickle_file: - pickle.dump(monomer, pickle_file) - else: - with open(pickle_path, "wb") as pickle_file: - pickle.dump(monomer, pickle_file) + _persist_monomer_outputs(monomer) def create_individual_features_truemultimer(): """Generate features in TrueMultimer mode, one set per entry in the description CSV.""" @@ -395,6 +518,12 @@ def process_multimeric_features(feat, idx): for temp_path in feat["templates"]: if not os.path.isfile(temp_path): raise FileNotFoundError(f"Template file {temp_path} does not exist.") + reused_monomer = _reuse_truemultimer_monomer_features(feat) + if reused_monomer is not None: + if _should_skip_monomer_output(reused_monomer.description): + return + _persist_monomer_outputs(reused_monomer) + return protein, chains, template_paths = feat["protein"], feat["chains"], feat["templates"] with tempfile.TemporaryDirectory() as temp_dir: local_path_to_custom_db = create_custom_db(temp_dir, protein, template_paths, chains) diff --git a/test/integration/test_create_individual_features.py b/test/integration/test_create_individual_features.py index a5969cbb..0b695bb7 100644 --- a/test/integration/test_create_individual_features.py +++ b/test/integration/test_create_individual_features.py @@ -434,6 +434,170 @@ def __init__(self, description, sequence): assert saved_monomer.sequence == "ACDEFG" assert saved_monomer.uniprot_runner == "runner" + @pytest.mark.parametrize("compressed_source", [False, True]) + def test_process_multimeric_features_reuses_existing_source_pickle( + self, tmp_flags, compressed_source + ): + template_path = Path(self.test_dir) / "template1.cif" + template_path.write_text("data_template\n", encoding="utf-8") + + from absl import flags + + FLAGS = flags.FLAGS + FLAGS(["test"]) + FLAGS.output_dir = os.path.join(self.test_dir, "reused_truemultimer_output") + FLAGS.use_mmseqs2 = False + FLAGS.compress_features = False + FLAGS.skip_existing = False + + output_dir = Path(FLAGS.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + source = MonomericObject("proteinA", "ACDE") + source.feature_dict = { + "msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32), + "deletion_matrix_int": np.zeros((1, 4), dtype=np.int32), + "num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32), + "msa_species_identifiers": np.asarray([b"9606"], dtype=object), + "msa_all_seq": np.asarray([[1, 2, 3, 4]], dtype=np.int32), + "deletion_matrix_int_all_seq": np.zeros((1, 4), dtype=np.int32), + "msa_species_identifiers_all_seq": np.asarray([b"9606"], dtype=object), + "template_aatype": np.zeros((1, 4, 22), dtype=np.float32), + "template_all_atom_masks": np.ones((1, 4, 37), dtype=np.float32), + "template_all_atom_positions": np.ones((1, 4, 37, 3), dtype=np.float32), + "template_domain_names": np.asarray([b"old_template"], dtype=object), + "template_sequence": np.asarray([b"OLD"], dtype=object), + "template_sum_probs": np.asarray([0.5], dtype=np.float32), + "template_confidence_scores": np.full((1, 4), 0.75, dtype=np.float32), + "template_release_date": np.asarray(["2024-01-01"], dtype=object), + } + + if compressed_source: + with lzma.open(output_dir / "proteinA.pkl.xz", "wb") as handle: + pickle.dump(source, handle) + else: + with open(output_dir / "proteinA.pkl", "wb") as handle: + pickle.dump(source, handle) + + new_template_features = { + "template_aatype": np.ones((2, 4, 22), dtype=np.float32), + "template_all_atom_masks": np.full((2, 4, 37), 2.0, dtype=np.float32), + "template_all_atom_positions": np.full((2, 4, 37, 3), 3.0, dtype=np.float32), + "template_domain_names": np.asarray([b"newA", b"newB"], dtype=object), + "template_sequence": np.asarray([b"NEWA", b"NEWB"], dtype=object), + "template_sum_probs": np.asarray([0.1, 0.2], dtype=np.float32), + } + + feat = { + "protein": "proteinA.template1.cif.A", + "chains": ["A"], + "templates": [str(template_path)], + "sequence": "ACDE", + } + + with patch.object( + create_features, + "extract_multimeric_template_features_for_single_chain", + return_value=types.SimpleNamespace(features=new_template_features), + ) as mock_extract, \ + patch.object(create_features, "create_custom_db") as mock_custom_db, \ + patch.object(create_features, "create_arguments") as mock_arguments, \ + patch.object(create_features, "create_pipeline_af2") as mock_pipeline, \ + patch.object(create_features, "create_uniprot_runner") as mock_runner, \ + patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}): + create_features.process_multimeric_features(feat, 1) + + mock_extract.assert_called_once_with( + query_seq="ACDE", + pdb_id="template1", + chain_id="A", + mmcif_file=str(template_path), + threshold_clashes=create_features.FLAGS.threshold_clashes, + hb_allowance=create_features.FLAGS.hb_allowance, + plddt_threshold=create_features.FLAGS.plddt_threshold, + ) + mock_custom_db.assert_not_called() + mock_arguments.assert_not_called() + mock_pipeline.assert_not_called() + mock_runner.assert_not_called() + + output_pickle = output_dir / "proteinA.template1.cif.A.pkl" + assert output_pickle.exists() + with open(output_pickle, "rb") as handle: + reused = pickle.load(handle) + + assert reused.description == "proteinA.template1.cif.A" + assert np.array_equal(reused.feature_dict["msa"], source.feature_dict["msa"]) + assert reused.feature_dict["template_sequence"].tolist() == [b"NEWA", b"NEWB"] + assert np.array_equal( + reused.feature_dict["template_confidence_scores"], + np.ones((2, 4), dtype=np.float32), + ) + assert reused.feature_dict["template_release_date"].tolist() == ["none", "none"] + assert list(output_dir.glob("proteinA.template1.cif.A_feature_metadata_*.json")) + + def test_process_multimeric_features_falls_back_when_source_sequence_mismatches( + self, tmp_flags + ): + template_path = Path(self.test_dir) / "template1.cif" + template_path.write_text("data_template\n", encoding="utf-8") + + from absl import flags + + FLAGS = flags.FLAGS + FLAGS(["test"]) + FLAGS.output_dir = os.path.join(self.test_dir, "mismatched_truemultimer_output") + FLAGS.use_mmseqs2 = False + FLAGS.compress_features = False + FLAGS.skip_existing = False + FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer" + FLAGS.uniprot_database_path = "/db/uniprot.fasta" + + output_dir = Path(FLAGS.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + source = MonomericObject("proteinA", "ACDE") + source.feature_dict = { + "msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32), + "deletion_matrix_int": np.zeros((1, 4), dtype=np.int32), + "num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32), + "msa_species_identifiers": np.asarray([b"9606"], dtype=object), + } + with open(output_dir / "proteinA.pkl", "wb") as handle: + pickle.dump(source, handle) + + feat = { + "protein": "proteinA.template1.cif.A", + "chains": ["A"], + "templates": [str(template_path)], + "sequence": "ACDF", + } + + with patch.object( + create_features, + "extract_multimeric_template_features_for_single_chain", + ) as mock_extract, \ + patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \ + patch.object(create_features, "create_arguments") as mock_arguments, \ + patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \ + patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \ + patch.object(create_features, "create_and_save_monomer_objects") as mock_save: + create_features.process_multimeric_features(feat, 1) + + mock_extract.assert_not_called() + mock_custom_db.assert_called_once() + mock_arguments.assert_called_once_with("/tmp/custom_db") + mock_pipeline.assert_called_once_with() + mock_runner.assert_called_once_with( + FLAGS.jackhmmer_binary_path, + FLAGS.uniprot_database_path, + ) + saved_monomer, saved_pipeline = mock_save.call_args.args + assert saved_pipeline == "pipeline" + assert saved_monomer.description == "proteinA.template1.cif.A" + assert saved_monomer.sequence == "ACDF" + assert saved_monomer.uniprot_runner == "runner" + def test_main_dispatches_to_truemultimer_for_af2_template_runs(self): """The main entrypoint should route AF2 template jobs to the TrueMultimer path.""" from absl import flags