Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 150 additions & 21 deletions alphapulldown/scripts/create_individual_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tempfile
from datetime import datetime
from pathlib import Path
import numpy as np

from absl import logging, app, flags
from colabfold.utils import DEFAULT_API_SERVER
Expand All @@ -29,6 +30,9 @@
from alphapulldown.objects import MonomericObject
from alphapulldown.utils.file_handling import iter_seqs, parse_csv_file
from alphapulldown.utils.modelling_setup import create_uniprot_runner
from alphapulldown.utils.multimeric_template_utils import (
extract_multimeric_template_features_for_single_chain,
)
from alphapulldown.utils import save_meta_data

# Try to import AlphaFold3, but it's optional
Expand Down Expand Up @@ -340,27 +344,151 @@ def create_individual_features():
monomer.uniprot_runner = uniprot_runner
create_and_save_monomer_objects(monomer, pipeline)

def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None):
"""Save a MonomericObject after feature creation (pickled, optionally compressed)."""
# Ensure output directory exists
os.makedirs(FLAGS.output_dir, exist_ok=True)

pickle_path = os.path.join(FLAGS.output_dir, f"{monomer.description}.pkl")
if FLAGS.compress_features:
pickle_path += ".xz"
if FLAGS.skip_existing and os.path.exists(pickle_path):
logging.info(f"Feature file for {monomer.description} already exists. Skipping...")
return
meta_dict = save_meta_data.get_meta_dict(FLAGS.flag_values_dict())
metadata_output_path = os.path.join(
FLAGS.output_dir, f"{monomer.description}_feature_metadata_{datetime.now().date()}.json"
def _feature_pickle_path(description, *, compress):
suffix = ".pkl.xz" if compress else ".pkl"
return Path(FLAGS.output_dir) / f"{description}{suffix}"


def _metadata_output_path(description):
return Path(FLAGS.output_dir) / (
f"{description}_feature_metadata_{datetime.now().date()}.json"
)


def _should_skip_monomer_output(description):
pickle_path = _feature_pickle_path(description, compress=FLAGS.compress_features)
if FLAGS.skip_existing and pickle_path.exists():
logging.info(f"Feature file for {description} already exists. Skipping...")
return True
return False


def _persist_monomer_outputs(monomer):
meta_dict = save_meta_data.get_meta_dict(FLAGS.flag_values_dict())
metadata_output_path = _metadata_output_path(monomer.description)
if FLAGS.compress_features:
with lzma.open(metadata_output_path + '.xz', "wt") as meta_data_outfile:
with lzma.open(str(metadata_output_path) + ".xz", "wt") as meta_data_outfile:
json.dump(meta_dict, meta_data_outfile)
with lzma.open(_feature_pickle_path(monomer.description, compress=True), "wb") as pickle_file:
pickle.dump(monomer, pickle_file)
else:
with open(metadata_output_path, "w") as meta_data_outfile:
json.dump(meta_dict, meta_data_outfile)
with open(_feature_pickle_path(monomer.description, compress=False), "wb") as pickle_file:
pickle.dump(monomer, pickle_file)


def _load_existing_monomer_from_output_dir(description):
for suffix in (".pkl", ".pkl.xz"):
pickle_path = Path(FLAGS.output_dir) / f"{description}{suffix}"
if not pickle_path.exists():
continue
if suffix == ".pkl.xz":
with lzma.open(pickle_path, "rb") as handle:
return pickle.load(handle)
with open(pickle_path, "rb") as handle:
return pickle.load(handle)
return None


def _infer_truemultimer_source_name(protein, template_paths, chains):
if len(template_paths) != 1 or len(chains) != 1:
return protein
template_name = Path(template_paths[0]).name
suffix = f".{template_name}.{chains[0]}"
if protein.endswith(suffix):
return protein[: -len(suffix)]
return protein


def _replace_template_features(monomer, template_features):
monomer.feature_dict = {
key: value
for key, value in monomer.feature_dict.items()
if not key.startswith("template_")
}
monomer.feature_dict.update(template_features)

template_count = 0
if "template_aatype" in monomer.feature_dict:
template_count = int(monomer.feature_dict["template_aatype"].shape[0])
elif "template_sequence" in monomer.feature_dict:
template_count = len(monomer.feature_dict["template_sequence"])
template_count = max(template_count, 1)

if "template_sum_probs" in monomer.feature_dict:
monomer.feature_dict["template_sum_probs"] = np.asarray(
monomer.feature_dict["template_sum_probs"], dtype=np.float32
)
else:
monomer.feature_dict["template_sum_probs"] = np.zeros(
template_count, dtype=np.float32
)
if "template_confidence_scores" not in monomer.feature_dict:
monomer.feature_dict["template_confidence_scores"] = np.ones(
(template_count, len(monomer.sequence)), dtype=np.float32
)
if "template_release_date" not in monomer.feature_dict:
monomer.feature_dict["template_release_date"] = np.array(
["none"] * template_count, dtype=object
)


def _reuse_truemultimer_monomer_features(feat):
if FLAGS.use_mmseqs2 or len(feat["templates"]) != 1 or len(feat["chains"]) != 1:
return None

source_name = _infer_truemultimer_source_name(
feat["protein"], feat["templates"], feat["chains"]
)
monomer = _load_existing_monomer_from_output_dir(source_name)
if monomer is None:
return None
if monomer.sequence != feat["sequence"]:
logging.warning(
"Existing monomer features for %s use sequence %s, but the current "
"TrueMultimer entry expects %s. Falling back to full feature generation.",
source_name,
monomer.sequence,
feat["sequence"],
)
return None

template_path = feat["templates"][0]
chain_id = feat["chains"][0]
template_result = extract_multimeric_template_features_for_single_chain(
query_seq=monomer.sequence,
Comment thread
DimaMolod marked this conversation as resolved.
pdb_id=Path(template_path).stem,
chain_id=chain_id,
mmcif_file=template_path,
threshold_clashes=FLAGS.threshold_clashes,
hb_allowance=FLAGS.hb_allowance,
plddt_threshold=FLAGS.plddt_threshold,
)
if template_result is None or template_result.features is None:
raise RuntimeError(
f"Failed to extract template features from {template_path} chain {chain_id}."
)

monomer.description = feat["protein"]
monomer.sequence = feat["sequence"]
monomer.uniprot_runner = None
_replace_template_features(monomer, template_result.features)
logging.info(
"Reused existing monomer features from %s for TrueMultimer target %s.",
source_name,
feat["protein"],
)
return monomer


def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None):
"""Save a MonomericObject after feature creation (pickled, optionally compressed)."""
# Ensure output directory exists
os.makedirs(FLAGS.output_dir, exist_ok=True)

