diff --git a/package/MDAnalysis/topology/MMCIFParser.py b/package/MDAnalysis/topology/MMCIFParser.py index 26b88a8bc8..ca19027cb2 100644 --- a/package/MDAnalysis/topology/MMCIFParser.py +++ b/package/MDAnalysis/topology/MMCIFParser.py @@ -46,7 +46,6 @@ else: HAS_GEMMI = True -import itertools import warnings import numpy as np @@ -54,7 +53,6 @@ from ..core.topology import Topology from ..core.topologyattrs import ( AltLocs, - AtomAttr, Atomids, Atomnames, Atomtypes, @@ -66,230 +64,12 @@ Occupancies, RecordTypes, Resids, - ResidueAttr, Resnames, Resnums, Segids, - SegmentAttr, Tempfactors, ) -from .base import TopologyReaderBase - - -def _into_idx(arr: list) -> list[int]: - """Replace consecutive identical elements of an array with their indices. - - Example - ------- - .. code-block:: python - - arr: list[int] = [1, 1, 5, 5, 7, 3, 3] - assert _into_idx(arr) == [0, 0, 1, 1, 2, 3, 3] - - Parameters - ---------- - arr - input array of elements that can be compared with `__eq__` - - Returns - ------- - list[int] -- array where these elements are replaced with their unique indices, in order of appearance. - - .. versionadded:: 2.9.0 - """ - return [ - idx - for idx, (_, group) in enumerate(itertools.groupby(arr)) - for _ in group - ] - - -def get_Atomattrs(model: "gemmi.Model") -> tuple[list[AtomAttr], np.ndarray]: - """Extract all attributes that are subclasses of :class:`..core.topologyattrs.AtomAttr` from a ``gemmi.Model`` object, - and a `residx` index with indices of all atoms in residues. - - Parameters - ---------- - model - input `gemmi.Model`, e.g. `gemmi.read_structure('file.cif')[0]` - - Returns - ------- - tuple[list[AtomAttr], np.ndarray] -- first element is list of all extracted attributes, second element is `segidx` - - Raises - ------ - ValueError - if any of the records is neither 'ATOM' nor 'HETATM' - - .. versionadded:: 2.9.0 - """ - ( - altlocs, # at.altloc - serials, # at.serial - names, # at.name - atomtypes, # at.name - # ------------------ - chainids, # chain.name - elements, # at.element.name - formalcharges, # at.charge - weights, # at.element.weight - # ------------------ - occupancies, # at.occ - record_types, # res.het_flag - tempfactors, # at.b_iso - residx, # _into_idx(res.label_seq or res.seqid.num) - ) = map( # this construct takes np.ndarray of all lists of attributes, extracted from the `gemmi.Model` - np.array, - list( - zip( - *[ - ( - # tuple of attributes - # extracted from residue, atom or chain in the structure - # ------------------ - atom.altloc, # altlocs - atom.serial, # serials - atom.name, # names - atom.name, # atomtypes - # ------------------ - chain.name, # chainids - atom.element.name, # elements - atom.charge, # formalcharges - atom.element.weight, # weights - # ------------------ - atom.occ, # occupancies - residue.het_flag, # record_types - atom.b_iso, # tempfactors - # residue.seqid.num, - ( - residue.label_seq - if residue.label_seq is not None - else residue.seqid.num - ), # residx, later translated into continious repr - ) - # the main loop over the `gemmi.Model` object - for chain in model - for residue in chain - for atom in residue - ] - ) - ), - ) - - # transform *idx into continious numpy arrays - print(f"Before: {len(residx)=}") - residx = np.array(_into_idx(residx)) - print(f"After: {len(residx)=}") - - # fill in altlocs, since gemmi has '' as default - altlocs = ["A" if not elem else elem for elem in altlocs] - - # convert default gemmi record types to default MDAnalysis record types - record_types = [ - "ATOM" if record == "A" else "HETATM" if record == "H" else None - for record in record_types - ] - if any((elem is None for elem in record_types)): - raise ValueError("Found an atom that is neither ATOM or HETATM") - - attrs = [ - AltLocs(altlocs), - Atomids(serials), - Atomnames(names), - Atomtypes(atomtypes), - # ---------------------------- - ChainIDs(chainids), - Elements(elements), - FormalCharges(formalcharges), - Masses(weights), - # ---------------------------- - Occupancies(occupancies), - RecordTypes(record_types), - Tempfactors(tempfactors), - ] - - return attrs, residx - - -def make_resid(residue: "gemmi.Residue") -> str: - # return residue.seqid.num - # return f'{residue.seqid.num}{residue.seqid.icode.strip()}' - return ( - residue.label_seq - if residue.label_seq is not None - else residue.seqid.num - ) - - -def get_Residueattrs( - model: "gemmi.Model", -) -> tuple[list[ResidueAttr], np.ndarray]: - """Extract all attributes that are subclasses of :class:`..core.topologyattrs.ResidueAttr` from a ``gemmi.Model`` object, - and a `segidx` index witn indices of all residues in segments. - - Parameters - ---------- - model - input `gemmi.Model`, e.g. `gemmi.read_structure('file.cif')[0]` - - Returns - ------- - tuple[list[ResidueAttr], np.ndarray] -- first element is list of all extracted attributes, second element is `segidx` - - .. versionadded:: 2.9.0 - """ - ( - icodes, # residue.seqid.icode - resids, # residue.seqid.num # FIXME: perhaps this is what's wrong, and not residx per se? - resnames, # residue.name - segidx, # chain.name - resnums, # residue.seqid.num - ) = map( - np.array, - list( - zip( - *[ - ( - residue.seqid.icode, - residue.seqid.num, - residue.name, - chain.name, - residue.seqid.num, - ) - for chain in model - for residue in chain - ] - ) - ), - ) - segidx = np.array(_into_idx(segidx)) - - attrs = [ - Resnums(resnums), - ICodes([icode.strip() for icode in icodes]), - Resids(resids), - Resnames(resnames), - ] - return attrs, segidx - - -def get_Segmentattrs(model: "gemmi.Model") -> list[SegmentAttr]: - """Extract all attributes that are subclasses of :class:`..core.topologyattrs.SegmentAttr` from a ``gemmi.Model`` object. - - Parameters - ---------- - model - input `gemmi.Model`, e.g. `gemmi.read_structure('file.cif')[0]` - - Returns - ------- - list[SegmentAttr] -- list of all extracted attributes - - .. versionadded:: 2.9.0 - """ - segids = [chain.name for chain in model] - return [Segids(segids)] +from .base import TopologyReaderBase, change_squash class MMCIFParser(TopologyReaderBase): @@ -337,16 +117,104 @@ def parse(self, **kwargs) -> Topology: ) model = structure[0] - atomAttrs, residx = get_Atomattrs(model) - residAttrs, segidx = get_Residueattrs(model) - segmentAttrs = get_Segmentattrs(model) - - attrs = atomAttrs + residAttrs + segmentAttrs + ( + altlocs, # at.altloc + serials, # at.serial + names, # at.name + atomtypes, # at.name + # ------------------ + chainids, # chain.name + elements, # at.element.name + formalcharges, # at.charge + weights, # at.element.weight + # ------------------ + occupancies, # at.occ + record_types, # res.het_flag + tempfactors, # at.b_iso + # ------------------ + icodes, # residue.seqid.icode + resids, # residue.seqid.num + resnames, # residue.name + ) = map( # this construct takes np.ndarray of all lists of attributes, extracted from the `gemmi.Model` + np.array, + list( + zip( + *[ + ( + # tuple of attributes + # extracted from residue, atom or chain in the structure + # ------------------ + atom.altloc, # altlocs + atom.serial, # serials + atom.name, # names + atom.name, # atomtypes + # ------------------ + chain.name, # chainids + atom.element.name, # elements + atom.charge, # formalcharges + atom.element.weight, # weights + # ------------------ + atom.occ, # occupancies + residue.het_flag, # record_types + atom.b_iso, # tempfactors + # ------------------ + residue.seqid.icode, # icodes + residue.seqid.num, # resids + residue.name, # resnames + ) + # the main loop over the `gemmi.Model` object + for chain in model + for residue in chain + for atom in residue + ] + ) + ), + ) - # due to the list(map(...)) construction, all elements in array have equal lengths - n_atoms = len(atomAttrs[0]) - n_residues = len(residAttrs[0]) - n_segments = len(segmentAttrs[0]) + # fill in altlocs, since gemmi has '' as default + altlocs = ["A" if not elem else elem for elem in altlocs] + + # convert default gemmi record types to default MDAnalysis record types + record_types = [ + "ATOM" if record == "A" else "HETATM" if record == "H" else None + for record in record_types + ] + if any((elem is None for elem in record_types)): + raise ValueError("Found an atom that is neither ATOM or HETATM") + + # Atom Attr's + attrs = [ + AltLocs(altlocs), + Atomids(serials), + Atomnames(names), + Atomtypes(atomtypes), + # ---------------------------- + ChainIDs(chainids), + Elements(elements), + FormalCharges(formalcharges), + Masses(weights), + # ---------------------------- + Occupancies(occupancies), + RecordTypes(record_types), + Tempfactors(tempfactors), + ] + n_atoms = len(altlocs) + + # Residue Attr's + residx, (resids, resnames, icodes, chainids) = change_squash( + (resids, resnames, icodes, chainids), + (resids, resnames, icodes, chainids), + ) + attrs.append(Resids(resids)) + attrs.append(Resnames(resnames)) + attrs.append(Resnums(resids.copy())) + attrs.append(ICodes([icode.strip() for icode in icodes])) + n_residues = len(resids) + + # Segment Attr's + segidx, (segids,) = change_squash((chainids,), (chainids,)) + attrs.append(Segids(segids)) + n_segments = len(segids) return Topology( n_atoms,