Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 205 additions & 23 deletions atompack-py/src/database_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn schema_section(kind: u8, key: &str, type_tag: u8, slot_bytes: usize) -> Datab
}

fn schema_slot_bytes(kind: u8, key: &str, type_tag: u8, payload_slot_bytes: usize) -> usize {
if matches!(type_tag, TYPE_STRING | TYPE_NONE) {
if matches!(type_tag, TYPE_STRING | TYPE_NONE) || is_tensor_type_tag(type_tag) {
0
} else if is_per_atom(kind, key, type_tag) {
type_tag_elem_bytes(type_tag)
Expand Down Expand Up @@ -137,6 +137,18 @@ fn extract_string_column(
}))
}

fn reject_list_or_tuple_tensor_column(value: &Bound<'_, PyAny>, key: &str) -> PyResult<()> {
if value.cast::<PyList>().is_ok() || value.cast::<PyTuple>().is_ok() {
return Err(PyValueError::new_err(format!(
"custom property '{key}' must be a stacked ndarray tensor for add_arrays_batch; \
list/tuple tensor inputs are not supported. Ragged tensors should use \
per-molecule construction plus add_molecules(...) and non-flat retrieval with \
get_molecule(s)."
)));
}
Ok(())
}

fn extract_scalar_column_f64<T: Element + Copy + Into<f64>>(
arr: &Bound<'_, PyArray1<T>>,
batch: usize,
Expand Down Expand Up @@ -251,6 +263,152 @@ fn extract_vec3_column<T: Element + Copy + bytemuck::NoUninit>(
})
}

fn tensor_value_count(shape: &[usize]) -> PyResult<usize> {
shape.iter().try_fold(1usize, |acc, dim| {
acc.checked_mul(*dim)
.ok_or_else(|| PyValueError::new_err("Tensor shape overflows usize"))
})
}

fn tensor_payload_header(shape: &[usize], key: &str) -> PyResult<Vec<u8>> {
if shape.len() > u8::MAX as usize {
return Err(PyValueError::new_err(format!(
"custom property '{}' tensor rank {} exceeds maximum {}",
key,
shape.len(),
u8::MAX
)));
}
if shape.iter().any(|dim| *dim > u32::MAX as usize) {
return Err(PyValueError::new_err(format!(
"custom property '{}' tensor dimensions must fit in u32 for storage",
key
)));
}
let mut header = Vec::with_capacity(1 + shape.len() * 4);
header.push(shape.len() as u8);
for dim in shape {
header.extend_from_slice(&(*dim as u32).to_le_bytes());
}
Ok(header)
}

fn extract_tensor_column<T: Element + Copy + bytemuck::NoUninit>(
arr: &Bound<'_, PyArrayDyn<T>>,
batch: usize,
n_atoms: Option<usize>,
key: &str,
kind: u8,
type_tag: u8,
) -> PyResult<BatchSectionColumn> {
let readonly = arr.readonly();
let view = readonly.as_array();
let shape = view.shape();
if shape.len() < 3 || shape.first() != Some(&batch) {
return Err(PyValueError::new_err(format!(
"custom property '{}' must have stacked tensor shape ({}, ...)",
key, batch
)));
}
if let Some(expected_atoms) = n_atoms
&& shape.get(1) != Some(&expected_atoms)
{
return Err(PyValueError::new_err(format!(
"atom property '{}' must have stacked tensor shape ({}, {}, ...)",
key, batch, expected_atoms
)));
}

let per_record_shape = &shape[1..];
let values_per_record = tensor_value_count(per_record_shape)?;
let elem_bytes = std::mem::size_of::<T>();
let data_slot_bytes = values_per_record
.checked_mul(elem_bytes)
.ok_or_else(|| PyValueError::new_err("Tensor byte length overflow"))?;
let header = tensor_payload_header(per_record_shape, key)?;
let slot_bytes = header
.len()
.checked_add(data_slot_bytes)
.ok_or_else(|| PyValueError::new_err("Tensor payload length overflow"))?;
let values = readonly.as_slice().map_err(|_| {
PyValueError::new_err(format!("custom property '{}' must be C-contiguous", key))
})?;
let mut payload = Vec::with_capacity(
batch
.checked_mul(slot_bytes)
.ok_or_else(|| PyValueError::new_err("Tensor batch payload length overflow"))?,
);
for index in 0..batch {
let start = index
.checked_mul(values_per_record)
.ok_or_else(|| PyValueError::new_err("Tensor value offset overflow"))?;
let end = start
.checked_add(values_per_record)
.ok_or_else(|| PyValueError::new_err("Tensor value offset overflow"))?;
payload.extend_from_slice(&header);
payload.extend_from_slice(bytemuck::cast_slice::<T, u8>(&values[start..end]));
}

Ok(BatchSectionColumn {
key: key.to_string(),
kind,
type_tag,
slot_bytes,
payload,
strings: None,
})
}

