Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions models/rfd3/src/rfd3/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ class RFD3Output:
denoised_trajectory_stack: Optional[AtomArrayStack] = None
noisy_trajectory_stack: Optional[AtomArrayStack] = None

@staticmethod
def _strip_atom_id(atoms):
"""Strip ``atom_id`` so the CIF writer auto-generates unique
sequential IDs instead of using the duplicated values left by
virtual-atom padding."""
if "atom_id" in atoms.get_annotation_categories():
atoms.del_annotation("atom_id")
return atoms

def dump(
self,
out_dir,
Expand All @@ -103,7 +112,7 @@ def dump(
base_path = os.path.join(out_dir, self.example_id)
base_path = Path(base_path).absolute()
to_cif_file(
self.atom_array,
self._strip_atom_id(self.atom_array),
base_path,
file_type="cif.gz",
include_entity_poly=False,
Expand All @@ -118,15 +127,15 @@ def dump(
suffix = str(base_path)[-1]
if self.denoised_trajectory_stack is not None:
to_cif_file(
self.denoised_trajectory_stack,
self._strip_atom_id(self.denoised_trajectory_stack),
"_denoised_model_".join([prefix, suffix]),
file_type="cif.gz",
include_entity_poly=False,
)

if self.noisy_trajectory_stack is not None:
to_cif_file(
self.noisy_trajectory_stack,
self._strip_atom_id(self.noisy_trajectory_stack),
"_noisy_model_".join([prefix, suffix]),
file_type="cif.gz",
include_entity_poly=False,
Expand Down
12 changes: 11 additions & 1 deletion models/rfd3/src/rfd3/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,17 @@ def _cleanup_virtual_atoms_and_assign_atom_name_elements(
atom_array.element,
)
atom_array.res_name[invalid_mask] = np.array(["UNK"] * sum(invalid_mask))
return atom_array[ret_mask]
result = atom_array[ret_mask]

# Strip atom_id annotation to prevent duplicate _atom_site.id values
# in CIF output. PadTokensWithVirtualAtoms copies the central atom (CB)
# including its atom_id; after virtual atoms are removed, sidechain atoms
# retain the duplicated atom_id. Removing it lets the CIF writer
# auto-generate unique sequential IDs.
if "atom_id" in result.get_annotation_categories():
result.del_annotation("atom_id")

return result


def _readout_seq_from_struc(
Expand Down
104 changes: 104 additions & 0 deletions models/rfd3/tests/test_cif_duplicate_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Regression tests for duplicate CIF atom_id (#148).

PadTokensWithVirtualAtoms copies the central atom (CB) including its
atom_id annotation. After virtual atoms are removed, sidechain atoms
retain CB's duplicated atom_id. The CIF writer uses these values for
_atom_site.id, producing duplicate IDs that violate the mmCIF spec.
"""

import tempfile
from pathlib import Path

import numpy as np
from atomworks.io.utils.io_utils import to_cif_file
from biotite.structure import AtomArray
from rfd3.engine import RFD3Output
from rfd3.trainer.trainer_utils import (
_cleanup_virtual_atoms_and_assign_atom_name_elements,
)


def _make_alanine_array_with_duplicate_atom_ids(n_residues=4):
"""Build a protein AtomArray where every atom in each residue carries
the same atom_id as CB — the corruption caused by virtual-atom padding."""
names = ["N", "CA", "C", "O", "CB"]
n_atoms = n_residues * len(names)
atoms = AtomArray(n_atoms)

for i in range(n_residues):
for j, name in enumerate(names):
idx = i * len(names) + j
atoms.chain_id[idx] = "A"
atoms.res_id[idx] = i + 1
atoms.res_name[idx] = "ALA"
atoms.atom_name[idx] = name
atoms.element[idx] = name[0]
atoms.coord[idx] = [float(i * 4), float(j * 1.5), 0.0]

# Simulate the bug: every atom in a residue shares CB's atom_id
atom_ids = np.array(
[i * len(names) + 4 for i in range(n_residues) for _ in range(len(names))]
)
atoms.set_annotation("atom_id", atom_ids)

# Annotations required by _cleanup_virtual_atoms_and_assign_atom_name_elements
atoms.set_annotation(
"is_motif_atom_with_fixed_seq", np.ones(n_atoms, dtype=bool)
)
atoms.set_annotation(
"is_motif_atom_unindexed", np.zeros(n_atoms, dtype=bool)
)
atoms.set_annotation("gt_atom_name", atoms.atom_name.copy())

return atoms


def test_cleanup_strips_atom_id():
"""_cleanup_virtual_atoms_and_assign_atom_name_elements must remove
the atom_id annotation so the CIF writer generates fresh IDs."""
atoms = _make_alanine_array_with_duplicate_atom_ids()

# Precondition: atom_id exists and has duplicates
assert "atom_id" in atoms.get_annotation_categories()
assert len(set(atoms.atom_id)) < len(atoms.atom_id)

result = _cleanup_virtual_atoms_and_assign_atom_name_elements(atoms)
assert "atom_id" not in result.get_annotation_categories()


def test_cif_output_has_unique_ids():
"""CIF _atom_site.id values must be unique after _strip_atom_id."""
atoms = _make_alanine_array_with_duplicate_atom_ids()
atoms = RFD3Output._strip_atom_id(atoms)

with tempfile.TemporaryDirectory() as tmpdir:
out_path = Path(tmpdir) / "test"
to_cif_file(atoms, out_path, file_type="cif")

cif_text = Path(f"{out_path}.cif").read_text()

# Parse _atom_site loop to find the id column
lines = cif_text.splitlines()
col_names = []
data_start = None
for i, line in enumerate(lines):
if line.strip().startswith("_atom_site."):
col_names.append(line.strip().split()[0])
elif col_names and not line.strip().startswith("_") and line.strip():
data_start = i
break

assert "_atom_site.id" in col_names, "missing _atom_site.id column"
id_col = col_names.index("_atom_site.id")

ids = []
for line in lines[data_start:]:
stripped = line.strip()
if not stripped or stripped.startswith("#") or stripped.startswith("loop_"):
break
parts = stripped.split()
if len(parts) > id_col:
ids.append(parts[id_col])

assert len(ids) > 0, "no atom records found in CIF output"
assert len(ids) == len(set(ids)), f"duplicate _atom_site.id values: {ids}"
Loading