Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions atompack-py/python/atompack/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Type stubs for atompack"""

from typing import Any, Sequence, overload
from typing import Any, Literal, Sequence, overload

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -336,20 +336,30 @@ class Molecule:
"""
...

def set_property(self, key: str, value: float | int | str | npt.NDArray | None) -> None:
def set_property(
self,
key: str,
value: float | int | str | npt.NDArray | None,
*,
scope: Literal["molecule", "atom"] | None = None,
) -> None:
"""
Set a custom property.

Supported types: None, float, int, str, 1D float32/float64/int32/int64 arrays,
and 2D float32/float64 arrays with shape (n, 3). Input dtype is preserved.
The key 'stress' is reserved; use the dedicated ``stress`` property instead.
Supported types: None, float, int, str, and numeric ndarrays with dtype
float32, float64, int32, or int64. Input dtype and tensor shape are preserved.
New atom properties require ``scope="atom"``; existing atom properties keep
atom scope when overwritten. The key 'stress' is reserved; use the dedicated
``stress`` property instead.

Parameters
----------
key : str
Property key
value : float, int, str, ndarray, or None
Property value
scope : {"molecule", "atom"}, optional
Property scope. Defaults to molecule for new keys.

Raises
------
Expand All @@ -358,7 +368,7 @@ class Molecule:
"""
...

def property_keys(self) -> list[str]:
def property_keys(self, *, scope: Literal["molecule", "atom"] | None = None) -> list[str]:
"""
Get all property keys.

Expand All @@ -369,14 +379,16 @@ class Molecule:
"""
...

def has_property(self, key: str) -> bool:
def has_property(self, key: str, *, scope: Literal["molecule", "atom"] | None = None) -> bool:
"""
Check if a property exists.

Parameters
----------
key : str
Property key
scope : {"molecule", "atom"}, optional
Restrict the lookup to one scope.

Returns
-------
Expand All @@ -385,6 +397,17 @@ class Molecule:
"""
...

def delete_property(self, key: str) -> None:
"""
Delete a custom property by key.

Raises
------
KeyError
If property key does not exist
"""
...

@overload
def __getitem__(self, index: int) -> Atom: ...
@overload
Expand Down
22 changes: 18 additions & 4 deletions atompack-py/python/atompack/_atompack_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Any, Sequence, overload
from typing import Any, Literal, Sequence, overload

import numpy as np

Expand Down Expand Up @@ -329,7 +329,13 @@ class PyMolecule:
"""
...

def set_property(self, key: str, value: Any) -> None:
def set_property(
self,
key: str,
value: Any,
*,
scope: Literal["molecule", "atom"] | None = None,
) -> None:
"""
Set a custom property.

Expand All @@ -339,10 +345,12 @@ class PyMolecule:
Property key
value : Any
Property value
scope : {"molecule", "atom"}, optional
Property scope. Defaults to molecule for new keys.
"""
...

def property_keys(self) -> list[str]:
def property_keys(self, *, scope: Literal["molecule", "atom"] | None = None) -> list[str]:
"""
Get all property keys.

Expand All @@ -353,14 +361,16 @@ class PyMolecule:
"""
...

def has_property(self, key: str) -> bool:
def has_property(self, key: str, *, scope: Literal["molecule", "atom"] | None = None) -> bool:
"""
Check if a property exists.

Parameters
----------
key : str
Property key
scope : {"molecule", "atom"}, optional
Restrict the lookup to one scope.

Returns
-------
Expand All @@ -369,6 +379,10 @@ class PyMolecule:
"""
...

def delete_property(self, key: str) -> None:
"""Delete a custom property by key."""
...

@overload
def __getitem__(self, index: int) -> PyAtom: ...
@overload
Expand Down
17 changes: 11 additions & 6 deletions atompack-py/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,30 @@ impl PyAtomDatabase {
)
.map_err(|e| PyValueError::new_err(format!("{}", e)));
}
let owned = molecule.clone_as_owned()?;
let owned = molecule.as_owned().ok_or_else(|| {
PyValueError::new_err("Molecule is missing both owned and view state")
})?;
self.inner
.add_molecule(&owned)
.add_molecule(owned)
.map_err(|e| PyValueError::new_err(format!("{}", e)))
}

/// Add multiple molecules (processed in parallel)
fn add_molecules(&mut self, molecules: Vec<PyRef<PyMolecule>>) -> PyResult<()> {
let mut raw_records: Vec<(&[u8], u32)> = Vec::new();
let mut raw_views: Vec<&SoaMoleculeView> = Vec::new();
let mut owned_molecules: Vec<Molecule> = Vec::new();
let mut owned_molecules: Vec<&Molecule> = Vec::new();

for m in &molecules {
if let Some(view) = m.as_view() {
raw_records.push((view.raw_bytes(), view.n_atoms as u32));
raw_views.push(view);
} else if let Some(owned) = m.as_owned() {
owned_molecules.push(owned);
} else {
owned_molecules.push(m.clone_as_owned()?);
return Err(PyValueError::new_err(
"Molecule is missing both owned and view state",
));
}
}

Expand Down Expand Up @@ -211,9 +217,8 @@ impl PyAtomDatabase {
}
}
if !owned_molecules.is_empty() {
let mol_refs: Vec<&Molecule> = owned_molecules.iter().collect();
self.inner
.add_molecules(&mol_refs)
.add_molecules(&owned_molecules)
.map_err(|e| PyValueError::new_err(format!("{}", e)))?;
}
Ok(())
Expand Down
15 changes: 11 additions & 4 deletions atompack-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

use atompack::{
Atom, AtomDatabase, FloatArrayData, FloatScalarData, Mat3Data, Molecule, SharedMmapBytes,
Vec3Data, atom::PropertyValue, compression::CompressionType,
Vec3Data,
atom::{PropertyValue, TensorData},
compression::CompressionType,
};
use numpy::{Element, PyArray1, PyArray2, PyArray3, PyArrayMethods};
use numpy::{Element, PyArray1, PyArray2, PyArray3, PyArrayDyn, PyArrayMethods};
use pyo3::exceptions::{PyFileExistsError, PyIndexError, PyKeyError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyTuple};
Expand Down Expand Up @@ -114,6 +116,10 @@ const TYPE_MAT3X3_F64: u8 = 10;
const TYPE_FLOAT32: u8 = 11;
const TYPE_MAT3X3_F32: u8 = 12;
const TYPE_NONE: u8 = 13;
const TYPE_TENSOR_F32: u8 = 14;
const TYPE_TENSOR_F64: u8 = 15;
const TYPE_TENSOR_I32: u8 = 16;
const TYPE_TENSOR_I64: u8 = 17;

const RECORD_FORMAT_SOA_V2: u32 = 2;
const RECORD_FORMAT_SOA_V3: u32 = 3;
Expand All @@ -126,8 +132,9 @@ pub(crate) use self::py_dtypes::{
parse_mat3_field, parse_positions_field, parse_property_value, parse_vec3_field,
};
pub(crate) use self::soa::{
LazySection, SectionSchema, SoaContext, SoaMoleculeView, parse_mol_fast_soa, read_f64_scalar,
read_i64_scalar, section_schema_from_ref, type_tag_elem_bytes,
LazySection, SectionSchema, SoaContext, SoaMoleculeView, decode_property_value,
is_tensor_type_tag, parse_mol_fast_soa, read_f64_scalar, read_i64_scalar,
section_schema_from_ref, type_tag_elem_bytes,
};

mod database;
Expand Down
Loading
Loading