fn extract_stacked_tensor_column(
value: &Bound<'_, PyAny>,
batch: usize,
n_atoms: Option<usize>,
key: &str,
kind: u8,
) -> PyResult<Option<BatchSectionColumn>> {
if let Ok(arr) = value.cast::<PyArrayDyn<f32>>() {
return Ok(Some(extract_tensor_column(
arr,
batch,
n_atoms,
key,
kind,
TYPE_TENSOR_F32,
)?));
}
if let Ok(arr) = value.cast::<PyArrayDyn<f64>>() {
return Ok(Some(extract_tensor_column(
arr,
batch,
n_atoms,
key,
kind,
TYPE_TENSOR_F64,
)?));
}
if let Ok(arr) = value.cast::<PyArrayDyn<i32>>() {
return Ok(Some(extract_tensor_column(
arr,
batch,
n_atoms,
key,
kind,
TYPE_TENSOR_I32,
)?));
}
if let Ok(arr) = value.cast::<PyArrayDyn<i64>>() {
return Ok(Some(extract_tensor_column(
arr,
batch,
n_atoms,
key,
kind,
TYPE_TENSOR_I64,
)?));
}
Ok(None)
}

fn extract_property_column(
value: &Bound<'_, PyAny>,
batch: usize,
Expand All @@ -260,6 +418,7 @@ fn extract_property_column(
if let Some(column) = extract_string_column(value, batch, key, kind)? {
return Ok(Some(column));
}
reject_list_or_tuple_tensor_column(value, key)?;
if let Some(arr) = PyFloatArray1::from_any(value) {
return Ok(Some(match arr {
PyFloatArray1::F32(arr) => extract_scalar_column_f64(&arr, batch, key, kind)?,
Expand Down Expand Up @@ -336,6 +495,9 @@ fn extract_property_column(
}
}
}
if let Some(column) = extract_stacked_tensor_column(value, batch, None, key, kind)? {
return Ok(Some(column));
}
Ok(None)
}

Expand All @@ -345,6 +507,7 @@ fn extract_atom_property_column(
n_atoms: usize,
key: &str,
) -> PyResult<Option<BatchSectionColumn>> {
reject_list_or_tuple_tensor_column(value, key)?;
if let Some(arr) = PyFloatArray2::from_any(value) {
let (column, expected) = match arr {
PyFloatArray2::F32(arr) => (
Expand Down Expand Up @@ -385,26 +548,45 @@ fn extract_atom_property_column(
return Ok(Some(column));
}
if let Some(arr) = PyFloatArray3::from_any(value) {
return Ok(Some(match arr {
PyFloatArray3::F32(arr) => extract_vec3_column(
&arr,
batch,
n_atoms,
key,
KIND_ATOM_PROP,
TYPE_VEC3_F32,
"atom properties",
)?,
PyFloatArray3::F64(arr) => extract_vec3_column(
&arr,
batch,
n_atoms,
key,
KIND_ATOM_PROP,
TYPE_VEC3_F64,
"atom properties",
)?,
}));
match arr {
PyFloatArray3::F32(arr) => {
let readonly = arr.readonly();
let view = readonly.as_array();
let shape = view.shape();
if shape == [batch, n_atoms, 3] {
return Ok(Some(extract_vec3_column(
&arr,
batch,
n_atoms,
key,
KIND_ATOM_PROP,
TYPE_VEC3_F32,
"atom properties",
)?));
}
}
PyFloatArray3::F64(arr) => {
let readonly = arr.readonly();
let view = readonly.as_array();
let shape = view.shape();
if shape == [batch, n_atoms, 3] {
return Ok(Some(extract_vec3_column(
&arr,
batch,
n_atoms,
key,
KIND_ATOM_PROP,
TYPE_VEC3_F64,
"atom properties",
)?));
}
}
}
}
if let Some(column) =
extract_stacked_tensor_column(value, batch, Some(n_atoms), key, KIND_ATOM_PROP)?
{
return Ok(Some(column));
}
Ok(None)
}
Expand All @@ -425,7 +607,7 @@ fn extract_custom_columns(
push_unique_key(&mut seen_keys, &key, "property")?;
let Some(column) = extract_property_column(&value, batch, &key, KIND_MOL_PROP)? else {
return Err(PyValueError::new_err(format!(
"Unsupported batched property '{}' type. Supported: list[str], 1D numeric arrays, 2D numeric arrays, or 3D float arrays with trailing dimension 3",
"Unsupported batched property '{}' type. Supported: list[str], 1D numeric arrays, 2D numeric arrays, 3D float arrays with trailing dimension 3, or stacked tensor ndarrays with dtype float32, float64, int32, or int64",
key
)));
};
Expand All @@ -440,7 +622,7 @@ fn extract_custom_columns(
push_unique_key(&mut seen_keys, &key, "atom property")?;
let Some(column) = extract_atom_property_column(&value, batch, n_atoms, &key)? else {
return Err(PyValueError::new_err(format!(
"Unsupported batched atom property '{}' type. Supported: 2D numeric arrays with shape (batch, n_atoms) or 3D float arrays with shape (batch, n_atoms, 3)",
"Unsupported batched atom property '{}' type. Supported: 2D numeric arrays with shape (batch, n_atoms), 3D float arrays with shape (batch, n_atoms, 3), or stacked tensor ndarrays with dtype float32, float64, int32, or int64",
key
)));
};
Expand Down
Loading
Loading