Skip to content

Commit db61778

Browse files
committed
Release v0.13.1
1 parent af42f20 commit db61778

9 files changed

Lines changed: 7262 additions & 86 deletions

File tree

openprotein/data/data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def create(
5959
metadata = api.assaydata_post(
6060
self.session, stream, name, assay_description=description
6161
)
62+
table.columns = [str(column).lower() for column in table.columns]
6263
metadata.sequence_length = len(table["sequence"].values[0])
6364
return AssayDataset(self.session, metadata)
6465

openprotein/data/schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class AssayMetadata(BaseModel):
1313
num_entries: int
1414
measurement_names: list[str]
1515
sequence_length: int | None = None
16+
has_name_column: bool = False
1617

1718

1819
class AssayDataRow(BaseModel):

openprotein/embeddings/embeddings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from openprotein.base import APISession
44

55
from . import api
6+
from .ablang import AbLang2Model
67
from .esm import ESMModel
78
from .future import EmbeddingsResultFuture
89
from .models import EmbeddingModel
@@ -55,6 +56,7 @@ class EmbeddingsAPI:
5556
#: Rotaprot model trained on UniRef90
5657
rotaprot_large_uniref90_ft: OpenProteinModel
5758
poet_2: PoET2Model
59+
ablang2: AbLang2Model
5860

5961
#: ESM1b model
6062
esm1b: ESMModel # alias

openprotein/fold/common.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,11 @@ def serialize_input(session: APISession, complexes: list[Complex], needs_msa: bo
163163
msa_to_seed: dict[str, set[str]] = dict()
164164
for complex in complexes:
165165
_complex: list[dict[str, Any]] = []
166-
for chain_ids, chain in complex.get_id_groups():
167-
id_field: str | list[str] = (
168-
chain_ids[0] if len(chain_ids) == 1 else list(chain_ids)
169-
)
166+
for chain_id, chain in complex.get_chains().items():
170167
if isinstance(chain, Protein):
171168
# add the protein in the unified format
172169
p: dict = {
173-
"id": id_field,
170+
"id": chain_id,
174171
"sequence": chain.sequence.decode(),
175172
}
176173
if needs_msa:
@@ -208,17 +205,17 @@ def serialize_input(session: APISession, complexes: list[Complex], needs_msa: bo
208205
p["msa_id"] = msa_id
209206
_complex.append({"protein": p})
210207
elif isinstance(chain, Ligand):
211-
ligand_payload: dict[str, Any] = {"id": id_field}
208+
ligand_payload: dict[str, Any] = {"id": chain_id}
212209
if chain.smiles is not None:
213210
ligand_payload["smiles"] = chain.smiles
214211
if chain.ccd is not None:
215212
ligand_payload["ccd"] = chain.ccd
216213
_complex.append({"ligand": ligand_payload})
217214
elif isinstance(chain, DNA):
218-
d: dict[str, Any] = {"id": id_field, "sequence": chain.sequence}
215+
d: dict[str, Any] = {"id": chain_id, "sequence": chain.sequence}
219216
_complex.append({"dna": d})
220217
elif isinstance(chain, RNA):
221-
r: dict[str, Any] = {"id": id_field, "sequence": chain.sequence}
218+
r: dict[str, Any] = {"id": chain_id, "sequence": chain.sequence}
222219
_complex.append({"rna": r})
223220
else:
224221
raise ValueError(f"Unexpected chain type: {chain}")

openprotein/molecules/complex.py

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,16 @@
2424
class Complex:
2525
def __init__(
2626
self,
27-
chains: Mapping[
28-
str | tuple[str, ...], Protein | DNA | RNA | Ligand
29-
]
30-
| None = None,
27+
chains: Mapping[str, Protein | DNA | RNA | Ligand] | None = None,
3128
name: bytes | str | None = None,
3229
):
33-
expanded: dict[str, Protein | DNA | RNA | Ligand] = {}
34-
groups: list[tuple[str, ...]] = []
30+
collected: dict[str, Protein | DNA | RNA | Ligand] = {}
3531
if chains is not None:
3632
for key, value in chains.items():
37-
ids = (key,) if isinstance(key, str) else key
38-
if not isinstance(ids, tuple) or not all(
39-
isinstance(cid, str) for cid in ids
40-
):
41-
raise TypeError(
42-
f"chain id must be str or tuple[str, ...]; got {key!r}"
43-
)
44-
if len(ids) == 0:
45-
raise ValueError("tuple chain id must be non-empty")
46-
for cid in ids:
47-
if cid in expanded:
48-
raise ValueError(f"duplicate chain id: {cid!r}")
49-
expanded[cid] = value
50-
groups.append(tuple(ids))
51-
self._chains = dict(sorted(expanded.items()))
52-
self._id_groups: list[tuple[str, ...]] = sorted(groups, key=lambda g: g[0])
33+
if not isinstance(key, str):
34+
raise TypeError(f"chain id must be str; got {key!r}")
35+
collected[key] = value
36+
self._chains = dict(sorted(collected.items()))
5337
self._templates: "Sequence[Protein | Complex | Template]" = ()
5438
self.name = name
5539

@@ -89,17 +73,6 @@ def set_templates(
8973
def get_chains(self) -> Mapping[str, Protein | DNA | RNA | Ligand]:
9074
return MappingProxyType(self._chains)
9175

92-
def get_id_groups(
93-
self,
94-
) -> "list[tuple[tuple[str, ...], Protein | DNA | RNA | Ligand]]":
95-
"""Return ordered (chain_ids, chain) pairs grouped by entity.
96-
97-
Each group's ``chain_ids`` is the tuple originally passed to the
98-
constructor (a 1-tuple for scalar keys), and the chain object is
99-
shared across all ids in the group.
100-
"""
101-
return [(ids, self._chains[ids[0]]) for ids in self._id_groups]
102-
10376
def get_proteins(self) -> Mapping[str, Protein]:
10477
return MappingProxyType(
10578
{k: v for k, v in self._chains.items() if isinstance(v, Protein)}
@@ -143,18 +116,8 @@ def get_ligand(self, chain_id: str) -> Ligand:
143116
def set_chain(
144117
self, chain_id: str, value: Protein | DNA | RNA | Ligand
145118
) -> "Complex":
146-
new_groups: list[tuple[str, ...]] = []
147-
for ids in self._id_groups:
148-
if chain_id in ids:
149-
remaining = tuple(i for i in ids if i != chain_id)
150-
if remaining:
151-
new_groups.append(remaining)
152-
else:
153-
new_groups.append(ids)
154-
new_groups.append((chain_id,))
155119
self._chains[chain_id] = value
156120
self._chains = dict(sorted(self._chains.items()))
157-
self._id_groups = sorted(new_groups, key=lambda g: g[0])
158121
return self
159122

160123
def __rand__(self, left: "Complex | Protein | str") -> "Complex":
@@ -184,9 +147,6 @@ def __and__(self, right: "Complex | Protein | str") -> "Complex":
184147
f"Trying to combine two sets of chains with overlapping chain ids: {overlapping_chain_ids}"
185148
)
186149
self._chains = dict(sorted((self._chains | right._chains).items()))
187-
self._id_groups = sorted(
188-
self._id_groups + right._id_groups, key=lambda g: g[0]
189-
)
190150
return self
191151

192152
@overload
@@ -309,12 +269,9 @@ def from_string(
309269
)
310270

311271
def copy(self) -> "Complex":
312-
chains_copy: dict[
313-
str | tuple[str, ...], Protein | DNA | RNA | Ligand
314-
] = {}
315-
for ids in self._id_groups:
316-
value = self._chains[ids[0]].copy()
317-
chains_copy[ids if len(ids) > 1 else ids[0]] = value
272+
chains_copy: dict[str, Protein | DNA | RNA | Ligand] = {
273+
chain_id: chain.copy() for chain_id, chain in self._chains.items()
274+
}
318275
return Complex(chains=chains_copy, name=self._name)
319276

320277
def _assert_valid_templates(self):

0 commit comments

Comments
 (0)