if _should_skip_monomer_output(monomer.description):
return
if FLAGS.use_mmseqs2:
monomer.make_mmseq_features(
DEFAULT_API_SERVER=DEFAULT_API_SERVER,
Expand All @@ -374,12 +502,7 @@ def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None
pipeline=pipeline, output_dir=FLAGS.output_dir,
use_precomputed_msa=FLAGS.use_precomputed_msas,
save_msa=FLAGS.save_msa_files)
if FLAGS.compress_features:
with lzma.open(pickle_path, "wb") as pickle_file:
pickle.dump(monomer, pickle_file)
else:
with open(pickle_path, "wb") as pickle_file:
pickle.dump(monomer, pickle_file)
_persist_monomer_outputs(monomer)

def create_individual_features_truemultimer():
"""Generate features in TrueMultimer mode, one set per entry in the description CSV."""
Expand All @@ -395,6 +518,12 @@ def process_multimeric_features(feat, idx):
for temp_path in feat["templates"]:
if not os.path.isfile(temp_path):
raise FileNotFoundError(f"Template file {temp_path} does not exist.")
reused_monomer = _reuse_truemultimer_monomer_features(feat)
if reused_monomer is not None:
if _should_skip_monomer_output(reused_monomer.description):
return
_persist_monomer_outputs(reused_monomer)
return
protein, chains, template_paths = feat["protein"], feat["chains"], feat["templates"]
with tempfile.TemporaryDirectory() as temp_dir:
local_path_to_custom_db = create_custom_db(temp_dir, protein, template_paths, chains)
Expand Down
164 changes: 164 additions & 0 deletions test/integration/test_create_individual_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,170 @@ def __init__(self, description, sequence):
assert saved_monomer.sequence == "ACDEFG"
assert saved_monomer.uniprot_runner == "runner"

@pytest.mark.parametrize("compressed_source", [False, True])
def test_process_multimeric_features_reuses_existing_source_pickle(
self, tmp_flags, compressed_source
):
template_path = Path(self.test_dir) / "template1.cif"
template_path.write_text("data_template\n", encoding="utf-8")

from absl import flags

FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.output_dir = os.path.join(self.test_dir, "reused_truemultimer_output")
FLAGS.use_mmseqs2 = False
FLAGS.compress_features = False
FLAGS.skip_existing = False

