diff --git a/alphapulldown/analysis_pipeline/diagnostics.py b/alphapulldown/analysis_pipeline/diagnostics.py new file mode 100644 index 00000000..c2dad42d --- /dev/null +++ b/alphapulldown/analysis_pipeline/diagnostics.py @@ -0,0 +1,308 @@ +"""Helpers for plotting AlphaPulldown diagnostic figures.""" + +from __future__ import annotations + +from pathlib import Path +import re +from typing import Iterable + +import matplotlib +import numpy as np + +matplotlib.use("Agg", force=True) + +from matplotlib import pyplot as plt + +from af2plots.plotter import plotter +from colabfold.plot import plot_msa_v2 + +from alphapulldown.utils.lightweight_pickles import extract_feature_dict, load_lightweight_pickle + + +_RESULT_PICKLE_PATTERNS = ("result*.pkl", "result*.pkl.gz", "result*.pkl.xz") +_RESULT_PICKLE_PATTERN = re.compile( + r"^result(?P_[\w\d]+)?_model_(?P\d+)(?:_\w+)?\.pkl(?:\.(?:gz|xz))?$" +) + + +def _normalise_stem(path: str | Path) -> str: + input_path = Path(path) + name = input_path.name + for suffix in (".pkl.xz", ".pkl.gz", ".pkl", ".json", ".xz", ".gz"): + if name.endswith(suffix): + name = name[: -len(suffix)] + break + return name or input_path.stem + + +def _ensure_output_dir(output_dir: str | Path) -> Path: + destination = Path(output_dir) + destination.mkdir(parents=True, exist_ok=True) + return destination + + +def _infer_asym_id_from_result_pickle(result_pickle: str | Path) -> tuple[list[int], int] | None: + payload = load_lightweight_pickle(result_pickle) + if not isinstance(payload, dict): + return None + seqs = payload.get("seqs") + if not isinstance(seqs, list) or not all(isinstance(sequence, str) for sequence in seqs): + return None + + asym_id: list[int] = [] + for index, sequence in enumerate(seqs, start=1): + asym_id.extend([index] * len(sequence)) + return asym_id, len(seqs) + + +def _ensure_chain_metadata(parsed_models: dict[str, dict]) -> None: + for model_data in parsed_models.values(): + if "asym_id" in model_data and "assembly_num_chains" in model_data: + continue + inferred = _infer_asym_id_from_result_pickle(model_data["fn"]) + if inferred is None: + continue + asym_id, assembly_num_chains = inferred + model_data["asym_id"] = np.asarray(asym_id, dtype=np.int32) + model_data["assembly_num_chains"] = assembly_num_chains + + +def _find_result_pickles(prediction_dir: str | Path) -> list[Path]: + prediction_root = Path(prediction_dir) + suffix_priority = {".pkl": 0, ".gz": 1, ".xz": 2} + selected_paths: dict[str, Path] = {} + + for pattern in _RESULT_PICKLE_PATTERNS: + for path in prediction_root.glob(pattern): + if _RESULT_PICKLE_PATTERN.fullmatch(path.name) is None: + continue + key = _normalise_stem(path) + current = selected_paths.get(key) + if current is None or suffix_priority[path.suffix] < suffix_priority[current.suffix]: + selected_paths[key] = path + + return [selected_paths[key] for key in sorted(selected_paths)] + + +def _parse_prediction_pickles(prediction_dir: str | Path) -> dict[str, dict]: + parsed_models: dict[str, dict] = {} + + for result_pickle in _find_result_pickles(prediction_dir): + match = _RESULT_PICKLE_PATTERN.fullmatch(result_pickle.name) + if match is None: + continue + + data = load_lightweight_pickle(result_pickle) + if not isinstance(data, dict): + raise TypeError(f"Expected dict payload in {result_pickle}, got {type(data)!r}") + + if "ptm" in data: + ptm = float(data["ptm"]) + elif "ranking_confidence" in data: + ptm = float(data["ranking_confidence"]) + else: + ptm = float(np.mean(data["plddt"], dtype=float)) + + parsed_models[str(result_pickle)] = { + "datadir": str(prediction_dir), + "fn": str(result_pickle), + "idx": int(match.group("idx")), + "ptm": ptm, + "iptm": data.get("iptm"), + "distogram": data.get("distogram"), + "sm_contacts": data.get("sm_contacts"), + "pae": data.get("predicted_aligned_error"), + "plddt": data["plddt"], + } + + if not parsed_models: + raise FileNotFoundError( + f"{prediction_dir} does not contain result*.pkl, result*.pkl.gz, or result*.pkl.xz files" + ) + + for rank, path in enumerate( + sorted(parsed_models, key=lambda item: parsed_models[item]["ptm"], reverse=True) + ): + parsed_models[path]["rank"] = rank + 1 + parsed_models[path]["description"] = f"ranked_{rank}.pdb pTM={parsed_models[path]['ptm']:.2f}" + if parsed_models[path]["iptm"] is not None: + parsed_models[path]["description"] += f" iPTM={parsed_models[path]['iptm']:.2f}" + + return parsed_models + + +def _softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray: + shifted = logits - np.max(logits, axis=axis, keepdims=True) + exponentiated = np.exp(shifted) + return exponentiated / np.sum(exponentiated, axis=axis, keepdims=True) + + +def _plot_distogram_fallback( + parsed_models: dict[str, dict], + *, + dpi: int = 100, + distance: float = 8.0, +) -> tuple[plt.Figure, list[str]] | None: + top_model = next( + (model for model in parsed_models.values() if model.get("rank") == 1), + max(parsed_models.values(), key=lambda model: float(model["ptm"]), default=None), + ) + if top_model is None: + return None + + predicted_distogram = top_model.get("distogram") + if not isinstance(predicted_distogram, dict): + return None + + logits = predicted_distogram.get("logits") + bin_edges = predicted_distogram.get("bin_edges") + if logits is None or bin_edges is None: + return None + + logits_array = np.asarray(logits) + bin_edges_array = np.asarray(bin_edges) + if logits_array.ndim != 3 or bin_edges_array.ndim != 1: + return None + + upper_bounds = np.concatenate([bin_edges_array, [np.inf]]) + threshold_mask = upper_bounds < float(np.clip(distance, 3.0, 20.0)) + if not np.any(threshold_mask): + threshold_mask[0] = True + + contact_probabilities = _softmax(logits_array, axis=-1)[..., threshold_mask].sum(axis=-1) + + figure, axis = plt.subplots(figsize=(8, 8), dpi=dpi) + image = axis.imshow( + contact_probabilities, + cmap="coolwarm", + vmin=0.0, + vmax=1.0, + extent=(0, contact_probabilities.shape[0], contact_probabilities.shape[0], 0), + ) + colorbar = figure.colorbar(image, ax=axis, fraction=0.046, pad=0.04) + colorbar.ax.set_ylabel(f"Probability(distance<{int(distance)}A)") + axis.set_title("Predicted contacts") + axis.set_xlabel("Residue number") + axis.set_ylabel("Residue number") + + asym_id = top_model.get("asym_id") + assembly_num_chains = top_model.get("assembly_num_chains") + if asym_id is not None and assembly_num_chains is not None: + asym_id_array = np.asarray(asym_id) + for chain_index in range(int(assembly_num_chains) - 1): + chain_positions = np.where(asym_id_array == (chain_index + 1))[0] + if chain_positions.size == 0: + continue + chain_cut = int(chain_positions.max()) + axis.axvline(x=chain_cut, ls="--", c="k", lw=1) + axis.axhline(y=chain_cut, ls="--", c="k", lw=1) + + return figure, [] + + +def _plot_distogram_compat( + af2_plotter: object, + parsed_models: dict[str, dict], + *, + dpi: int = 100, +) -> tuple[plt.Figure, list[str]] | None: + plot_distogram = getattr(af2_plotter, "plot_distogram", None) + if callable(plot_distogram): + return plot_distogram(parsed_models, dpi=dpi) + return _plot_distogram_fallback(parsed_models, dpi=dpi) + + +def save_msa_coverage_plot( + feature_pickle: str | Path, + output_dir: str | Path, + *, + dpi: int = 100, + output_stem: str | None = None, +) -> Path: + """Save a ColabFold-style MSA coverage plot from a feature pickle.""" + + payload = load_lightweight_pickle(feature_pickle) + feature_dict = extract_feature_dict(payload) + destination = _ensure_output_dir(output_dir) + plot_module = plot_msa_v2(feature_dict, dpi=dpi) + output_path = destination / f"{output_stem or _normalise_stem(feature_pickle)}_msa_coverage.png" + plot_module.savefig(output_path, bbox_inches="tight") + plot_module.close() + return output_path + + +def save_prediction_plots( + prediction_dir: str | Path, + output_dir: str | Path, + *, + dpi: int = 100, +) -> list[Path]: + """Save pLDDT, PAE, and distogram plots from a prediction directory.""" + + prediction_root = Path(prediction_dir) + destination = _ensure_output_dir(output_dir) + af2_plotter = plotter() + parsed_models = _parse_prediction_pickles(prediction_root) + _ensure_chain_metadata(parsed_models) + output_prefix = destination / prediction_root.name + + written_paths: list[Path] = [] + + pae_figure = af2_plotter.plot_predicted_alignment_error(parsed_models, dpi=dpi) + pae_path = output_prefix.with_name(f"{output_prefix.name}_pae.png") + pae_figure.savefig(pae_path, bbox_inches="tight") + plt.close(pae_figure) + written_paths.append(pae_path) + + plddt_figure = af2_plotter.plot_plddts(parsed_models, dpi=dpi) + plddt_path = output_prefix.with_name(f"{output_prefix.name}_plddt.png") + plddt_figure.savefig(plddt_path, bbox_inches="tight") + plt.close(plddt_figure) + written_paths.append(plddt_path) + + distogram_result = _plot_distogram_compat(af2_plotter, parsed_models, dpi=dpi) + if distogram_result is not None: + distogram_figure, _ = distogram_result + distogram_path = output_prefix.with_name(f"{output_prefix.name}_distogram.png") + distogram_figure.savefig(distogram_path, bbox_inches="tight") + plt.close(distogram_figure) + written_paths.append(distogram_path) + + return written_paths + + +def plot_inputs( + inputs: Iterable[str | Path], + *, + output_dir: str | Path | None = None, + dpi: int = 100, +) -> list[Path]: + """Dispatch plotting based on the provided input paths.""" + + written_paths: list[Path] = [] + for raw_input in inputs: + input_path = Path(raw_input) + destination = Path(output_dir) if output_dir is not None else input_path.parent + + if input_path.is_dir(): + if _find_result_pickles(input_path): + written_paths.extend(save_prediction_plots(input_path, destination, dpi=dpi)) + continue + feature_pickle = input_path / "features.pkl" + if feature_pickle.exists(): + written_paths.append( + save_msa_coverage_plot( + feature_pickle, + destination, + dpi=dpi, + output_stem=input_path.name, + ) + ) + continue + raise FileNotFoundError( + f"{input_path} does not contain result*.pkl(.gz/.xz) files or a features.pkl file" + ) + + written_paths.append(save_msa_coverage_plot(input_path, destination, dpi=dpi)) + + return written_paths diff --git a/alphapulldown/analysis_pipeline/interaction_network.py b/alphapulldown/analysis_pipeline/interaction_network.py new file mode 100644 index 00000000..60a93b62 --- /dev/null +++ b/alphapulldown/analysis_pipeline/interaction_network.py @@ -0,0 +1,393 @@ +"""Utilities for plotting AlphaPulldown interaction networks.""" + +from __future__ import annotations + +from collections import defaultdict +from itertools import combinations +from pathlib import Path +import math +import re + +import matplotlib + +matplotlib.use("Agg", force=True) + +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt + + +_INTERFACE_PATTERN = re.compile(r"^(?P[A-Z]+)_(?P[A-Z]+)$") +_HOMO_OLIGOMER_PATTERN = re.compile(r"^(?P.+)_homo_(?P\d+)er$") + + +def _expand_ap_style_homo_oligomer(token: str) -> list[str]: + match = _HOMO_OLIGOMER_PATTERN.fullmatch(token) + if match is None: + return [token] if token else [] + return [match.group("name")] * int(match.group("count")) + + +def split_job_name(job_name: str) -> list[str]: + """Split AlphaPulldown job names into ordered interactors.""" + + if "_and_" in job_name: + parts = [part for part in job_name.split("_and_") if part] + elif "+" in job_name: + parts = [part for part in job_name.split("+") if part] + else: + parts = [job_name] if job_name else [] + + expanded_parts: list[str] = [] + for part in parts: + expanded_parts.extend(_expand_ap_style_homo_oligomer(part)) + return expanded_parts + + +def _parse_float(value) -> float | None: + try: + parsed = float(value) + except (TypeError, ValueError): + return None + return None if math.isnan(parsed) else parsed + + +def _chain_label_to_index(label: str) -> int: + index = 0 + for character in label: + index = index * 26 + (ord(character) - ord("A") + 1) + return index - 1 + + +def _extract_pairs_for_row(job_name: str, interface: str | None) -> list[tuple[str, str]]: + interactors = split_job_name(job_name) + if len(interactors) < 2: + return [] + + if isinstance(interface, str): + match = _INTERFACE_PATTERN.fullmatch(interface.strip()) + if match is not None: + left_index = _chain_label_to_index(match.group("left")) + right_index = _chain_label_to_index(match.group("right")) + if left_index < len(interactors) and right_index < len(interactors): + return [(interactors[left_index], interactors[right_index])] + + if len(interactors) == 2: + return [(interactors[0], interactors[1])] + + return list(combinations(interactors, 2)) + + +def build_interaction_edge_table( + score_table: pd.DataFrame, + *, + score_column: str = "iptm_ptm", + min_score: float = 0.0, + max_pae: float | None = None, +) -> pd.DataFrame: + """Collapse an AlphaPulldown score table into one undirected edge table.""" + + aggregated: dict[tuple[str, str], dict[str, object]] = {} + for row in score_table.to_dict(orient="records"): + job_name = str(row.get("jobs", "")).strip() + if not job_name: + continue + + score = _parse_float(row.get(score_column)) + if score is None or score < min_score: + continue + + if max_pae is not None: + pae = _parse_float(row.get("average_interface_pae")) + if pae is None or pae > max_pae: + continue + + pairs = _extract_pairs_for_row(job_name, row.get("interface")) + for left, right in pairs: + source, target = sorted((left, right)) + key = (source, target) + record = aggregated.setdefault( + key, + { + "source": source, + "target": target, + "score": score, + "support": 0, + "jobs": set(), + "self_interaction": source == target, + }, + ) + record["score"] = max(float(record["score"]), score) + record["support"] = int(record["support"]) + 1 + cast_jobs = record["jobs"] + assert isinstance(cast_jobs, set) + cast_jobs.add(job_name) + + rows: list[dict[str, object]] = [] + for record in aggregated.values(): + jobs = sorted(record.pop("jobs")) # type: ignore[arg-type] + rows.append({**record, "jobs": ";".join(jobs)}) + + edge_table = pd.DataFrame(rows) + if edge_table.empty: + return edge_table + return edge_table.sort_values( + by=["self_interaction", "score", "source", "target"], + ascending=[True, False, True, True], + ).reset_index(drop=True) + + +def summarise_nodes(edge_table: pd.DataFrame) -> pd.DataFrame: + """Create a node summary table from an edge table.""" + + degree_by_node: dict[str, int] = defaultdict(int) + best_score_by_node: dict[str, float] = defaultdict(float) + self_score_by_node: dict[str, float] = defaultdict(float) + + for row in edge_table.to_dict(orient="records"): + source = str(row["source"]) + target = str(row["target"]) + score = float(row["score"]) + if bool(row.get("self_interaction")): + self_score_by_node[source] = max(self_score_by_node[source], score) + best_score_by_node[source] = max(best_score_by_node[source], score) + continue + + degree_by_node[source] += 1 + degree_by_node[target] += 1 + best_score_by_node[source] = max(best_score_by_node[source], score) + best_score_by_node[target] = max(best_score_by_node[target], score) + + nodes = sorted(set(degree_by_node) | set(best_score_by_node) | set(self_score_by_node)) + node_rows = [ + { + "node": node, + "degree": degree_by_node[node], + "best_score": best_score_by_node[node], + "self_interaction_score": self_score_by_node[node], + } + for node in nodes + ] + return pd.DataFrame(node_rows).sort_values( + by=["degree", "best_score", "node"], + ascending=[False, False, True], + ).reset_index(drop=True) + + +def _connected_components(nodes: list[str], edges: pd.DataFrame) -> list[list[str]]: + adjacency: dict[str, set[str]] = {node: set() for node in nodes} + for row in edges.to_dict(orient="records"): + if bool(row.get("self_interaction")): + continue + source = str(row["source"]) + target = str(row["target"]) + adjacency[source].add(target) + adjacency[target].add(source) + + components: list[list[str]] = [] + seen: set[str] = set() + for node in nodes: + if node in seen: + continue + stack = [node] + component: list[str] = [] + while stack: + current = stack.pop() + if current in seen: + continue + seen.add(current) + component.append(current) + stack.extend(sorted(adjacency[current] - seen)) + components.append(sorted(component)) + return components + + +def _spring_layout(component_nodes: list[str], component_edges: pd.DataFrame, *, seed: int) -> dict[str, np.ndarray]: + if len(component_nodes) == 1: + return {component_nodes[0]: np.zeros(2, dtype=float)} + + index_by_node = {node: index for index, node in enumerate(component_nodes)} + positions = np.random.default_rng(seed).normal(scale=0.25, size=(len(component_nodes), 2)) + weights = np.ones((len(component_nodes), len(component_nodes)), dtype=float) + + for row in component_edges.to_dict(orient="records"): + if bool(row.get("self_interaction")): + continue + left = index_by_node[str(row["source"])] + right = index_by_node[str(row["target"])] + weights[left, right] = max(float(row["score"]), 0.05) + weights[right, left] = weights[left, right] + + ideal_distance = math.sqrt(1.0 / max(len(component_nodes), 1)) + temperature = 0.25 + for _ in range(150): + displacement = np.zeros_like(positions) + delta = positions[:, np.newaxis, :] - positions[np.newaxis, :, :] + distance = np.linalg.norm(delta, axis=-1) + 1e-6 + + repulsion = delta * ((ideal_distance * ideal_distance) / (distance * distance))[:, :, np.newaxis] + displacement += np.nansum(repulsion, axis=1) + + for left in range(len(component_nodes)): + for right in range(left + 1, len(component_nodes)): + if weights[left, right] == 1.0 and weights[right, left] == 1.0: + continue + difference = positions[left] - positions[right] + edge_distance = np.linalg.norm(difference) + 1e-6 + attraction = difference * ((edge_distance / ideal_distance) * weights[left, right]) + displacement[left] -= attraction + displacement[right] += attraction + + norms = np.linalg.norm(displacement, axis=1) + 1e-6 + positions += (displacement / norms[:, np.newaxis]) * np.minimum(norms, temperature)[:, np.newaxis] + positions -= positions.mean(axis=0) + temperature *= 0.95 + + max_extent = np.max(np.linalg.norm(positions, axis=1)) + if max_extent > 0: + positions /= max_extent + return {node: positions[index] for node, index in index_by_node.items()} + + +def compute_network_layout(edge_table: pd.DataFrame, *, seed: int = 0) -> dict[str, np.ndarray]: + """Compute a dependency-free spring layout for an interaction network.""" + + if edge_table.empty: + return {} + + nodes = sorted(set(edge_table["source"]) | set(edge_table["target"])) + components = _connected_components(nodes, edge_table) + component_columns = max(1, math.ceil(math.sqrt(len(components)))) + + layout: dict[str, np.ndarray] = {} + for component_index, component_nodes in enumerate( + sorted(components, key=len, reverse=True) + ): + component_edges = edge_table[ + edge_table["source"].isin(component_nodes) + & edge_table["target"].isin(component_nodes) + ] + component_layout = _spring_layout( + component_nodes, + component_edges, + seed=seed + component_index, + ) + + row_index = component_index // component_columns + column_index = component_index % component_columns + offset = np.asarray([column_index * 3.0, -row_index * 3.0]) + for node, position in component_layout.items(): + layout[node] = position + offset + + return layout + + +def plot_interaction_network( + edge_table: pd.DataFrame, + output_path: str | Path, + *, + title: str = "Interaction Network", + label_top_n: int = 30, + dpi: int = 150, + seed: int = 0, +) -> Path: + """Render an interaction network plot to disk.""" + + if edge_table.empty: + raise ValueError("edge_table is empty; no interactions passed the selected filters") + + output_file = Path(output_path) + output_file.parent.mkdir(parents=True, exist_ok=True) + + layout = compute_network_layout(edge_table, seed=seed) + node_summary = summarise_nodes(edge_table) + component_ids = { + node: index + for index, component in enumerate( + _connected_components(sorted(layout), edge_table) + ) + for node in component + } + + fig, ax = plt.subplots(figsize=(11, 8), dpi=dpi) + ax.set_title(title) + + non_self_edges = edge_table[~edge_table["self_interaction"]] + if not non_self_edges.empty: + min_score = float(non_self_edges["score"].min()) + max_score = float(non_self_edges["score"].max()) + score_span = max(max_score - min_score, 1e-6) + for row in non_self_edges.to_dict(orient="records"): + source = str(row["source"]) + target = str(row["target"]) + x_values = [layout[source][0], layout[target][0]] + y_values = [layout[source][1], layout[target][1]] + normalized_score = (float(row["score"]) - min_score) / score_span + ax.plot( + x_values, + y_values, + color="0.55", + alpha=0.35 + 0.45 * normalized_score, + linewidth=1.0 + 3.0 * normalized_score, + zorder=1, + ) + + color_map = plt.get_cmap("tab20", max(len(component_ids), 1)) + max_degree = max(int(node_summary["degree"].max()), 1) + max_self_score = max(float(node_summary["self_interaction_score"].max()), 1.0) + + for row in node_summary.to_dict(orient="records"): + node = str(row["node"]) + x_coord, y_coord = layout[node] + degree = int(row["degree"]) + self_score = float(row["self_interaction_score"]) + node_size = 220 + 110 * degree + (180 if self_score > 0 else 0) + face_color = color_map(component_ids.get(node, 0)) + ax.scatter( + [x_coord], + [y_coord], + s=node_size, + c=[face_color], + edgecolors="black", + linewidths=1.0, + zorder=2, + ) + if self_score > 0: + ax.scatter( + [x_coord], + [y_coord], + s=node_size * (1.15 + 0.35 * (self_score / max_self_score)), + facecolors="none", + edgecolors="black", + linewidths=1.2, + zorder=3, + ) + + labels_to_draw = node_summary.head(label_top_n) + for row in labels_to_draw.to_dict(orient="records"): + node = str(row["node"]) + x_coord, y_coord = layout[node] + ax.text( + x_coord, + y_coord + 0.09, + node, + fontsize=9, + ha="center", + va="bottom", + zorder=4, + ) + + ax.set_axis_off() + fig.tight_layout() + fig.savefig(output_file, bbox_inches="tight") + plt.close(fig) + return output_file + + +def write_table(table: pd.DataFrame, output_path: str | Path) -> Path: + """Write a pandas table to CSV.""" + + output_file = Path(output_path) + output_file.parent.mkdir(parents=True, exist_ok=True) + table.to_csv(output_file, index=False) + return output_file diff --git a/alphapulldown/analysis_pipeline/plot_diagnostics.py b/alphapulldown/analysis_pipeline/plot_diagnostics.py new file mode 100644 index 00000000..c2ce20b7 --- /dev/null +++ b/alphapulldown/analysis_pipeline/plot_diagnostics.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +"""Generate AlphaPulldown diagnostic plots similar to ColabFold outputs.""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +from alphapulldown.analysis_pipeline.diagnostics import plot_inputs + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Write MSA coverage, pLDDT, PAE, and distogram plots from " + "AlphaPulldown feature pickles or prediction directories." + ) + ) + parser.add_argument( + "inputs", + nargs="+", + help=( + "Feature pickles, directories containing features.pkl, or " + "prediction directories containing result*.pkl files." + ), + ) + parser.add_argument( + "--output_dir", + default=None, + help="Directory to write plots into. Defaults to the parent directory of each input.", + ) + parser.add_argument( + "--dpi", + type=int, + default=100, + help="Matplotlib DPI to use for saved plots.", + ) + return parser + + +def main(argv: list[str] | None = None) -> list[Path]: + parser = build_parser() + args = parser.parse_args(argv) + written_paths = plot_inputs( + args.inputs, + output_dir=args.output_dir, + dpi=args.dpi, + ) + for path in written_paths: + print(path) + return written_paths + + +if __name__ == "__main__": + main() diff --git a/alphapulldown/analysis_pipeline/plot_interaction_network.py b/alphapulldown/analysis_pipeline/plot_interaction_network.py new file mode 100644 index 00000000..f67f3720 --- /dev/null +++ b/alphapulldown/analysis_pipeline/plot_interaction_network.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +"""Plot an interaction network from AlphaPulldown score tables.""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import pandas as pd + +from alphapulldown.analysis_pipeline.interaction_network import ( + build_interaction_edge_table, + plot_interaction_network, + summarise_nodes, + write_table, +) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Create an interaction-network plot from an AlphaPulldown score CSV " + "(for example good_interpae or pi_score outputs)." + ) + ) + parser.add_argument("input_csv", help="Path to the score CSV.") + parser.add_argument("output_plot", help="Path to the output PNG/PDF plot.") + parser.add_argument( + "--score_column", + default="iptm_ptm", + help="Numeric column used for edge strength filtering and styling.", + ) + parser.add_argument( + "--min_score", + type=float, + default=0.0, + help="Minimum score required for an interaction to be plotted.", + ) + parser.add_argument( + "--max_pae", + type=float, + default=None, + help="Optional maximum average_interface_pae filter.", + ) + parser.add_argument( + "--label_top_n", + type=int, + default=30, + help="Number of top-ranked nodes to label.", + ) + parser.add_argument( + "--title", + default="Interaction Network", + help="Plot title.", + ) + parser.add_argument( + "--edges_out", + default=None, + help="Optional CSV path for the aggregated edge table.", + ) + parser.add_argument( + "--nodes_out", + default=None, + help="Optional CSV path for the node summary table.", + ) + parser.add_argument( + "--dpi", + type=int, + default=150, + help="Matplotlib DPI for raster outputs.", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed for the force layout.", + ) + return parser + + +def main(argv: list[str] | None = None) -> list[Path]: + parser = build_parser() + args = parser.parse_args(argv) + + score_table = pd.read_csv(args.input_csv) + edge_table = build_interaction_edge_table( + score_table, + score_column=args.score_column, + min_score=args.min_score, + max_pae=args.max_pae, + ) + plot_path = plot_interaction_network( + edge_table, + args.output_plot, + title=args.title, + label_top_n=args.label_top_n, + dpi=args.dpi, + seed=args.seed, + ) + + written_paths = [plot_path] + if args.edges_out: + written_paths.append(write_table(edge_table, args.edges_out)) + if args.nodes_out: + node_table = summarise_nodes(edge_table) + written_paths.append(write_table(node_table, args.nodes_out)) + + for path in written_paths: + print(path) + return written_paths + + +if __name__ == "__main__": + main() diff --git a/alphapulldown/folding_backend/alphafold2_backend.py b/alphapulldown/folding_backend/alphafold2_backend.py index 42abae12..9c25e70e 100644 --- a/alphapulldown/folding_backend/alphafold2_backend.py +++ b/alphapulldown/folding_backend/alphafold2_backend.py @@ -39,6 +39,46 @@ RELAX_MAX_OUTER_ITERATIONS = 3 +def _select_models_to_relax( + ranked_order: List[str], + *, + models_to_relax: "ModelsToRelax", + iptm_scores: Dict[str, float], + ptm_scores: Dict[str, float], + relax_best_score_threshold, +) -> List[str]: + if models_to_relax == ModelsToRelax.ALL: + return ranked_order + if models_to_relax == ModelsToRelax.NONE: + return [] + if not ranked_order: + return [] + + best_model = ranked_order[0] + if relax_best_score_threshold is None: + return [best_model] + + if best_model in iptm_scores: + score_name = "iPTM" + score_value = iptm_scores[best_model] + else: + score_name = "pTM" + score_value = ptm_scores.get(best_model) + + if score_value is None or score_value < relax_best_score_threshold: + logging.info( + "Skipping relaxation for %s because its %s score %.3f is below the " + "requested threshold %.3f.", + best_model, + score_name, + 0.0 if score_value is None else score_value, + relax_best_score_threshold, + ) + return [] + + return [best_model] + + @enum.unique class ModelsToRelax(enum.Enum): ALL = 0 @@ -788,6 +828,7 @@ def postprocess( output_dir: str, features_directory: str, models_to_relax: ModelsToRelax, + relax_best_score_threshold = None, compress_pickles: bool = False, remove_pickles: bool = False, remove_keys_from_pickles: bool = False, @@ -917,12 +958,13 @@ def postprocess( max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS, use_gpu=_resolve_gpu_relax(use_gpu_relax)) - if models_to_relax == ModelsToRelax.BEST: - to_relax = [ranked_order[0]] - elif models_to_relax == ModelsToRelax.ALL: - to_relax = ranked_order - elif models_to_relax == ModelsToRelax.NONE: - to_relax = [] + to_relax = _select_models_to_relax( + ranked_order, + models_to_relax=models_to_relax, + iptm_scores=iptm_scores, + ptm_scores=ptm_scores, + relax_best_score_threshold=relax_best_score_threshold, + ) for model_name in to_relax: if f'relax_{model_name}' in timings: diff --git a/alphapulldown/objects.py b/alphapulldown/objects.py index 765d3165..72adcd51 100644 --- a/alphapulldown/objects.py +++ b/alphapulldown/objects.py @@ -518,14 +518,18 @@ class MultimericObject: interactors: individual interactors that are to be concatenated pair_msa: boolean, tells the programme whether to pair MSA or not multimeric_template: boolean, tells the programme whether use multimeric templates or not - multimeric_template_meta_data: a csv with the format {"monomer_A":{"xxx.cif":"chainID"},"monomer_B":{"yyy.cif":"chainID"}} + multimeric_template_meta_data: a csv with the format + {"monomer_A": [("xxx.cif", "chainID")], "monomer_B": [("yyy.cif", "chainID")]} multimeric_template_dir: a directory where all the multimeric templates mmcifs files are stored """ def __init__(self, interactors: list, pair_msa: bool = True, multimeric_template: bool = False, multimeric_template_meta_data: str = None, - multimeric_template_dir:str = None) -> None: + multimeric_template_dir:str = None, + threshold_clashes: float = 1000, + hb_allowance: float = 0.4, + plddt_threshold: float = 0) -> None: self.description = "" self.interactors = interactors self.build_description_monomer_mapping() @@ -534,6 +538,9 @@ def __init__(self, interactors: list, pair_msa: bool = True, self.chain_id_map = dict() self.input_seqs = [] self.multimeric_template_dir = multimeric_template_dir + self.threshold_clashes = threshold_clashes + self.hb_allowance = hb_allowance + self.plddt_threshold = plddt_threshold self.create_output_name() if multimeric_template_meta_data is not None: @@ -624,6 +631,20 @@ def create_multichain_mask(self): # DEBUG #self.save_binary_matrix(multichain_mask, "multichain_mask.png") return multichain_mask + + def _get_matching_interactors_for_template_name(self, monomer_name): + """Resolve template CSV names to interactors, including chopped aliases.""" + exact_matches = [ + interactor for interactor in self.interactors + if interactor.description == monomer_name + ] + if exact_matches: + return exact_matches + + return [ + interactor for interactor in self.interactors + if getattr(interactor, "monomeric_description", None) == monomer_name + ] def create_multimeric_template_features(self): """A method of creating multimeric template features""" @@ -636,15 +657,31 @@ def create_multimeric_template_features(self): """) pass else: - for monomer_name in self.multimeric_template_meta_data: - for k,v in self.multimeric_template_meta_data[monomer_name].items(): - curr_monomer = self.monomers_mapping[monomer_name] + for monomer_name, template_entries in self.multimeric_template_meta_data.items(): + matching_monomers = self._get_matching_interactors_for_template_name(monomer_name) + if not matching_monomers: + raise KeyError(monomer_name) + + if len(matching_monomers) == 1: + monomer_assignments = [matching_monomers[0]] * len(template_entries) + elif len(template_entries) <= len(matching_monomers): + monomer_assignments = matching_monomers[:len(template_entries)] + else: + raise ValueError( + f"Found {len(template_entries)} template assignments for '{monomer_name}' " + f"but only {len(matching_monomers)} matching interactors." + ) + + for curr_monomer, (k, v) in zip(monomer_assignments, template_entries): assert k.endswith(".cif"), "The multimeric template file you provided does not seem to be a mmcif file. Please check your format and make sure it ends with .cif" assert os.path.exists(os.path.join(self.multimeric_template_dir,k)), f"Your provided {k} cannot be found in: {self.multimeric_template_dir}. Abort" pdb_id = k.split('.cif')[0] multimeric_template_features = extract_multimeric_template_features_for_single_chain(query_seq=curr_monomer.sequence, pdb_id=pdb_id,chain_id=v, - mmcif_file=os.path.join(self.multimeric_template_dir,k)) + mmcif_file=os.path.join(self.multimeric_template_dir,k), + threshold_clashes=getattr(self, "threshold_clashes", 1000), + hb_allowance=getattr(self, "hb_allowance", 0.4), + plddt_threshold=getattr(self, "plddt_threshold", 0)) curr_monomer.feature_dict.update(multimeric_template_features.features) diff --git a/alphapulldown/scripts/create_individual_features.py b/alphapulldown/scripts/create_individual_features.py index a1cfd326..b8fed124 100644 --- a/alphapulldown/scripts/create_individual_features.py +++ b/alphapulldown/scripts/create_individual_features.py @@ -68,7 +68,7 @@ "rna_central": "rnacentral_active_seq_id_90_cov_80_linclust.fasta", } -AF2_DATABASE_FLAGS = { +AF2_FULL_DATABASE_FLAGS = { "uniref90_database_path": "uniref90", "uniref30_database_path": "uniref30", "mgnify_database_path": "mgnify", @@ -81,6 +81,16 @@ "obsolete_pdbs_path": "obsolete_pdbs", } +AF2_REDUCED_DATABASE_FLAGS = { + "uniref90_database_path": "uniref90", + "mgnify_database_path": "mgnify", + "small_bfd_database_path": "small_bfd", + "uniprot_database_path": "uniprot", + "pdb_seqres_database_path": "pdb_seqres", + "template_mmcif_dir": "template_mmcif_dir", + "obsolete_pdbs_path": "obsolete_pdbs", +} + AF3_DATABASE_FLAGS = { "uniref90_database_path": "uniref90", "mgnify_database_path": "mgnify", @@ -90,7 +100,11 @@ "template_mmcif_dir": "template_mmcif_dir", } -DATABASE_PATH_FLAGS = frozenset(AF2_DATABASE_FLAGS) | frozenset(AF3_DATABASE_FLAGS) +DATABASE_PATH_FLAGS = ( + frozenset(AF2_FULL_DATABASE_FLAGS) + | frozenset(AF2_REDUCED_DATABASE_FLAGS) + | frozenset(AF3_DATABASE_FLAGS) +) # =================== Flags =================== flags.DEFINE_enum( @@ -180,9 +194,7 @@ def create_arguments(local_custom_template_db=None): Optionally override template paths with a local custom template DB.""" validate_data_pipeline_flags() - required_database_flags = ( - AF3_DATABASE_FLAGS if FLAGS.data_pipeline == 'alphafold3' else AF2_DATABASE_FLAGS - ) + required_database_flags = get_required_database_flags() # When using MMseqs2 (current implementation uses remote servers), database paths are not needed # Note: Current MMseqs2 implementation uses remote servers via DEFAULT_API_SERVER @@ -202,6 +214,20 @@ def create_arguments(local_custom_template_db=None): FLAGS.template_mmcif_dir = os.path.join(local_custom_template_db, "pdb_mmcif", "mmcif_files") FLAGS.obsolete_pdbs_path = os.path.join(local_custom_template_db, "pdb_mmcif", "obsolete.dat") + +def get_required_database_flags(): + """Return the database flags required by the selected pipeline and preset.""" + if FLAGS.data_pipeline == "alphafold3": + return AF3_DATABASE_FLAGS + + if FLAGS.db_preset == "reduced_dbs": + required_flags = dict(AF2_REDUCED_DATABASE_FLAGS) + if FLAGS.use_hhsearch: + required_flags["pdb70_database_path"] = "pdb70" + return required_flags + + return AF2_FULL_DATABASE_FLAGS + def check_template_date(): """Check if the max_template_date is provided.""" if not FLAGS.max_template_date: diff --git a/alphapulldown/scripts/generate_alphafold_server_json.py b/alphapulldown/scripts/generate_alphafold_server_json.py new file mode 100644 index 00000000..5f1afe1e --- /dev/null +++ b/alphapulldown/scripts/generate_alphafold_server_json.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""Generate AlphaFold Server batch JSON files from AlphaPulldown job inputs.""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +from alphapulldown.utils.alphafold_server_json import ( + build_alphafold_server_jobs, + write_jobs_to_json_files, +) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Convert AlphaPulldown protein list jobs into AlphaFold Server batch JSON files." + ) + ) + parser.add_argument( + "--protein_lists", + required=True, + nargs="+", + help="One or more AlphaPulldown protein list files.", + ) + parser.add_argument( + "--monomer_objects_dir", + required=True, + nargs="+", + help="One or more directories containing AlphaPulldown monomer feature pickles.", + ) + parser.add_argument( + "--output_path", + required=True, + help="Output JSON path. Files are automatically split if more than --jobs_per_file jobs are generated.", + ) + parser.add_argument( + "--mode", + default="pulldown", + choices=["pulldown", "all_vs_all", "homo-oligomer", "custom"], + help="Job generation mode, matching run_multimer_jobs.py.", + ) + parser.add_argument( + "--oligomer_state_file", + default=None, + help="Path to the oligomer-state file used for mode=homo-oligomer.", + ) + parser.add_argument( + "--protein_delimiter", + default="+", + help="Protein delimiter used by alphapulldown-input-parser.", + ) + parser.add_argument( + "--model_seeds", + default="", + help="Comma-separated list of AlphaFold Server seeds. Leave empty to request automatic seed assignment.", + ) + parser.add_argument( + "--job_index", + type=int, + default=None, + help="1-based job index to export. Export all jobs when omitted.", + ) + parser.add_argument( + "--jobs_per_file", + type=int, + default=100, + help="Maximum number of jobs per output JSON file. AlphaFold Server currently accepts up to 100 jobs per file.", + ) + return parser + + +def main(argv: list[str] | None = None) -> list[Path]: + parser = build_parser() + args = parser.parse_args(argv) + model_seeds = [seed.strip() for seed in args.model_seeds.split(",") if seed.strip()] + jobs = build_alphafold_server_jobs( + protein_lists=args.protein_lists, + monomer_directories=args.monomer_objects_dir, + mode=args.mode, + oligomer_state_file=args.oligomer_state_file, + protein_delimiter=args.protein_delimiter, + model_seeds=model_seeds, + job_index=args.job_index, + ) + written_paths = write_jobs_to_json_files( + jobs, + args.output_path, + jobs_per_file=args.jobs_per_file, + ) + for path in written_paths: + print(path) + return written_paths + + +if __name__ == "__main__": + main() diff --git a/alphapulldown/scripts/run_multimer_jobs.py b/alphapulldown/scripts/run_multimer_jobs.py index 51fd6c36..8268fb7b 100644 --- a/alphapulldown/scripts/run_multimer_jobs.py +++ b/alphapulldown/scripts/run_multimer_jobs.py @@ -163,7 +163,11 @@ def main(argv): "--protein_delimiter": FLAGS.protein_delimiter, "--desired_num_res": FLAGS.desired_num_res, "--desired_num_msa": FLAGS.desired_num_msa, - "--models_to_relax": FLAGS.models_to_relax + "--models_to_relax": FLAGS.models_to_relax, + "--relax_best_score_threshold": FLAGS.relax_best_score_threshold, + "--threshold_clashes": FLAGS.threshold_clashes, + "--hb_allowance": FLAGS.hb_allowance, + "--plddt_threshold": FLAGS.plddt_threshold, } command_args = {} diff --git a/alphapulldown/scripts/run_structure_prediction.py b/alphapulldown/scripts/run_structure_prediction.py index 8ce12124..6fc8efd9 100644 --- a/alphapulldown/scripts/run_structure_prediction.py +++ b/alphapulldown/scripts/run_structure_prediction.py @@ -73,6 +73,18 @@ 'Path to the text file with multimeric template instruction.') flags.DEFINE_string('path_to_mmt', None, 'Path to directory with multimeric template mmCIF files.') +flags.DEFINE_float( + 'threshold_clashes', + 1000, + 'Threshold for VDW overlap used to remove clashes from quick-mode multimeric templates.') +flags.DEFINE_float( + 'hb_allowance', + 0.4, + 'Allowance for hydrogen bonding when filtering quick-mode multimeric templates.') +flags.DEFINE_float( + 'plddt_threshold', + 0, + 'Threshold for removing low-pLDDT residues from quick-mode multimeric templates.') flags.DEFINE_integer('desired_num_res', None, 'A desired number of residues to pad') flags.DEFINE_integer('desired_num_msa', None, @@ -90,6 +102,12 @@ "in case you are having issues with the relaxation " "stage.", ) +flags.DEFINE_float( + "relax_best_score_threshold", + None, + "Optional minimum iPTM/pTM score required before relaxing the best-ranked " + "model when --models_to_relax=Best.", +) flags.DEFINE_enum('model_preset', 'monomer', ['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'], 'Choose preset model configuration - the monomer model, ' @@ -216,11 +234,13 @@ def _validate_flags_for_backend(backend_name: str) -> None: # Backend-specific flags af2_like_flags = { 'compress_result_pickles', 'remove_result_pickles', 'models_to_relax', - 'remove_keys_from_pickles', 'convert_to_modelcif', 'allow_resume', + 'relax_best_score_threshold', 'remove_keys_from_pickles', + 'convert_to_modelcif', 'allow_resume', 'num_cycle', 'num_predictions_per_model', 'pair_msa', 'save_features_for_multimeric_object', 'skip_templates', 'msa_depth_scan', 'multimeric_template', 'model_names', 'msa_depth', - 'description_file', 'path_to_mmt', 'desired_num_res', 'desired_num_msa', + 'description_file', 'path_to_mmt', 'threshold_clashes', 'hb_allowance', + 'plddt_threshold', 'desired_num_res', 'desired_num_msa', 'benchmark', 'model_preset', 'use_ap_style', 'use_gpu_relax', 'dropout', } alphalink_extra = {'crosslinks'} @@ -336,6 +356,9 @@ def pre_modelling_setup( multimeric_template=FLAGS.multimeric_template, multimeric_template_meta_data=FLAGS.description_file, multimeric_template_dir=FLAGS.path_to_mmt, + threshold_clashes=FLAGS.threshold_clashes, + hb_allowance=FLAGS.hb_allowance, + plddt_threshold=FLAGS.plddt_threshold, ) if FLAGS.save_features_for_multimeric_object: pickle.dump(MultimericObject.feature_dict, open(join(output_dir, "multimeric_object_features.pkl"), "wb")) @@ -459,6 +482,7 @@ def main(argv): "remove_keys_from_pickles": FLAGS.remove_keys_from_pickles, "use_gpu_relax": FLAGS.use_gpu_relax, "models_to_relax": FLAGS.models_to_relax, + "relax_best_score_threshold": FLAGS.relax_best_score_threshold, "features_directory": FLAGS.features_directory, "convert_to_modelcif": FLAGS.convert_to_modelcif } diff --git a/alphapulldown/utils/alphafold_server_json.py b/alphapulldown/utils/alphafold_server_json.py new file mode 100644 index 00000000..91c10f0e --- /dev/null +++ b/alphapulldown/utils/alphafold_server_json.py @@ -0,0 +1,378 @@ +"""Build AlphaFold Server batch JSON inputs from AlphaPulldown jobs.""" + +from __future__ import annotations + +import json +import re +from collections import OrderedDict +from copy import deepcopy +from pathlib import Path +from typing import Any, Iterable + +from alphapulldown_input_parser import RegionSelection, generate_fold_specifications, parse_fold + +from alphapulldown.utils.file_handling import make_dir_monomer_dictionary +from alphapulldown.utils.lightweight_pickles import load_lightweight_pickle + + +_SERVER_DIALECT = "alphafoldserver" +_SERVER_VERSION = 1 +_SERVER_ENTITY_KEYS = ("proteinChain", "dnaSequence", "rnaSequence", "ligand", "ion") + + +def _json_input_basename(json_input_path: str) -> str: + stem = Path(json_input_path).stem + for suffix in ("_af3_input", "_input"): + if stem.endswith(suffix): + stem = stem[: -len(suffix)] + break + return stem or Path(json_input_path).stem + + +def _sanitize_job_name(name: str) -> str: + sanitized = re.sub(r"[^\w.-]+", "_", name.strip()) + sanitized = re.sub(r"_+", "_", sanitized).strip("._") + return sanitized or "alphafold_server_job" + + +def _regions_to_tuples(selection: RegionSelection | Any) -> str | list[tuple[int, int]]: + if isinstance(selection, RegionSelection): + if selection.is_all: + return "all" + return [(region.start, region.end) for region in selection.regions] + return selection + + +def _normalise_job_entries(job: list[dict[str, Any]]) -> list[dict[str, Any]]: + normalised: list[dict[str, Any]] = [] + for entry in job: + if "json_input" in entry: + normalised_entry: dict[str, Any] = {"json_input": entry["json_input"]} + regions = entry.get("regions") + if isinstance(regions, RegionSelection) and not regions.is_all: + normalised_entry["regions"] = _regions_to_tuples(regions) + normalised.append(normalised_entry) + continue + + name, selection = next(iter(entry.items())) + normalised.append({name: _regions_to_tuples(selection)}) + return normalised + + +def _slice_sequence(sequence: str, regions: list[tuple[int, int]] | None) -> str: + if not regions: + return sequence + chunks = [] + for start, end in regions: + if start < 1 or end < start: + raise ValueError(f"Invalid region range {(start, end)}") + chunks.append(sequence[start - 1 : end]) + return "".join(chunks) + + +def _resolve_feature_pickle_path( + monomer_directories: list[str], + protein_name: str, +) -> Path: + monomer_dir_map = make_dir_monomer_dictionary(monomer_directories) + for suffix in (".pkl", ".pkl.xz"): + filename = f"{protein_name}{suffix}" + directory = monomer_dir_map.get(filename) + if directory is not None: + return Path(directory) / filename + raise FileNotFoundError( + f"Could not find a feature pickle for {protein_name!r} in {monomer_directories!r}" + ) + + +def _protein_entity(sequence: str) -> dict[str, Any]: + return {"proteinChain": {"sequence": sequence, "count": 1}} + + +def _sequence_entity( + entity_key: str, + sequence: str, + *, + count: int = 1, +) -> dict[str, Any]: + return {entity_key: {"sequence": sequence, "count": count}} + + +def _simple_entity(entity_key: str, value_key: str, value: str, *, count: int = 1) -> dict[str, Any]: + return {entity_key: {value_key: value, "count": count}} + + +def _convert_local_af3_entity( + entity: dict[str, Any], +) -> list[dict[str, Any]]: + if len(entity) != 1: + raise ValueError(f"Expected one entity per AF3 JSON sequence entry, got {entity!r}") + + entity_type, payload = next(iter(entity.items())) + if not isinstance(payload, dict): + raise TypeError(f"Expected dict payload for {entity_type!r}, got {type(payload)!r}") + + if entity_type == "protein": + sequence = payload.get("sequence") + if not isinstance(sequence, str): + raise ValueError("AF3 protein entities must contain a sequence string") + return [_protein_entity(sequence)] + + if entity_type == "dna": + sequence = payload.get("sequence") + if not isinstance(sequence, str): + raise ValueError("AF3 DNA entities must contain a sequence string") + return [_sequence_entity("dnaSequence", sequence)] + + if entity_type == "rna": + sequence = payload.get("sequence") + if not isinstance(sequence, str): + raise ValueError("AF3 RNA entities must contain a sequence string") + return [_sequence_entity("rnaSequence", sequence)] + + if entity_type == "ligand": + ccd_codes = payload.get("ccdCodes") + if not isinstance(ccd_codes, list) or not all(isinstance(code, str) for code in ccd_codes): + raise ValueError("AF3 ligand entities must provide ccdCodes as a list of strings") + return [_simple_entity("ligand", "ligand", code) for code in ccd_codes] + + if entity_type == "ion": + ion_value = payload.get("ion") + if not isinstance(ion_value, str): + raise ValueError("AF3 ion entities must contain an ion string") + return [_simple_entity("ion", "ion", ion_value)] + + raise ValueError(f"Unsupported AF3 entity type {entity_type!r} for AlphaFold Server export") + + +def _convert_server_entity(entity: dict[str, Any]) -> dict[str, Any]: + if len(entity) != 1: + raise ValueError(f"Expected one entity per server JSON sequence entry, got {entity!r}") + entity_type, payload = next(iter(entity.items())) + if entity_type not in _SERVER_ENTITY_KEYS: + raise ValueError(f"Unsupported AlphaFold Server entity type {entity_type!r}") + if not isinstance(payload, dict): + raise TypeError(f"Expected dict payload for {entity_type!r}, got {type(payload)!r}") + return {entity_type: deepcopy(payload)} + + +def _load_json_input_entities( + json_input_path: str, +) -> list[dict[str, Any]]: + with open(json_input_path, "r", encoding="utf-8") as handle: + payload = json.load(handle) + + if isinstance(payload, list): + if len(payload) != 1: + raise ValueError( + f"JSON input {json_input_path} contains {len(payload)} jobs; expected exactly one" + ) + payload = payload[0] + + if not isinstance(payload, dict): + raise TypeError(f"Unsupported JSON root in {json_input_path!r}: {type(payload)!r}") + + dialect = payload.get("dialect") + sequences = payload.get("sequences") + if not isinstance(sequences, list): + raise ValueError(f"JSON input {json_input_path} does not define a sequences list") + + if dialect == _SERVER_DIALECT: + return [_convert_server_entity(entity) for entity in sequences] + if dialect == "alphafold3": + converted: list[dict[str, Any]] = [] + for entity in sequences: + if not isinstance(entity, dict): + raise TypeError(f"Unsupported AF3 entity payload in {json_input_path!r}: {entity!r}") + converted.extend(_convert_local_af3_entity(entity)) + return converted + + raise ValueError( + f"Unsupported JSON dialect {dialect!r} in {json_input_path!r}. " + f"Expected {_SERVER_DIALECT!r} or 'alphafold3'." + ) + + +def _slice_json_input_entities( + entities: list[dict[str, Any]], + json_input_path: str, + regions: list[tuple[int, int]] | None, +) -> list[dict[str, Any]]: + if not regions: + return entities + if len(entities) != 1: + raise ValueError( + "Region ranges for JSON inputs require exactly one entity, but " + f"{json_input_path!r} contains {len(entities)} entities" + ) + entity = deepcopy(entities[0]) + entity_type, payload = next(iter(entity.items())) + if entity_type not in {"proteinChain", "dnaSequence", "rnaSequence"}: + raise ValueError( + f"Region slicing is only supported for sequence entities, not {entity_type!r}" + ) + sequence = payload.get("sequence") + if not isinstance(sequence, str): + raise ValueError(f"{entity_type!r} in {json_input_path!r} does not contain a sequence string") + payload["sequence"] = _slice_sequence(sequence, regions) + return [entity] + + +def _collapse_duplicate_entities(entities: Iterable[dict[str, Any]]) -> list[dict[str, Any]]: + collapsed: OrderedDict[str, dict[str, Any]] = OrderedDict() + for entity in entities: + entity_type, payload = next(iter(entity.items())) + payload_copy = deepcopy(payload) + count = int(payload_copy.pop("count", 1)) + key = json.dumps({entity_type: payload_copy}, sort_keys=True) + if key not in collapsed: + payload_copy["count"] = count + collapsed[key] = {entity_type: payload_copy} + else: + existing_type, existing_payload = next(iter(collapsed[key].items())) + if existing_type != entity_type: + raise AssertionError("Entity deduplication key collision across types") + existing_payload["count"] = int(existing_payload.get("count", 1)) + count + return list(collapsed.values()) + + +def _build_job_name(name_fragments: list[str]) -> str: + return _sanitize_job_name("_and_".join(fragment for fragment in name_fragments if fragment)) + + +def build_job_name_from_entries(entries: list[dict[str, Any]]) -> str: + fragments: list[str] = [] + for entry in entries: + if "json_input" in entry: + fragment = _json_input_basename(entry["json_input"]) + regions = entry.get("regions") + if regions: + ranges = "_".join(f"{start}-{end}" for start, end in regions) + fragment = f"{fragment}_{ranges}" + fragments.append(fragment) + continue + + protein_name = next(iter(entry)) + fragments.append(protein_name) + return _build_job_name(fragments) + + +def build_alphafold_server_job( + job_entries: list[dict[str, Any]], + monomer_directories: list[str], + *, + model_seeds: list[str] | None = None, +) -> dict[str, Any]: + entities: list[dict[str, Any]] = [] + for entry in job_entries: + if "json_input" in entry: + json_entities = _load_json_input_entities(entry["json_input"]) + entities.extend( + _slice_json_input_entities( + json_entities, + entry["json_input"], + entry.get("regions"), + ) + ) + continue + + protein_name, selection = next(iter(entry.items())) + feature_pickle = _resolve_feature_pickle_path(monomer_directories, protein_name) + payload = load_lightweight_pickle(feature_pickle) + sequence = getattr(payload, "sequence", None) + if not isinstance(sequence, str): + raise ValueError(f"Feature pickle {feature_pickle} does not contain a sequence string") + regions = None if selection == "all" else selection + entities.append(_protein_entity(_slice_sequence(sequence, regions))) + + return { + "name": build_job_name_from_entries(job_entries), + "modelSeeds": list(model_seeds or []), + "sequences": _collapse_duplicate_entities(entities), + "dialect": _SERVER_DIALECT, + "version": _SERVER_VERSION, + } + + +def build_alphafold_server_jobs( + *, + protein_lists: list[str], + monomer_directories: list[str], + mode: str = "pulldown", + oligomer_state_file: str | None = None, + protein_delimiter: str = "+", + model_seeds: list[str] | None = None, + job_index: int | None = None, +) -> list[dict[str, Any]]: + active_lists = list(protein_lists) + if mode == "all_vs_all": + active_lists = [protein_lists[0], protein_lists[0]] + elif mode == "homo-oligomer": + if oligomer_state_file is None: + raise ValueError("oligomer_state_file is required for mode='homo-oligomer'") + active_lists = [oligomer_state_file] + + specifications = generate_fold_specifications( + input_files=active_lists, + delimiter=protein_delimiter, + exclude_permutations=True, + ) + all_folds = [spec.replace(",", ":").replace(";", "+") for spec in specifications] + + if job_index is not None: + zero_based_index = job_index - 1 + if zero_based_index < 0 or zero_based_index >= len(all_folds): + raise IndexError( + f"job_index must be between 1 and {len(all_folds)}, got {job_index}" + ) + selected_folds = [all_folds[zero_based_index]] + else: + selected_folds = all_folds + + parsed_jobs = parse_fold(selected_folds, monomer_directories, protein_delimiter) + normalised_jobs = [_normalise_job_entries(job) for job in parsed_jobs] + return [ + build_alphafold_server_job( + normalised_job, + monomer_directories, + model_seeds=model_seeds, + ) + for normalised_job in normalised_jobs + ] + + +def write_jobs_to_json_files( + jobs: list[dict[str, Any]], + output_path: str | Path, + *, + jobs_per_file: int = 100, +) -> list[Path]: + if jobs_per_file < 1: + raise ValueError("jobs_per_file must be at least 1") + + destination = Path(output_path) + destination.parent.mkdir(parents=True, exist_ok=True) + job_batches = [ + jobs[index : index + jobs_per_file] + for index in range(0, len(jobs), jobs_per_file) + ] or [[]] + + written_paths: list[Path] = [] + if len(job_batches) == 1: + target_paths = [destination] + else: + target_paths = [ + destination.with_name(f"{destination.stem}_{index:03d}{destination.suffix or '.json'}") + for index in range(1, len(job_batches) + 1) + ] + + for batch, target_path in zip(job_batches, target_paths, strict=True): + suffix = target_path.suffix or ".json" + if target_path.suffix != suffix: + target_path = target_path.with_suffix(suffix) + with open(target_path, "w", encoding="utf-8") as handle: + json.dump(batch, handle, indent=2) + handle.write("\n") + written_paths.append(target_path) + return written_paths + diff --git a/alphapulldown/utils/create_custom_template_db.py b/alphapulldown/utils/create_custom_template_db.py index 2726e1b5..3154128e 100644 --- a/alphapulldown/utils/create_custom_template_db.py +++ b/alphapulldown/utils/create_custom_template_db.py @@ -69,7 +69,14 @@ def parse_code(template): # Generate a deterministic 4-character code if needed if len(code) != 4: + original_code = code code = generate_code(code) + logging.info( + "Template %s does not have a four-character PDB-style code, so " + "using deterministic code %s instead.", + original_code, + code, + ) return code.lower() @@ -132,9 +139,16 @@ def _prepare_template(template, code, chain_id, mmcif_dir, seqres_path, template Process and prepare each template. """ duplicate = number_of_templates == 1 + codes_to_process = [f"{code[:-1]}{i}" for i in range(1, 5)] if duplicate else [code] new_template = templates_dir / Path(code + Path(template).suffix) copy_file_exclude_lines('HETATM', template, new_template) logging.info(f"Processing template: {new_template} Chain {chain_id}") + if duplicate: + logging.info( + "Only one multimeric template was provided, so TrueMultimer will " + "duplicate it four times with codes %s to increase template influence.", + ", ".join(codes_to_process), + ) # Convert to (our) mmcif object mmcif_obj = MmcifChainFiltered(new_template, code, chain_id) @@ -152,7 +166,6 @@ def _prepare_template(template, code, chain_id, mmcif_dir, seqres_path, template sequence_ids = mmcif_obj.atom_site_label_seq_ids # Save to file and validate - codes_to_process = [f"{code[:-1]}{i}" for i in range(1, 5)] if duplicate else [code] for temp_code in codes_to_process: mmcif_string = to_mmcif(protein, f"{temp_code}_{chain_id}", "Monomer", chain_id, seqres, sequence_ids) fn = mmcif_dir / f"{temp_code}.cif" diff --git a/alphapulldown/utils/lightweight_pickles.py b/alphapulldown/utils/lightweight_pickles.py new file mode 100644 index 00000000..7940c5f1 --- /dev/null +++ b/alphapulldown/utils/lightweight_pickles.py @@ -0,0 +1,71 @@ +"""Helpers for reading AlphaPulldown pickles without importing heavy runtime modules.""" + +from __future__ import annotations + +import gzip +import lzma +import pickle +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class LightweightMonomericObject: + """Pickle-compatible stand-in for alphapulldown.objects.MonomericObject.""" + + description: str = "" + sequence: str = "" + feature_dict: dict[str, Any] = field(default_factory=dict) + _uniprot_runner: Any = None + + +@dataclass +class LightweightChoppedObject(LightweightMonomericObject): + """Pickle-compatible stand-in for alphapulldown.objects.ChoppedObject.""" + + monomeric_description: str | None = None + regions: Any = None + + +class _AlphaPulldownObjectUnpickler(pickle.Unpickler): + """Unpickler that swaps heavy AlphaPulldown classes for lightweight stand-ins.""" + + _CLASS_MAP = { + ("alphapulldown.objects", "MonomericObject"): LightweightMonomericObject, + ("alphapulldown.objects", "ChoppedObject"): LightweightChoppedObject, + } + + def find_class(self, module: str, name: str) -> Any: + replacement = self._CLASS_MAP.get((module, name)) + if replacement is not None: + return replacement + return super().find_class(module, name) + + +def load_lightweight_pickle(path: str | Path) -> Any: + """Loads a pickle while avoiding imports from alphapulldown.objects.""" + + pickle_path = Path(path) + if pickle_path.suffix == ".xz": + opener = lzma.open + elif pickle_path.suffix == ".gz": + opener = gzip.open + else: + opener = open + with opener(pickle_path, "rb") as handle: + return _AlphaPulldownObjectUnpickler(handle).load() + + +def extract_feature_dict(payload: Any) -> dict[str, Any]: + """Returns a feature dictionary from either a raw dict or a monomer-like object.""" + + if isinstance(payload, dict): + return payload + + feature_dict = getattr(payload, "feature_dict", None) + if not isinstance(feature_dict, dict): + raise TypeError( + f"Expected a dict-like payload or an object with feature_dict, got {type(payload)!r}" + ) + return feature_dict diff --git a/alphapulldown/utils/multimeric_template_utils.py b/alphapulldown/utils/multimeric_template_utils.py index 863e9fec..92df850d 100644 --- a/alphapulldown/utils/multimeric_template_utils.py +++ b/alphapulldown/utils/multimeric_template_utils.py @@ -30,8 +30,8 @@ def prepare_multimeric_template_meta_info(csv_path: str, mmt_dir: str) -> dict: mmt_dir: Path to directory with multimeric template mmCIF files Returns: - a list of dictionaries with the following structure: - [{"protein": protein_name, "sequence" :sequence", templates": [pdb_files], "chains": [chain_id]}, ...]}] + A dictionary keyed by protein name where each value preserves the CSV row + order as a list of `(template_file, chain_id)` tuples. """ # Parse csv file parsed_dict = {} @@ -45,10 +45,7 @@ def prepare_multimeric_template_meta_info(csv_path: str, mmt_dir: str) -> dict: protein, template, chain = [item.strip() for item in row] assert os.path.exists(os.path.join( mmt_dir, template)), f"Provided {template} cannot be found in {mmt_dir}. Abort" - if protein not in parsed_dict: - parsed_dict[protein] = { - template: chain - } + parsed_dict.setdefault(protein, []).append((template, chain)) else: logging.error( f"Invalid line found in the file {csv_path}: {row}") @@ -63,7 +60,14 @@ def obtain_kalign_binary_path() -> Optional[str]: return shutil.which('kalign') -def parse_mmcif_file(file_id: str, mmcif_file: str, chain_id: str) -> ParsingResult: +def parse_mmcif_file( + file_id: str, + mmcif_file: str, + chain_id: str, + threshold_clashes: float = 1000, + hb_allowance: float = 0.4, + plddt_threshold: float = 0, +) -> ParsingResult: """ Args: file_id: A string identifier for this file. Should be unique within the @@ -76,6 +80,8 @@ def parse_mmcif_file(file_id: str, mmcif_file: str, chain_id: str) -> ParsingRes try: mmcif_filtered_obj = MmcifChainFiltered( Path(mmcif_file), file_id, chain_id=chain_id) + mmcif_filtered_obj.remove_clashes(threshold_clashes, hb_allowance) + mmcif_filtered_obj.remove_low_plddt(plddt_threshold) parsing_result = mmcif_filtered_obj.parsing_result except FileNotFoundError as e: parsing_result = None @@ -121,6 +127,9 @@ def extract_multimeric_template_features_for_single_chain( pdb_id: str, chain_id: str, mmcif_file: str, + threshold_clashes: float = 1000, + hb_allowance: float = 0.4, + plddt_threshold: float = 0, ) -> SingleHitResult: """ Args: @@ -134,7 +143,13 @@ def extract_multimeric_template_features_for_single_chain( A SingleHitResult object """ mmcif_parse_result = parse_mmcif_file( - pdb_id, mmcif_file, chain_id=chain_id) + pdb_id, + mmcif_file, + chain_id=chain_id, + threshold_clashes=threshold_clashes, + hb_allowance=hb_allowance, + plddt_threshold=plddt_threshold, + ) if (mmcif_parse_result is not None) and (mmcif_parse_result.mmcif_object is not None): mapping,template_sequence = _obtain_mapping(mmcif_parse_result=mmcif_parse_result, chain_id=chain_id, diff --git a/pyproject.toml b/pyproject.toml index 49157c9b..30f96d3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,8 +109,11 @@ packages = [ script-files = [ "./alphapulldown/scripts/create_individual_features.py", "./alphapulldown/scripts/run_multimer_jobs.py", + "./alphapulldown/scripts/generate_alphafold_server_json.py", "./alphapulldown/analysis_pipeline/create_notebook.py", "./alphapulldown/analysis_pipeline/get_good_inter_pae.py", + "./alphapulldown/analysis_pipeline/plot_diagnostics.py", + "./alphapulldown/analysis_pipeline/plot_interaction_network.py", "./alphapulldown/scripts/rename_colab_search_a3m.py", "./alphapulldown/scripts/prepare_seq_names.py", "./alphapulldown/scripts/generate_crosslink_pickle.py", diff --git a/test/integration/test_create_individual_features.py b/test/integration/test_create_individual_features.py index 4f94778e..3f1297af 100644 --- a/test/integration/test_create_individual_features.py +++ b/test/integration/test_create_individual_features.py @@ -760,6 +760,66 @@ def test_create_arguments_alphafold3_clears_af2_only_databases(self): assert FLAGS.obsolete_pdbs_path is None logger.info("AF3 argument creation only kept AF3-relevant database paths") + def test_create_arguments_reduced_dbs_clears_unused_af2_databases(self): + """Test that reduced_dbs only sets the AF2 paths it actually needs.""" + logger.info("Testing reduced_dbs argument creation without full-db leftovers") + + from absl import flags + FLAGS = flags.FLAGS + FLAGS(['test']) + + FLAGS.use_mmseqs2 = False + FLAGS.data_pipeline = "alphafold2" + FLAGS.db_preset = "reduced_dbs" + FLAGS.use_hhsearch = False + FLAGS.data_dir = "/test/db" + FLAGS.uniref90_database_path = None + FLAGS.mgnify_database_path = None + FLAGS.small_bfd_database_path = None + FLAGS.uniprot_database_path = None + FLAGS.pdb_seqres_database_path = None + FLAGS.template_mmcif_dir = None + FLAGS.obsolete_pdbs_path = None + FLAGS.uniref30_database_path = "/stale/uniref30" + FLAGS.bfd_database_path = "/stale/bfd" + FLAGS.pdb70_database_path = "/stale/pdb70" + + create_features.create_arguments() + + assert FLAGS.uniref90_database_path == "/test/db/uniref90/uniref90.fasta" + assert FLAGS.mgnify_database_path == "/test/db/mgnify/mgy_clusters_2022_05.fa" + assert FLAGS.small_bfd_database_path == "/test/db/small_bfd/bfd-first_non_consensus_sequences.fasta" + assert FLAGS.uniprot_database_path == "/test/db/uniprot/uniprot.fasta" + assert FLAGS.pdb_seqres_database_path == "/test/db/pdb_seqres/pdb_seqres.txt" + assert FLAGS.template_mmcif_dir == "/test/db/pdb_mmcif/mmcif_files" + assert FLAGS.obsolete_pdbs_path == "/test/db/pdb_mmcif/obsolete.dat" + assert FLAGS.uniref30_database_path is None + assert FLAGS.bfd_database_path is None + assert FLAGS.pdb70_database_path is None + logger.info("Reduced-dbs argument creation cleared unused full-database paths") + + def test_create_arguments_reduced_dbs_keeps_pdb70_for_hhsearch(self): + """Test that reduced_dbs still sets pdb70 when HHsearch templates are requested.""" + logger.info("Testing reduced_dbs HHsearch argument creation") + + from absl import flags + FLAGS = flags.FLAGS + FLAGS(['test']) + + FLAGS.use_mmseqs2 = False + FLAGS.data_pipeline = "alphafold2" + FLAGS.db_preset = "reduced_dbs" + FLAGS.use_hhsearch = True + FLAGS.data_dir = "/test/db" + FLAGS.pdb70_database_path = None + + create_features.create_arguments() + + assert FLAGS.pdb70_database_path == "/test/db/pdb70/pdb70" + assert FLAGS.bfd_database_path is None + assert FLAGS.uniref30_database_path is None + logger.info("Reduced-dbs HHsearch argument creation kept pdb70 without restoring full BFD") + def test_mmseqs2_without_data_dir(self): """Test that MMseqs2 works without data_dir flag.""" logger.info("Testing MMseqs2 without data_dir flag") diff --git a/test/unit/test_alphafold2_backend_helpers.py b/test/unit/test_alphafold2_backend_helpers.py index 4f61c4f2..565689c3 100644 --- a/test/unit/test_alphafold2_backend_helpers.py +++ b/test/unit/test_alphafold2_backend_helpers.py @@ -971,3 +971,102 @@ def test_postprocess_handles_monomers_without_relaxation_and_logs_modelcif_error assert plot_calls == [0] assert cleanup_calls assert modelcif_errors == ["Error: convert failed"] + + +def test_postprocess_skips_best_relaxation_below_score_threshold( + af2_backend_module, + monkeypatch, + tmp_path, +): + info_messages = [] + monkeypatch.setattr( + af2_backend_module, + "plot_pae_from_matrix", + lambda **kwargs: None, + ) + monkeypatch.setattr( + af2_backend_module, + "post_prediction_process", + lambda *args, **kwargs: None, + ) + monkeypatch.setattr( + af2_backend_module.logging, + "info", + lambda message, *args: info_messages.append(message % args if args else message), + ) + + multimer = af2_backend_module.MultimericObject( + description="complex", + input_seqs=["AA", "BB"], + feature_dict={}, + multimeric_mode=True, + ) + prediction_results = { + "model_high": { + "plddt": np.array([90.0, 91.0, 92.0, 93.0], dtype=np.float32), + "predicted_aligned_error": np.zeros((4, 4), dtype=np.float32), + "max_predicted_aligned_error": 31.0, + "ranking_confidence": 0.9, + "iptm": 0.7, + "ptm": 0.8, + "unrelaxed_protein": SimpleNamespace(name="high"), + "seqs": ["AA", "BB"], + } + } + + af2_backend_module.AlphaFold2Backend.postprocess( + prediction_results=prediction_results, + multimeric_object=multimer, + output_dir=tmp_path, + features_directory="/features", + models_to_relax=af2_backend_module.ModelsToRelax.BEST, + relax_best_score_threshold=0.8, + convert_to_modelcif=False, + ) + + assert not (tmp_path / "relaxed_model_high.pdb").exists() + assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "PDB:high" + assert any("Skipping relaxation for model_high" in message for message in info_messages) + + +def test_postprocess_uses_ptm_threshold_for_best_monomer_relaxation( + af2_backend_module, + monkeypatch, + tmp_path, +): + monkeypatch.setattr( + af2_backend_module, + "plot_pae_from_matrix", + lambda **kwargs: None, + ) + monkeypatch.setattr( + af2_backend_module, + "post_prediction_process", + lambda *args, **kwargs: None, + ) + + monomer = af2_backend_module.MonomericObject("single", "AB") + prediction_results = { + "modelA": { + "plddt": np.array([81.0, 82.0], dtype=np.float32), + "predicted_aligned_error": np.zeros((2, 2), dtype=np.float32), + "max_predicted_aligned_error": 31.0, + "ranking_confidence": 81.5, + "ptm": 0.4, + "seqs": ["AB"], + "unrelaxed_protein": SimpleNamespace(name="mono"), + } + } + + af2_backend_module.AlphaFold2Backend.postprocess( + prediction_results=prediction_results, + multimeric_object=monomer, + output_dir=tmp_path, + features_directory="/features", + models_to_relax=af2_backend_module.ModelsToRelax.BEST, + relax_best_score_threshold=0.3, + convert_to_modelcif=False, + ) + + assert (tmp_path / "relaxed_modelA.pdb").read_text(encoding="utf-8") == "RELAXED:mono" + assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "RELAXED:mono" diff --git a/test/unit/test_alphafold_server_json.py b/test/unit/test_alphafold_server_json.py new file mode 100644 index 00000000..0a7244c1 --- /dev/null +++ b/test/unit/test_alphafold_server_json.py @@ -0,0 +1,119 @@ +import json +from pathlib import Path + +import pytest + +from alphapulldown.utils.alphafold_server_json import ( + build_alphafold_server_jobs, + build_alphafold_server_job, + write_jobs_to_json_files, +) + + +TEST_DATA = Path(__file__).resolve().parents[1] / "test_data" + + +def test_build_server_job_collapses_homodimer_features(): + jobs = build_alphafold_server_jobs( + protein_lists=[str(TEST_DATA / "protein_lists" / "test_dimer.txt")], + monomer_directories=[str(TEST_DATA / "features")], + ) + + assert len(jobs) == 1 + job = jobs[0] + assert job["dialect"] == "alphafoldserver" + assert job["version"] == 1 + assert job["modelSeeds"] == [] + assert job["name"] == "TEST_and_TEST" + assert job["sequences"] == [ + { + "proteinChain": { + "sequence": "MESAIAEGGASRFSASSGGGGSRGAPQHYPKTAGNSEFLGKTPGQNAQKWIPARSTRRDDNSAA", + "count": 2, + } + } + ] + + +def test_build_server_job_slices_regions_for_fragments(): + job = build_alphafold_server_jobs( + protein_lists=[str(TEST_DATA / "protein_lists" / "test_dimer_chopped.txt")], + monomer_directories=[str(TEST_DATA / "features")], + )[0] + + assert job["name"] == "TEST_and_A0A075B6L2" + assert ( + job["sequences"][0]["proteinChain"]["sequence"] + == "MESAIAEGGASRFSASSGGGGSRGAPQHYPKTAGNSEFLGKTPGQNAQKWIPARSTRRDDNSAA" + ) + assert job["sequences"][0]["proteinChain"]["count"] == 1 + assert job["sequences"][1]["proteinChain"]["sequence"] == "MPLVVAVIFFPLVVLWVF" + assert job["sequences"][1]["proteinChain"]["count"] == 1 + + +def test_build_server_job_converts_af3_json_inputs_for_dna(): + jobs = build_alphafold_server_jobs( + protein_lists=[str(TEST_DATA / "protein_lists" / "test_monomer_with_dna.txt")], + monomer_directories=[str(TEST_DATA / "features")], + ) + + assert len(jobs) == 1 + job = jobs[0] + assert job["name"] == "A0A024R1R8_and_dna" + assert job["sequences"] == [ + { + "proteinChain": { + "sequence": "MSSHEGGKKKALKQPKKQAKEMDEEEKAFKQKQKEEQKKLEVLKAKVVGKGPLATGGIKKSGKK", + "count": 1, + } + }, + {"dnaSequence": {"sequence": "GATTACA", "count": 1}}, + {"dnaSequence": {"sequence": "TGTAATC", "count": 1}}, + ] + + +def test_build_server_job_converts_local_af3_json_homodimer(): + jobs = build_alphafold_server_jobs( + protein_lists=[ + str(TEST_DATA / "protein_lists" / "test_homodimer_from_json_features.txt") + ], + monomer_directories=[str(TEST_DATA / "features" / "af3_features" / "protein")], + ) + + assert len(jobs) == 1 + job = jobs[0] + assert job["name"] == "P61626_and_P61626" + assert len(job["sequences"]) == 1 + assert job["sequences"][0]["proteinChain"]["count"] == 2 + + +def test_write_jobs_to_json_files_splits_large_batches(tmp_path): + job = build_alphafold_server_job( + [{"TEST": "all"}], + [str(TEST_DATA / "features")], + ) + jobs = [job, job, job] + + written_paths = write_jobs_to_json_files( + jobs, + tmp_path / "server_jobs.json", + jobs_per_file=2, + ) + + assert [path.name for path in written_paths] == [ + "server_jobs_001.json", + "server_jobs_002.json", + ] + first_payload = json.loads(written_paths[0].read_text(encoding="utf-8")) + second_payload = json.loads(written_paths[1].read_text(encoding="utf-8")) + assert len(first_payload) == 2 + assert len(second_payload) == 1 + + +def test_job_index_is_one_based(): + with pytest.raises(IndexError): + build_alphafold_server_jobs( + protein_lists=[str(TEST_DATA / "protein_lists" / "test_dimer.txt")], + monomer_directories=[str(TEST_DATA / "features")], + job_index=2, + ) diff --git a/test/unit/test_custom_db.py b/test/unit/test_custom_db.py index ee3a8085..1830f272 100644 --- a/test/unit/test_custom_db.py +++ b/test/unit/test_custom_db.py @@ -154,4 +154,26 @@ def test_long_filename_generates_valid_code(): logger.info(f"Successfully parsed entry: {test_entry}") assert metadata.pdb_id == code, f"Parsed PDB ID {metadata.pdb_id} does not match generated code {code}" except ValueError as e: - pytest.fail(f"Generated code '{code}' cannot be parsed by AlphaFold parser: {e}") \ No newline at end of file + pytest.fail(f"Generated code '{code}' cannot be parsed by AlphaFold parser: {e}") + + +def test_create_db_logs_generated_code_and_duplication(caplog, tmp_path): + caplog.set_level(logging.INFO) + + create_db( + tmp_path / "custom_db", + ["./test/test_data/templates/RANdom_name1_.7-1_0.pdb"], + ["B"], + 1000, + 0.4, + 0, + ) + + assert ( + "does not have a four-character PDB-style code, so using deterministic code" + in caplog.text + ) + assert ( + "Only one multimeric template was provided, so TrueMultimer will duplicate it four times" + in caplog.text + ) diff --git a/test/unit/test_diagnostics.py b/test/unit/test_diagnostics.py new file mode 100644 index 00000000..91537cae --- /dev/null +++ b/test/unit/test_diagnostics.py @@ -0,0 +1,104 @@ +import gzip +from pathlib import Path +import shutil + +import alphapulldown.analysis_pipeline.diagnostics as diagnostics +from alphapulldown.analysis_pipeline.diagnostics import ( + plot_inputs, + save_msa_coverage_plot, + save_prediction_plots, +) + + +TEST_DATA = Path(__file__).resolve().parents[1] / "test_data" + + +def test_save_prediction_plots_writes_core_diagnostics(tmp_path): + written_paths = save_prediction_plots( + TEST_DATA / "predictions" / "TEST_homo_2er", + tmp_path, + ) + + assert [path.name for path in written_paths] == [ + "TEST_homo_2er_pae.png", + "TEST_homo_2er_plddt.png", + "TEST_homo_2er_distogram.png", + ] + for path in written_paths: + assert path.exists() + assert path.stat().st_size > 0 + + +def test_save_msa_coverage_plot_supports_monomer_pickles(tmp_path): + output_path = save_msa_coverage_plot( + TEST_DATA / "features" / "A0A024R1R8.pkl", + tmp_path, + ) + + assert output_path.name == "A0A024R1R8_msa_coverage.png" + assert output_path.exists() + assert output_path.stat().st_size > 0 + + +def test_plot_inputs_accepts_feature_directories_and_prediction_dirs(tmp_path): + written_paths = plot_inputs( + [ + TEST_DATA / "predictions" / "af_vs_ap" / "A0A024R1R8", + TEST_DATA / "predictions" / "TEST_homo_2er", + ], + output_dir=tmp_path, + ) + + assert sorted(path.name for path in written_paths) == [ + "A0A024R1R8_msa_coverage.png", + "TEST_homo_2er_distogram.png", + "TEST_homo_2er_pae.png", + "TEST_homo_2er_plddt.png", + ] + + +def test_plot_inputs_accepts_gzip_compressed_prediction_dirs(tmp_path): + source_dir = TEST_DATA / "predictions" / "TEST_homo_2er" + compressed_dir = tmp_path / "compressed_prediction" + shutil.copytree(source_dir, compressed_dir) + + for result_pickle in compressed_dir.glob("result*.pkl"): + compressed_path = result_pickle.with_suffix(f"{result_pickle.suffix}.gz") + with result_pickle.open("rb") as source_handle, gzip.open(compressed_path, "wb") as target_handle: + shutil.copyfileobj(source_handle, target_handle) + result_pickle.unlink() + + written_paths = plot_inputs([compressed_dir], output_dir=tmp_path / "plots") + + assert sorted(path.name for path in written_paths) == [ + "compressed_prediction_distogram.png", + "compressed_prediction_pae.png", + "compressed_prediction_plddt.png", + ] + + +def test_save_prediction_plots_falls_back_when_af2plots_has_no_distogram(monkeypatch, tmp_path): + real_plotter = diagnostics.plotter + + class LegacyPlotter: + def __init__(self): + self._delegate = real_plotter() + + def plot_predicted_alignment_error(self, *args, **kwargs): + return self._delegate.plot_predicted_alignment_error(*args, **kwargs) + + def plot_plddts(self, *args, **kwargs): + return self._delegate.plot_plddts(*args, **kwargs) + + monkeypatch.setattr(diagnostics, "plotter", LegacyPlotter) + + written_paths = save_prediction_plots( + TEST_DATA / "predictions" / "TEST_homo_2er", + tmp_path, + ) + + assert [path.name for path in written_paths] == [ + "TEST_homo_2er_pae.png", + "TEST_homo_2er_plddt.png", + "TEST_homo_2er_distogram.png", + ] diff --git a/test/unit/test_interaction_network.py b/test/unit/test_interaction_network.py new file mode 100644 index 00000000..57d727ce --- /dev/null +++ b/test/unit/test_interaction_network.py @@ -0,0 +1,160 @@ +from pathlib import Path + +import pandas as pd + +from alphapulldown.analysis_pipeline.interaction_network import ( + build_interaction_edge_table, + plot_interaction_network, + summarise_nodes, + write_table, +) + + +def test_build_interaction_edge_table_prefers_interface_pairs(): + table = pd.DataFrame( + [ + { + "jobs": "A_and_B_and_C", + "iptm_ptm": 0.91, + "average_interface_pae": 4.0, + "interface": "A_C", + }, + { + "jobs": "A_and_B", + "iptm_ptm": 0.72, + "average_interface_pae": 2.0, + "interface": "A_B", + }, + { + "jobs": "B_and_B", + "iptm_ptm": 0.88, + "average_interface_pae": 3.0, + "interface": "A_B", + }, + ] + ) + + edge_table = build_interaction_edge_table(table, min_score=0.7, max_pae=5.0) + + assert edge_table.to_dict(orient="records") == [ + { + "source": "A", + "target": "C", + "score": 0.91, + "support": 1, + "jobs": "A_and_B_and_C", + "self_interaction": False, + }, + { + "source": "A", + "target": "B", + "score": 0.72, + "support": 1, + "jobs": "A_and_B", + "self_interaction": False, + }, + { + "source": "B", + "target": "B", + "score": 0.88, + "support": 1, + "jobs": "B_and_B", + "self_interaction": True, + }, + ] + + +def test_build_interaction_edge_table_supports_ap_style_homo_job_names(): + table = pd.DataFrame( + [ + { + "jobs": "protA_homo_3er", + "iptm_ptm": 0.91, + "average_interface_pae": 4.0, + "interface": "A_C", + }, + ] + ) + + edge_table = build_interaction_edge_table(table, min_score=0.7, max_pae=5.0) + + assert edge_table.to_dict(orient="records") == [ + { + "source": "protA", + "target": "protA", + "score": 0.91, + "support": 1, + "jobs": "protA_homo_3er", + "self_interaction": True, + }, + ] + + +def test_build_interaction_edge_table_skips_nan_scores(): + table = pd.DataFrame( + [ + { + "jobs": "A_and_B", + "pi_score": float("nan"), + "average_interface_pae": 4.0, + "interface": "A_B", + }, + { + "jobs": "A_and_C", + "pi_score": 0.81, + "average_interface_pae": 4.0, + "interface": "A_B", + }, + ] + ) + + edge_table = build_interaction_edge_table( + table, + score_column="pi_score", + min_score=0.8, + max_pae=5.0, + ) + + assert edge_table.to_dict(orient="records") == [ + { + "source": "A", + "target": "C", + "score": 0.81, + "support": 1, + "jobs": "A_and_C", + "self_interaction": False, + }, + ] + + +def test_plot_interaction_network_and_tables_write_outputs(tmp_path): + edge_table = pd.DataFrame( + [ + { + "source": "A", + "target": "B", + "score": 0.82, + "support": 2, + "jobs": "A_and_B", + "self_interaction": False, + }, + { + "source": "B", + "target": "B", + "score": 0.77, + "support": 1, + "jobs": "B_and_B", + "self_interaction": True, + }, + ] + ) + + plot_path = plot_interaction_network(edge_table, tmp_path / "network.png", seed=7) + node_table = summarise_nodes(edge_table) + node_path = write_table(node_table, tmp_path / "nodes.csv") + edge_path = write_table(edge_table, tmp_path / "edges.csv") + + assert plot_path.exists() + assert plot_path.stat().st_size > 0 + assert node_path.read_text(encoding="utf-8").startswith("node,degree,best_score") + assert edge_path.read_text(encoding="utf-8").startswith("source,target,score") diff --git a/test/unit/test_multimeric_template_utils.py b/test/unit/test_multimeric_template_utils.py index ace80e61..02d867ed 100644 --- a/test/unit/test_multimeric_template_utils.py +++ b/test/unit/test_multimeric_template_utils.py @@ -16,7 +16,21 @@ def test_prepare_multimeric_template_meta_info_parses_valid_csv(tmp_path): result = mtu.prepare_multimeric_template_meta_info(str(csv_path), str(tmp_path)) - assert result == {"proteinA": {"template1.cif": "A"}} + assert result == {"proteinA": [("template1.cif", "A")]} + + +def test_prepare_multimeric_template_meta_info_keeps_duplicate_rows_for_homo_oligomers(tmp_path): + csv_path = tmp_path / "templates.csv" + template_path = tmp_path / "template1.cif" + template_path.write_text("data", encoding="utf-8") + csv_path.write_text( + "proteinA,template1.cif,A\nproteinA,template1.cif,B\n", + encoding="utf-8", + ) + + result = mtu.prepare_multimeric_template_meta_info(str(csv_path), str(tmp_path)) + + assert result == {"proteinA": [("template1.cif", "A"), ("template1.cif", "B")]} def test_prepare_multimeric_template_meta_info_exits_on_invalid_row(tmp_path): @@ -42,6 +56,7 @@ def test_obtain_kalign_binary_path_asserts_when_binary_missing(monkeypatch): def test_parse_mmcif_file_returns_parsing_result(monkeypatch, tmp_path): expected = SimpleNamespace(name="parsed") + calls = [] class FakeFiltered: def __init__(self, path, file_id, chain_id): @@ -50,11 +65,28 @@ def __init__(self, path, file_id, chain_id): assert chain_id == "A" self.parsing_result = expected + def remove_clashes(self, threshold, hb_allowance): + calls.append(("remove_clashes", threshold, hb_allowance)) + + def remove_low_plddt(self, threshold): + calls.append(("remove_low_plddt", threshold)) + monkeypatch.setattr(mtu, "MmcifChainFiltered", FakeFiltered) - result = mtu.parse_mmcif_file("1abc", str(tmp_path / "template.cif"), "A") + result = mtu.parse_mmcif_file( + "1abc", + str(tmp_path / "template.cif"), + "A", + threshold_clashes=12.5, + hb_allowance=0.7, + plddt_threshold=42.0, + ) assert result is expected + assert calls == [ + ("remove_clashes", 12.5, 0.7), + ("remove_low_plddt", 42.0), + ] def test_parse_mmcif_file_returns_none_when_file_missing(monkeypatch, tmp_path): diff --git a/test/unit/test_objects.py b/test/unit/test_objects.py index fb5e01c7..f99de4eb 100644 --- a/test/unit/test_objects.py +++ b/test/unit/test_objects.py @@ -772,7 +772,7 @@ def test_multimeric_object_init_calls_template_setup_and_feature_creation(monkey def fake_prepare(path, template_dir): calls.append(("prepare_meta", path, template_dir)) - return {"proteinA": {"file.cif": "A"}} + return {"proteinA": [("file.cif", "A")]} monkeypatch.setattr(objects_mod, "prepare_multimeric_template_meta_info", fake_prepare) monkeypatch.setattr( @@ -802,7 +802,7 @@ def fake_prepare(path, template_dir): assert multimer.description == "proteinA" assert multimer.pair_msa is False assert multimer.multimeric_template is True - assert multimer.multimeric_template_meta_data == {"proteinA": {"file.cif": "A"}} + assert multimer.multimeric_template_meta_data == {"proteinA": [("file.cif", "A")]} assert calls == [ ("prepare_meta", "meta.csv", "/tmp/templates"), "templates", @@ -965,25 +965,100 @@ def test_create_multimeric_template_features_updates_matching_monomer(monkeypatc template_file.write_text("data_1abc", encoding="utf-8") monomer = SimpleNamespace(sequence="ACDE", feature_dict={}) multimer = MultimericObject.__new__(MultimericObject) + multimer.interactors = [SimpleNamespace(description="proteinA", sequence="ACDE", feature_dict=monomer.feature_dict)] multimer.multimeric_template_dir = str(tmp_path) - multimer.multimeric_template_meta_data = {"proteinA": {"1abc.cif": "B"}} + multimer.multimeric_template_meta_data = {"proteinA": [("1abc.cif", "B")]} multimer.monomers_mapping = {"proteinA": monomer} + multimer.threshold_clashes = 12.5 + multimer.hb_allowance = 0.7 + multimer.plddt_threshold = 42.0 + extractor_calls = [] monkeypatch.setattr( objects_mod, "extract_multimeric_template_features_for_single_chain", - lambda **kwargs: SimpleNamespace(features={"templated": kwargs["chain_id"]}), + lambda **kwargs: extractor_calls.append(kwargs) + or SimpleNamespace(features={"templated": kwargs["chain_id"]}), ) multimer.create_multimeric_template_features() assert monomer.feature_dict["templated"] == "B" + assert extractor_calls == [{ + "query_seq": "ACDE", + "pdb_id": "1abc", + "chain_id": "B", + "mmcif_file": str(template_file), + "threshold_clashes": 12.5, + "hb_allowance": 0.7, + "plddt_threshold": 42.0, + }] + + +def test_create_multimeric_template_features_assigns_duplicate_rows_to_homo_oligomers(monkeypatch, tmp_path): + template_file = tmp_path / "templ.cif" + template_file.write_text("data_templ", encoding="utf-8") + monomer_a = SimpleNamespace(description="proteinA", sequence="AAAA", feature_dict={}) + monomer_b = SimpleNamespace(description="proteinA", sequence="BBBB", feature_dict={}) + multimer = MultimericObject.__new__(MultimericObject) + multimer.interactors = [monomer_a, monomer_b] + multimer.multimeric_template_dir = str(tmp_path) + multimer.multimeric_template_meta_data = { + "proteinA": [("templ.cif", "A"), ("templ.cif", "B")] + } + multimer.threshold_clashes = 1000 + multimer.hb_allowance = 0.4 + multimer.plddt_threshold = 0 + calls = [] + + monkeypatch.setattr( + objects_mod, + "extract_multimeric_template_features_for_single_chain", + lambda **kwargs: calls.append((kwargs["query_seq"], kwargs["chain_id"])) + or SimpleNamespace(features={"templated": kwargs["chain_id"]}), + ) + + multimer.create_multimeric_template_features() + + assert calls == [("AAAA", "A"), ("BBBB", "B")] + assert monomer_a.feature_dict["templated"] == "A" + assert monomer_b.feature_dict["templated"] == "B" + + +def test_create_multimeric_template_features_matches_chopped_objects_by_base_description( + monkeypatch, + tmp_path, +): + template_file = tmp_path / "1abc.cif" + template_file.write_text("data_1abc", encoding="utf-8") + chopped = ChoppedObject("P04051", "ACDE", {}, [(1, 2), (3, 4)]) + multimer = MultimericObject.__new__(MultimericObject) + multimer.interactors = [chopped] + multimer.multimeric_template_dir = str(tmp_path) + multimer.multimeric_template_meta_data = {"P04051": [("1abc.cif", "B")]} + multimer.threshold_clashes = 1000 + multimer.hb_allowance = 0.4 + multimer.plddt_threshold = 0 + calls = [] + + monkeypatch.setattr( + objects_mod, + "extract_multimeric_template_features_for_single_chain", + lambda **kwargs: calls.append((kwargs["query_seq"], kwargs["chain_id"])) + or SimpleNamespace(features={"templated": kwargs["chain_id"]}), + ) + + multimer.create_multimeric_template_features() + + assert calls == [("ACDE", "B")] + assert chopped.feature_dict["templated"] == "B" def test_create_multimeric_template_features_rejects_non_mmcif_files(tmp_path): multimer = MultimericObject.__new__(MultimericObject) + multimer.interactors = [SimpleNamespace(description="proteinA", sequence="ACDE", feature_dict={})] multimer.multimeric_template_dir = str(tmp_path) - multimer.multimeric_template_meta_data = {"proteinA": {"bad.pdb": "A"}} + multimer.multimeric_template_meta_data = {"proteinA": [("bad.pdb", "A")]} multimer.monomers_mapping = {"proteinA": SimpleNamespace(sequence="ACDE", feature_dict={})} with pytest.raises(AssertionError, match="does not seem to be a mmcif file"): diff --git a/test/unit/test_script_entrypoints.py b/test/unit/test_script_entrypoints.py index 7f41e36d..321fabbd 100644 --- a/test/unit/test_script_entrypoints.py +++ b/test/unit/test_script_entrypoints.py @@ -103,6 +103,9 @@ def DEFINE_list(self, name, default, *_args, **_kwargs): def DEFINE_integer(self, name, default, *_args, **_kwargs): return self.FLAGS.define(name, default) + def DEFINE_float(self, name, default, *_args, **_kwargs): + return self.FLAGS.define(name, default) + def DEFINE_boolean(self, name, default, *_args, **_kwargs): return self.FLAGS.define(name, default) @@ -208,12 +211,18 @@ def __init__( multimeric_template, multimeric_template_meta_data, multimeric_template_dir, + threshold_clashes=1000, + hb_allowance=0.4, + plddt_threshold=0, ): self.interactors = list(interactors) self.pair_msa = pair_msa self.multimeric_template = multimeric_template self.multimeric_template_meta_data = multimeric_template_meta_data self.multimeric_template_dir = multimeric_template_dir + self.threshold_clashes = threshold_clashes + self.hb_allowance = hb_allowance + self.plddt_threshold = plddt_threshold self.description = "_and_".join(interactor.description for interactor in interactors) self.input_seqs = [interactor.sequence for interactor in interactors] self.multimeric_mode = True @@ -304,6 +313,7 @@ def _load_run_multimer_jobs_module(): # Predefine the shared FLAGS that run_multimer_jobs expects from run_structure_prediction. shared_flag_defaults = { "models_to_relax": "NONE", + "relax_best_score_threshold": None, "num_cycle": 3, "num_predictions_per_model": 1, "pair_msa": True, @@ -315,6 +325,9 @@ def _load_run_multimer_jobs_module(): "fold_backend": "alphafold2", "description_file": None, "path_to_mmt": None, + "threshold_clashes": 1000, + "hb_allowance": 0.4, + "plddt_threshold": 0, "compress_result_pickles": False, "remove_result_pickles": False, "remove_keys_from_pickles": True, @@ -649,6 +662,45 @@ def test_pre_modelling_setup_saves_multimer_features_and_builds_unique_ap_style_ ] +def test_pre_modelling_setup_passes_multimeric_template_filters( + run_structure_prediction_module, + tmp_path, +): + _set_flag(run_structure_prediction_module.FLAGS, "pair_msa", True) + _set_flag(run_structure_prediction_module.FLAGS, "multimeric_template", True) + _set_flag(run_structure_prediction_module.FLAGS, "description_file", "meta.csv") + _set_flag(run_structure_prediction_module.FLAGS, "path_to_mmt", "/tmp/templates") + _set_flag(run_structure_prediction_module.FLAGS, "threshold_clashes", 12.5) + _set_flag(run_structure_prediction_module.FLAGS, "hb_allowance", 0.7) + _set_flag(run_structure_prediction_module.FLAGS, "plddt_threshold", 42.0) + _set_flag(run_structure_prediction_module.FLAGS, "save_features_for_multimeric_object", False) + _set_flag( + run_structure_prediction_module.FLAGS, + "features_directory", + [str(tmp_path / "features")], + ) + _set_flag(run_structure_prediction_module.FLAGS, "use_ap_style", False) + + feature_dir = tmp_path / "features" + feature_dir.mkdir() + for description in ("protA", "protB"): + (feature_dir / f"{description}_feature_metadata_2026-03-30.json").write_text( + '{"meta": 1}', + encoding="utf-8", + ) + + monomer_a = run_structure_prediction_module.MonomericObject("protA", "AAAA") + monomer_b = run_structure_prediction_module.MonomericObject("protB", "BBBB") + returned_object, _ = run_structure_prediction_module.pre_modelling_setup( + [monomer_a, monomer_b], + output_dir=str(tmp_path / "outputs"), + ) + + assert returned_object.threshold_clashes == 12.5 + assert returned_object.hb_allowance == 0.7 + assert returned_object.plddt_threshold == 42.0 + + def test_pre_modelling_setup_builds_ap_style_homo_oligomer_dir( run_structure_prediction_module, tmp_path, @@ -857,6 +909,7 @@ def test_main_sets_multimer_model_flags_for_multimer_jobs( _set_flag(run_structure_prediction_module.FLAGS, "msa_depth_scan", True) _set_flag(run_structure_prediction_module.FLAGS, "model_names", ["model_2_multimer_v3"]) _set_flag(run_structure_prediction_module.FLAGS, "msa_depth", 64) + _set_flag(run_structure_prediction_module.FLAGS, "relax_best_score_threshold", 0.6) monkeypatch.setattr(run_structure_prediction_module, "parse_fold", lambda *args: [["parsed"]]) monkeypatch.setattr(run_structure_prediction_module, "create_custom_info", lambda parsed: "data") @@ -883,6 +936,7 @@ def test_main_sets_multimer_model_flags_for_multimer_jobs( assert captured_calls[0]["model_flags"]["msa_depth_scan"] is True assert captured_calls[0]["model_flags"]["model_names_custom"] == ["model_2_multimer_v3"] assert captured_calls[0]["model_flags"]["msa_depth"] == 64 + assert captured_calls[0]["postprocess_flags"]["relax_best_score_threshold"] == 0.6 def test_main_rejects_mismatched_output_directories( @@ -1103,3 +1157,71 @@ def test_run_multimer_jobs_combines_inputs_when_padding_requested( input_index = calls[0].index("--input") assert calls[0][input_index + 1] == "job1,job2" assert "--nopair_msa" in calls[0] + + +def test_run_multimer_jobs_forwards_multimeric_template_filters( + run_multimer_jobs_module, + monkeypatch, +): + calls = [] + monkeypatch.setattr( + run_multimer_jobs_module.subprocess, + "run", + lambda command, check, env: calls.append(command), + ) + run_multimer_jobs_module.generate_fold_specifications = ( + lambda input_files, delimiter, exclude_permutations: ["job1"] + ) + + _set_flag(run_multimer_jobs_module.FLAGS, "mode", "custom") + _set_flag(run_multimer_jobs_module.FLAGS, "protein_lists", ["proteins.txt"]) + _set_flag(run_multimer_jobs_module.FLAGS, "dry_run", False) + _set_flag(run_multimer_jobs_module.FLAGS, "fold_backend", "alphafold2") + _set_flag(run_multimer_jobs_module.FLAGS, "output_path", "/tmp/output") + _set_flag(run_multimer_jobs_module.FLAGS, "data_dir", "/tmp/models") + _set_flag(run_multimer_jobs_module.FLAGS, "monomer_objects_dir", ["/tmp/features"]) + _set_flag(run_multimer_jobs_module.FLAGS, "multimeric_template", True) + _set_flag(run_multimer_jobs_module.FLAGS, "threshold_clashes", 12.5) + _set_flag(run_multimer_jobs_module.FLAGS, "hb_allowance", 0.7) + _set_flag(run_multimer_jobs_module.FLAGS, "plddt_threshold", 42.0) + + run_multimer_jobs_module.main(["prog"]) + + assert len(calls) == 1 + assert "--threshold_clashes" in calls[0] + assert calls[0][calls[0].index("--threshold_clashes") + 1] == "12.5" + assert "--hb_allowance" in calls[0] + assert calls[0][calls[0].index("--hb_allowance") + 1] == "0.7" + assert "--plddt_threshold" in calls[0] + assert calls[0][calls[0].index("--plddt_threshold") + 1] == "42.0" + + +def test_run_multimer_jobs_forwards_relax_best_score_threshold( + run_multimer_jobs_module, + monkeypatch, +): + calls = [] + monkeypatch.setattr( + run_multimer_jobs_module.subprocess, + "run", + lambda command, check, env: calls.append(command), + ) + run_multimer_jobs_module.generate_fold_specifications = ( + lambda input_files, delimiter, exclude_permutations: ["job1"] + ) + + _set_flag(run_multimer_jobs_module.FLAGS, "mode", "custom") + _set_flag(run_multimer_jobs_module.FLAGS, "protein_lists", ["proteins.txt"]) + _set_flag(run_multimer_jobs_module.FLAGS, "dry_run", False) + _set_flag(run_multimer_jobs_module.FLAGS, "fold_backend", "alphafold2") + _set_flag(run_multimer_jobs_module.FLAGS, "output_path", "/tmp/output") + _set_flag(run_multimer_jobs_module.FLAGS, "data_dir", "/tmp/models") + _set_flag(run_multimer_jobs_module.FLAGS, "monomer_objects_dir", ["/tmp/features"]) + _set_flag(run_multimer_jobs_module.FLAGS, "models_to_relax", "Best") + _set_flag(run_multimer_jobs_module.FLAGS, "relax_best_score_threshold", 0.6) + + run_multimer_jobs_module.main(["prog"]) + + assert len(calls) == 1 + assert "--relax_best_score_threshold" in calls[0] + assert calls[0][calls[0].index("--relax_best_score_threshold") + 1] == "0.6"