output_dir = Path(FLAGS.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

source = MonomericObject("proteinA", "ACDE")
source.feature_dict = {
"msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
"deletion_matrix_int": np.zeros((1, 4), dtype=np.int32),
"num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32),
"msa_species_identifiers": np.asarray([b"9606"], dtype=object),
"msa_all_seq": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
"deletion_matrix_int_all_seq": np.zeros((1, 4), dtype=np.int32),
"msa_species_identifiers_all_seq": np.asarray([b"9606"], dtype=object),
"template_aatype": np.zeros((1, 4, 22), dtype=np.float32),
"template_all_atom_masks": np.ones((1, 4, 37), dtype=np.float32),
"template_all_atom_positions": np.ones((1, 4, 37, 3), dtype=np.float32),
"template_domain_names": np.asarray([b"old_template"], dtype=object),
"template_sequence": np.asarray([b"OLD"], dtype=object),
"template_sum_probs": np.asarray([0.5], dtype=np.float32),
"template_confidence_scores": np.full((1, 4), 0.75, dtype=np.float32),
"template_release_date": np.asarray(["2024-01-01"], dtype=object),
}

if compressed_source:
with lzma.open(output_dir / "proteinA.pkl.xz", "wb") as handle:
pickle.dump(source, handle)
else:
with open(output_dir / "proteinA.pkl", "wb") as handle:
pickle.dump(source, handle)

new_template_features = {
"template_aatype": np.ones((2, 4, 22), dtype=np.float32),
"template_all_atom_masks": np.full((2, 4, 37), 2.0, dtype=np.float32),
"template_all_atom_positions": np.full((2, 4, 37, 3), 3.0, dtype=np.float32),
"template_domain_names": np.asarray([b"newA", b"newB"], dtype=object),
"template_sequence": np.asarray([b"NEWA", b"NEWB"], dtype=object),
"template_sum_probs": np.asarray([0.1, 0.2], dtype=np.float32),
}

feat = {
"protein": "proteinA.template1.cif.A",
"chains": ["A"],
"templates": [str(template_path)],
"sequence": "ACDE",
}

with patch.object(
create_features,
"extract_multimeric_template_features_for_single_chain",
return_value=types.SimpleNamespace(features=new_template_features),
) as mock_extract, \
patch.object(create_features, "create_custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner") as mock_runner, \
patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.process_multimeric_features(feat, 1)

mock_extract.assert_called_once_with(
query_seq="ACDE",
pdb_id="template1",
chain_id="A",
mmcif_file=str(template_path),
threshold_clashes=create_features.FLAGS.threshold_clashes,
hb_allowance=create_features.FLAGS.hb_allowance,
plddt_threshold=create_features.FLAGS.plddt_threshold,
)
mock_custom_db.assert_not_called()
mock_arguments.assert_not_called()
mock_pipeline.assert_not_called()
mock_runner.assert_not_called()

output_pickle = output_dir / "proteinA.template1.cif.A.pkl"
assert output_pickle.exists()
with open(output_pickle, "rb") as handle:
reused = pickle.load(handle)

assert reused.description == "proteinA.template1.cif.A"
assert np.array_equal(reused.feature_dict["msa"], source.feature_dict["msa"])
assert reused.feature_dict["template_sequence"].tolist() == [b"NEWA", b"NEWB"]
assert np.array_equal(
reused.feature_dict["template_confidence_scores"],
np.ones((2, 4), dtype=np.float32),
)
assert reused.feature_dict["template_release_date"].tolist() == ["none", "none"]
assert list(output_dir.glob("proteinA.template1.cif.A_feature_metadata_*.json"))

def test_process_multimeric_features_falls_back_when_source_sequence_mismatches(
self, tmp_flags
):
template_path = Path(self.test_dir) / "template1.cif"
template_path.write_text("data_template\n", encoding="utf-8")

from absl import flags

FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.output_dir = os.path.join(self.test_dir, "mismatched_truemultimer_output")
FLAGS.use_mmseqs2 = False
FLAGS.compress_features = False
FLAGS.skip_existing = False
FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer"
FLAGS.uniprot_database_path = "/db/uniprot.fasta"

output_dir = Path(FLAGS.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

source = MonomericObject("proteinA", "ACDE")
source.feature_dict = {
"msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
"deletion_matrix_int": np.zeros((1, 4), dtype=np.int32),
"num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32),
"msa_species_identifiers": np.asarray([b"9606"], dtype=object),
}
with open(output_dir / "proteinA.pkl", "wb") as handle:
pickle.dump(source, handle)

feat = {
"protein": "proteinA.template1.cif.A",
"chains": ["A"],
"templates": [str(template_path)],
"sequence": "ACDF",
}

with patch.object(
create_features,
"extract_multimeric_template_features_for_single_chain",
) as mock_extract, \
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.process_multimeric_features(feat, 1)

mock_extract.assert_not_called()
mock_custom_db.assert_called_once()
mock_arguments.assert_called_once_with("/tmp/custom_db")
mock_pipeline.assert_called_once_with()
mock_runner.assert_called_once_with(
FLAGS.jackhmmer_binary_path,
FLAGS.uniprot_database_path,
)
saved_monomer, saved_pipeline = mock_save.call_args.args
assert saved_pipeline == "pipeline"
assert saved_monomer.description == "proteinA.template1.cif.A"
assert saved_monomer.sequence == "ACDF"
assert saved_monomer.uniprot_runner == "runner"

def test_main_dispatches_to_truemultimer_for_af2_template_runs(self):
"""The main entrypoint should route AF2 template jobs to the TrueMultimer path."""
from absl import flags
Expand Down
Loading