From 60d1c42818090f5a546fc5f8ccb1c1537aca5a5d Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Fri, 5 Jun 2026 11:58:14 +0200 Subject: [PATCH] feat: support tensor properties in flat batches --- atompack-py/src/database_batch.rs | 228 ++++++++++++++++++++++++++--- atompack-py/src/database_flat.rs | 224 +++++++++++++++++++++++++++- atompack-py/src/lib.rs | 2 +- atompack-py/tests/test_database.py | 140 +++++++++++++++++- 4 files changed, 561 insertions(+), 33 deletions(-) diff --git a/atompack-py/src/database_batch.rs b/atompack-py/src/database_batch.rs index 64d463d..ff417cb 100644 --- a/atompack-py/src/database_batch.rs +++ b/atompack-py/src/database_batch.rs @@ -61,7 +61,7 @@ fn schema_section(kind: u8, key: &str, type_tag: u8, slot_bytes: usize) -> Datab } fn schema_slot_bytes(kind: u8, key: &str, type_tag: u8, payload_slot_bytes: usize) -> usize { - if matches!(type_tag, TYPE_STRING | TYPE_NONE) { + if matches!(type_tag, TYPE_STRING | TYPE_NONE) || is_tensor_type_tag(type_tag) { 0 } else if is_per_atom(kind, key, type_tag) { type_tag_elem_bytes(type_tag) @@ -137,6 +137,18 @@ fn extract_string_column( })) } +fn reject_list_or_tuple_tensor_column(value: &Bound<'_, PyAny>, key: &str) -> PyResult<()> { + if value.cast::().is_ok() || value.cast::().is_ok() { + return Err(PyValueError::new_err(format!( + "custom property '{key}' must be a stacked ndarray tensor for add_arrays_batch; \ + list/tuple tensor inputs are not supported. Ragged tensors should use \ + per-molecule construction plus add_molecules(...) and non-flat retrieval with \ + get_molecule(s)." + ))); + } + Ok(()) +} + fn extract_scalar_column_f64>( arr: &Bound<'_, PyArray1>, batch: usize, @@ -251,6 +263,152 @@ fn extract_vec3_column( }) } +fn tensor_value_count(shape: &[usize]) -> PyResult { + shape.iter().try_fold(1usize, |acc, dim| { + acc.checked_mul(*dim) + .ok_or_else(|| PyValueError::new_err("Tensor shape overflows usize")) + }) +} + +fn tensor_payload_header(shape: &[usize], key: &str) -> PyResult> { + if shape.len() > u8::MAX as usize { + return Err(PyValueError::new_err(format!( + "custom property '{}' tensor rank {} exceeds maximum {}", + key, + shape.len(), + u8::MAX + ))); + } + if shape.iter().any(|dim| *dim > u32::MAX as usize) { + return Err(PyValueError::new_err(format!( + "custom property '{}' tensor dimensions must fit in u32 for storage", + key + ))); + } + let mut header = Vec::with_capacity(1 + shape.len() * 4); + header.push(shape.len() as u8); + for dim in shape { + header.extend_from_slice(&(*dim as u32).to_le_bytes()); + } + Ok(header) +} + +fn extract_tensor_column( + arr: &Bound<'_, PyArrayDyn>, + batch: usize, + n_atoms: Option, + key: &str, + kind: u8, + type_tag: u8, +) -> PyResult { + let readonly = arr.readonly(); + let view = readonly.as_array(); + let shape = view.shape(); + if shape.len() < 3 || shape.first() != Some(&batch) { + return Err(PyValueError::new_err(format!( + "custom property '{}' must have stacked tensor shape ({}, ...)", + key, batch + ))); + } + if let Some(expected_atoms) = n_atoms + && shape.get(1) != Some(&expected_atoms) + { + return Err(PyValueError::new_err(format!( + "atom property '{}' must have stacked tensor shape ({}, {}, ...)", + key, batch, expected_atoms + ))); + } + + let per_record_shape = &shape[1..]; + let values_per_record = tensor_value_count(per_record_shape)?; + let elem_bytes = std::mem::size_of::(); + let data_slot_bytes = values_per_record + .checked_mul(elem_bytes) + .ok_or_else(|| PyValueError::new_err("Tensor byte length overflow"))?; + let header = tensor_payload_header(per_record_shape, key)?; + let slot_bytes = header + .len() + .checked_add(data_slot_bytes) + .ok_or_else(|| PyValueError::new_err("Tensor payload length overflow"))?; + let values = readonly.as_slice().map_err(|_| { + PyValueError::new_err(format!("custom property '{}' must be C-contiguous", key)) + })?; + let mut payload = Vec::with_capacity( + batch + .checked_mul(slot_bytes) + .ok_or_else(|| PyValueError::new_err("Tensor batch payload length overflow"))?, + ); + for index in 0..batch { + let start = index + .checked_mul(values_per_record) + .ok_or_else(|| PyValueError::new_err("Tensor value offset overflow"))?; + let end = start + .checked_add(values_per_record) + .ok_or_else(|| PyValueError::new_err("Tensor value offset overflow"))?; + payload.extend_from_slice(&header); + payload.extend_from_slice(bytemuck::cast_slice::(&values[start..end])); + } + + Ok(BatchSectionColumn { + key: key.to_string(), + kind, + type_tag, + slot_bytes, + payload, + strings: None, + }) +} + +fn extract_stacked_tensor_column( + value: &Bound<'_, PyAny>, + batch: usize, + n_atoms: Option, + key: &str, + kind: u8, +) -> PyResult> { + if let Ok(arr) = value.cast::>() { + return Ok(Some(extract_tensor_column( + arr, + batch, + n_atoms, + key, + kind, + TYPE_TENSOR_F32, + )?)); + } + if let Ok(arr) = value.cast::>() { + return Ok(Some(extract_tensor_column( + arr, + batch, + n_atoms, + key, + kind, + TYPE_TENSOR_F64, + )?)); + } + if let Ok(arr) = value.cast::>() { + return Ok(Some(extract_tensor_column( + arr, + batch, + n_atoms, + key, + kind, + TYPE_TENSOR_I32, + )?)); + } + if let Ok(arr) = value.cast::>() { + return Ok(Some(extract_tensor_column( + arr, + batch, + n_atoms, + key, + kind, + TYPE_TENSOR_I64, + )?)); + } + Ok(None) +} + fn extract_property_column( value: &Bound<'_, PyAny>, batch: usize, @@ -260,6 +418,7 @@ fn extract_property_column( if let Some(column) = extract_string_column(value, batch, key, kind)? { return Ok(Some(column)); } + reject_list_or_tuple_tensor_column(value, key)?; if let Some(arr) = PyFloatArray1::from_any(value) { return Ok(Some(match arr { PyFloatArray1::F32(arr) => extract_scalar_column_f64(&arr, batch, key, kind)?, @@ -336,6 +495,9 @@ fn extract_property_column( } } } + if let Some(column) = extract_stacked_tensor_column(value, batch, None, key, kind)? { + return Ok(Some(column)); + } Ok(None) } @@ -345,6 +507,7 @@ fn extract_atom_property_column( n_atoms: usize, key: &str, ) -> PyResult> { + reject_list_or_tuple_tensor_column(value, key)?; if let Some(arr) = PyFloatArray2::from_any(value) { let (column, expected) = match arr { PyFloatArray2::F32(arr) => ( @@ -385,26 +548,45 @@ fn extract_atom_property_column( return Ok(Some(column)); } if let Some(arr) = PyFloatArray3::from_any(value) { - return Ok(Some(match arr { - PyFloatArray3::F32(arr) => extract_vec3_column( - &arr, - batch, - n_atoms, - key, - KIND_ATOM_PROP, - TYPE_VEC3_F32, - "atom properties", - )?, - PyFloatArray3::F64(arr) => extract_vec3_column( - &arr, - batch, - n_atoms, - key, - KIND_ATOM_PROP, - TYPE_VEC3_F64, - "atom properties", - )?, - })); + match arr { + PyFloatArray3::F32(arr) => { + let readonly = arr.readonly(); + let view = readonly.as_array(); + let shape = view.shape(); + if shape == [batch, n_atoms, 3] { + return Ok(Some(extract_vec3_column( + &arr, + batch, + n_atoms, + key, + KIND_ATOM_PROP, + TYPE_VEC3_F32, + "atom properties", + )?)); + } + } + PyFloatArray3::F64(arr) => { + let readonly = arr.readonly(); + let view = readonly.as_array(); + let shape = view.shape(); + if shape == [batch, n_atoms, 3] { + return Ok(Some(extract_vec3_column( + &arr, + batch, + n_atoms, + key, + KIND_ATOM_PROP, + TYPE_VEC3_F64, + "atom properties", + )?)); + } + } + } + } + if let Some(column) = + extract_stacked_tensor_column(value, batch, Some(n_atoms), key, KIND_ATOM_PROP)? + { + return Ok(Some(column)); } Ok(None) } @@ -425,7 +607,7 @@ fn extract_custom_columns( push_unique_key(&mut seen_keys, &key, "property")?; let Some(column) = extract_property_column(&value, batch, &key, KIND_MOL_PROP)? else { return Err(PyValueError::new_err(format!( - "Unsupported batched property '{}' type. Supported: list[str], 1D numeric arrays, 2D numeric arrays, or 3D float arrays with trailing dimension 3", + "Unsupported batched property '{}' type. Supported: list[str], 1D numeric arrays, 2D numeric arrays, 3D float arrays with trailing dimension 3, or stacked tensor ndarrays with dtype float32, float64, int32, or int64", key ))); }; @@ -440,7 +622,7 @@ fn extract_custom_columns( push_unique_key(&mut seen_keys, &key, "atom property")?; let Some(column) = extract_atom_property_column(&value, batch, n_atoms, &key)? else { return Err(PyValueError::new_err(format!( - "Unsupported batched atom property '{}' type. Supported: 2D numeric arrays with shape (batch, n_atoms) or 3D float arrays with shape (batch, n_atoms, 3)", + "Unsupported batched atom property '{}' type. Supported: 2D numeric arrays with shape (batch, n_atoms), 3D float arrays with shape (batch, n_atoms, 3), or stacked tensor ndarrays with dtype float32, float64, int32, or int64", key ))); }; diff --git a/atompack-py/src/database_flat.rs b/atompack-py/src/database_flat.rs index 8486633..3df4b82 100644 --- a/atompack-py/src/database_flat.rs +++ b/atompack-py/src/database_flat.rs @@ -9,6 +9,162 @@ enum FlatPositions { F64(Vec), } +type TensorSectionPayloads = Vec>>; + +fn missing_tensor_error(key: &str, index: usize) -> PyErr { + PyValueError::new_err(format!( + "Tensor property '{}' is missing for selected molecule {}; it cannot be \ + flat-batched/concatenated. Retrieve molecules individually with get_molecule(s) instead.", + key, index + )) +} + +fn incompatible_tensor_shapes_error(key: &str, left: &[usize], right: &[usize]) -> PyErr { + PyValueError::new_err(format!( + "Tensor property '{}' has incompatible shapes {:?} and {:?}; it cannot be \ + flat-batched/concatenated. Retrieve molecules individually with get_molecule(s) instead.", + key, left, right + )) +} + +fn incompatible_atom_tensor_suffix_error(key: &str, left: &[usize], right: &[usize]) -> PyErr { + PyValueError::new_err(format!( + "Atom tensor property '{}' has incompatible per-atom tensor suffix shapes {:?} and {:?}; \ + it cannot be flat-batched/concatenated. Retrieve molecules individually with \ + get_molecule(s) instead.", + key, left, right + )) +} + +fn invalid_atom_tensor_shape_error( + key: &str, + shape: &[usize], + first_dim: Option, + n_atoms: usize, +) -> PyErr { + PyValueError::new_err(format!( + "Atom tensor property '{}' has shape {:?}; first dimension {:?} does not match \ + atom count {}. It cannot be flat-batched/concatenated. Retrieve molecules \ + individually with get_molecule(s) instead.", + key, shape, first_dim, n_atoms + )) +} + +fn tensor_array_from_bytes<'py>( + py: Python<'py>, + type_tag: u8, + bytes: Vec, + shape: &[usize], +) -> PyResult> { + Ok(match type_tag { + TYPE_TENSOR_F32 => pyarray1_from_cow(py, cast_or_decode_f32(&bytes)?) + .reshape(shape)? + .into_any() + .unbind(), + TYPE_TENSOR_F64 => pyarray1_from_cow(py, cast_or_decode_f64(&bytes)?) + .reshape(shape)? + .into_any() + .unbind(), + TYPE_TENSOR_I32 => pyarray1_from_cow(py, cast_or_decode_i32(&bytes)?) + .reshape(shape)? + .into_any() + .unbind(), + TYPE_TENSOR_I64 => pyarray1_from_cow(py, cast_or_decode_i64(&bytes)?) + .reshape(shape)? + .into_any() + .unbind(), + _ => { + return Err(PyValueError::new_err(format!( + "Unsupported tensor type tag {}", + type_tag + ))); + } + }) +} + +fn flat_tensor_array<'py>( + py: Python<'py>, + section: &SectionSchema, + payloads: &[Option>], + n_atoms_vec: &[u32], + total_atoms: usize, +) -> PyResult> { + let mut data = Vec::new(); + + if section.per_atom { + let mut expected_suffix: Option> = None; + for (index, payload) in payloads.iter().enumerate() { + let payload = payload + .as_ref() + .ok_or_else(|| missing_tensor_error(§ion.key, index))?; + let (shape, data_offset) = + crate::soa::tensor_shape_from_payload(section.type_tag, payload.as_slice()) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + let n_atoms = n_atoms_vec[index] as usize; + let Some((&first_dim, suffix)) = shape.split_first() else { + return Err(invalid_atom_tensor_shape_error( + §ion.key, + &shape, + None, + n_atoms, + )); + }; + if first_dim != n_atoms { + return Err(invalid_atom_tensor_shape_error( + §ion.key, + &shape, + Some(first_dim), + n_atoms, + )); + } + if let Some(expected) = &expected_suffix { + if expected.as_slice() != suffix { + return Err(incompatible_atom_tensor_suffix_error( + §ion.key, + expected, + suffix, + )); + } + } else { + expected_suffix = Some(suffix.to_vec()); + } + data.extend_from_slice(&payload[data_offset..]); + } + let suffix = expected_suffix.unwrap_or_default(); + let mut output_shape = Vec::with_capacity(1 + suffix.len()); + output_shape.push(total_atoms); + output_shape.extend(suffix); + tensor_array_from_bytes(py, section.type_tag, data, &output_shape) + } else { + let mut expected_shape: Option> = None; + for (index, payload) in payloads.iter().enumerate() { + let payload = payload + .as_ref() + .ok_or_else(|| missing_tensor_error(§ion.key, index))?; + let (shape, data_offset) = + crate::soa::tensor_shape_from_payload(section.type_tag, payload.as_slice()) + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + if let Some(expected) = &expected_shape { + if expected != &shape { + return Err(incompatible_tensor_shapes_error( + §ion.key, + expected, + &shape, + )); + } + } else { + expected_shape = Some(shape.clone()); + } + data.extend_from_slice(&payload[data_offset..]); + } + let tensor_shape = expected_shape.unwrap_or_default(); + let mut output_shape = Vec::with_capacity(1 + tensor_shape.len()); + output_shape.push(payloads.len()); + output_shape.extend(tensor_shape); + tensor_array_from_bytes(py, section.type_tag, data, &output_shape) + } +} + pub(super) fn get_molecules_flat_soa_impl<'py>( inner: &AtomDatabase, py: Python<'py>, @@ -223,7 +379,7 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let mut section_buffers: Vec> = schema .iter() .map(|s| { - if s.slot_bytes == 0 { + if s.slot_bytes == 0 || is_tensor_type_tag(s.type_tag) { Vec::new() } else if s.per_atom { vec![0u8; total_atoms * s.elem_bytes] @@ -236,7 +392,17 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let mut string_sections: Vec>>> = schema .iter() .map(|s| { - if s.slot_bytes == 0 { + if matches!(s.type_tag, TYPE_STRING | TYPE_NONE) { + Some(vec![None; n_mols]) + } else { + None + } + }) + .collect(); + let mut tensor_sections: Vec> = schema + .iter() + .map(|s| { + if is_tensor_type_tag(s.type_tag) { Some(vec![None; n_mols]) } else { None @@ -271,6 +437,11 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( .iter_mut() .map(|opt| opt.as_mut().map(std::sync::Mutex::new)) .collect(); + let tensor_mutexes: Vec>> = + tensor_sections + .iter_mut() + .map(|opt| opt.as_mut().map(std::sync::Mutex::new)) + .collect(); let process_mol = |i: usize, mol_bytes: &[u8]| -> atompack::Result<()> { let md = parse_mol_fast_soa(mol_bytes, ctx)?; @@ -340,7 +511,12 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( ))); } - if schema_entry.per_atom { + if is_tensor_type_tag(schema_entry.type_tag) { + let _ = crate::soa::tensor_shape_from_payload( + schema_entry.type_tag, + sec.payload, + )?; + } else if schema_entry.per_atom { let expected = n.checked_mul(schema_entry.elem_bytes).ok_or_else(|| { invalid_data(format!( @@ -367,7 +543,14 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( ))); } - if schema_entry.slot_bytes == 0 { + if is_tensor_type_tag(schema_entry.type_tag) { + if let Some(ref mtx) = tensor_mutexes[section_idx] { + let mut guard = mtx + .lock() + .map_err(|_| invalid_data("tensor section mutex poisoned"))?; + guard[i] = Some(sec.payload.to_vec()); + } + } else if schema_entry.slot_bytes == 0 { if schema_entry.type_tag == TYPE_NONE { continue; } @@ -424,7 +607,12 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( ))); } - if schema_entry.per_atom { + if is_tensor_type_tag(schema_entry.type_tag) { + let _ = crate::soa::tensor_shape_from_payload( + schema_entry.type_tag, + sec.payload, + )?; + } else if schema_entry.per_atom { let expected = n.checked_mul(schema_entry.elem_bytes).ok_or_else(|| { invalid_data(format!( @@ -451,7 +639,14 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( ))); } - if schema_entry.slot_bytes == 0 { + if is_tensor_type_tag(schema_entry.type_tag) { + if let Some(ref mtx) = tensor_mutexes[section_idx] { + let mut guard = mtx + .lock() + .map_err(|_| invalid_data("tensor section mutex poisoned"))?; + guard[i] = Some(sec.payload.to_vec()); + } + } else if schema_entry.slot_bytes == 0 { if schema_entry.type_tag == TYPE_NONE { continue; } @@ -541,6 +736,7 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( schema, section_buffers, string_sections, + tensor_sections, total_atoms, ))) }) @@ -553,6 +749,7 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( schema, section_buffers, string_results, + tensor_results, total_atoms, ) = match result { None => { @@ -571,7 +768,7 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( }; let dict = PyDict::new(py); - dict.set_item("n_atoms", PyArray1::from_vec(py, n_atoms_vec))?; + dict.set_item("n_atoms", PyArray1::from_slice(py, &n_atoms_vec))?; match positions { FlatPositions::F32(values) => { dict.set_item( @@ -591,10 +788,11 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( let atom_props_dict = PyDict::new(py); let mol_props_dict = PyDict::new(py); - for ((s, buf), str_result) in schema + for (((s, buf), str_result), tensor_result) in schema .iter() .zip(section_buffers.into_iter()) .zip(string_results.iter()) + .zip(tensor_results.iter()) { let target = match s.kind { KIND_ATOM_PROP => &atom_props_dict, @@ -602,6 +800,16 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( _ => &dict, }; + if is_tensor_type_tag(s.type_tag) { + if let Some(payloads) = tensor_result { + target.set_item( + &s.key, + flat_tensor_array(py, s, payloads, &n_atoms_vec, total_atoms)?, + )?; + } + continue; + } + if s.slot_bytes == 0 { if let Some(strings) = str_result { let py_list: Vec> = strings diff --git a/atompack-py/src/lib.rs b/atompack-py/src/lib.rs index 546d239..aa2daa9 100644 --- a/atompack-py/src/lib.rs +++ b/atompack-py/src/lib.rs @@ -16,7 +16,7 @@ use atompack::{ use numpy::{Element, PyArray1, PyArray2, PyArray3, PyArrayDyn, PyArrayMethods}; use pyo3::exceptions::{PyFileExistsError, PyIndexError, PyKeyError, PyTypeError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyDict, PyTuple}; +use pyo3::types::{PyBytes, PyDict, PyList, PyTuple}; use pyo3::{IntoPyObject, IntoPyObjectExt}; use std::borrow::Cow; use std::path::PathBuf; diff --git a/atompack-py/tests/test_database.py b/atompack-py/tests/test_database.py index 58a2ffe..0f7d1be 100644 --- a/atompack-py/tests/test_database.py +++ b/atompack-py/tests/test_database.py @@ -1,8 +1,8 @@ # Copyright 2026 Entalpic from __future__ import annotations -from pathlib import Path import pickle +from pathlib import Path import atompack import numpy as np @@ -296,6 +296,122 @@ def test_database_add_arrays_batch_roundtrip_with_custom_properties(tmp_path: Pa assert second.get_property("phase") == "valid" +def test_database_add_arrays_batch_roundtrip_with_tensor_properties(tmp_path: Path) -> None: + path = tmp_path / "batch_arrays_tensor_custom.atp" + positions = np.array( + [ + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], + [[0.5, 0.0, 0.0], [1.5, 0.0, 0.0]], + ], + dtype=np.float32, + ) + atomic_numbers = np.array([[6, 8], [1, 8]], dtype=np.uint8) + density = np.arange(16, dtype=np.float64).reshape(2, 2, 4) + atom_codes = np.arange(8, dtype=np.int32).reshape(2, 2, 2) + + db = atompack.Database(str(path)) + db.add_arrays_batch( + positions, + atomic_numbers, + properties={"density": density}, + atom_properties={"atom_codes": atom_codes}, + ) + db.flush() + + reopened = atompack.Database.open(str(path)) + molecules = reopened.get_molecules([0, 1]) + + first_density = molecules[0].get_property("density") + assert first_density.dtype == np.float64 + np.testing.assert_array_equal(first_density, density[0]) + + second_density = molecules[1].get_property("density") + assert second_density.dtype == np.float64 + np.testing.assert_array_equal(second_density, density[1]) + + first_codes = molecules[0].get_property("atom_codes") + assert first_codes.dtype == np.int32 + np.testing.assert_array_equal(first_codes, atom_codes[0]) + + second_codes = molecules[1].get_property("atom_codes") + assert second_codes.dtype == np.int32 + np.testing.assert_array_equal(second_codes, atom_codes[1]) + + +def test_get_molecules_flat_stacks_uniform_tensor_properties(tmp_path: Path) -> None: + path = tmp_path / "flat_tensor_custom.atp" + positions = np.array( + [ + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], + [[0.5, 0.0, 0.0], [1.5, 0.0, 0.0]], + ], + dtype=np.float32, + ) + atomic_numbers = np.array([[6, 8], [1, 8]], dtype=np.uint8) + density = np.arange(16, dtype=np.float32).reshape(2, 2, 4) + atom_descriptors = np.arange(16, dtype=np.int64).reshape(2, 2, 2, 2) + + db = atompack.Database(str(path)) + db.add_arrays_batch( + positions, + atomic_numbers, + properties={"density": density}, + atom_properties={"atom_descriptors": atom_descriptors}, + ) + db.flush() + + reopened = atompack.Database.open(str(path)) + flat = reopened.get_molecules_flat([0, 1]) + + flat_density = flat["properties"]["density"] + assert flat_density.dtype == np.float32 + assert flat_density.shape == (2, 2, 4) + np.testing.assert_array_equal(flat_density, density) + + flat_descriptors = flat["atom_properties"]["atom_descriptors"] + assert flat_descriptors.dtype == np.int64 + assert flat_descriptors.shape == (4, 2, 2) + np.testing.assert_array_equal(flat_descriptors, np.concatenate(atom_descriptors, axis=0)) + + +@pytest.mark.parametrize("container", [list, tuple]) +def test_database_add_arrays_batch_rejects_list_tuple_tensor_properties( + tmp_path: Path, + container: type, +) -> None: + path = tmp_path / f"batch_arrays_reject_{container.__name__}_tensor_list.atp" + positions = np.zeros((2, 2, 3), dtype=np.float32) + atomic_numbers = np.ones((2, 2), dtype=np.uint8) + ragged = container( + [ + np.zeros((2, 4), dtype=np.float32), + np.zeros((3, 4), dtype=np.float32), + ] + ) + + db = atompack.Database(str(path)) + with pytest.raises( + ValueError, + match=r"requires stacked ndarray tensors|must be a stacked ndarray tensor", + ): + db.add_arrays_batch( + positions, + atomic_numbers, + properties={"ragged_tensor": ragged}, + ) + + db.add_arrays_batch( + positions, + atomic_numbers, + properties={"phase": ["train", "valid"]}, + ) + db.flush() + + reopened = atompack.Database.open(str(path)) + assert reopened[0].get_property("phase") == "train" + assert reopened[1].get_property("phase") == "valid" + + def test_database_add_arrays_batch_appends_variable_size_atom_properties( tmp_path: Path, ) -> None: @@ -543,6 +659,28 @@ def test_database_tensor_properties_allow_value_level_shapes( np.testing.assert_allclose(second_descriptor, np.arange(8, dtype=np.float32).reshape(2, 4)) +def test_get_molecules_flat_rejects_incompatible_tensor_shapes( + tmp_path: Path, +) -> None: + path = tmp_path / "flat_ragged_tensor_shapes.atp" + mol1 = _make_molecule(-6.0) + mol1.set_property("density", np.zeros((2, 4), dtype=np.float32)) + + mol2 = _make_molecule(-7.0) + mol2.set_property("density", np.zeros((4, 4), dtype=np.float32)) + + db = atompack.Database(str(path)) + db.add_molecules([mol1, mol2]) + db.flush() + + reopened = atompack.Database.open(str(path)) + with pytest.raises( + ValueError, + match=r"density.*incompatible shapes.*cannot be flat-batched/concatenated.*get_molecule", + ): + reopened.get_molecules_flat([0, 1]) + + @pytest.mark.parametrize("mmap", [False, True]) @pytest.mark.parametrize("compression", ["none", "lz4", "zstd"]) def test_database_single_item_mutation_is_copy_on_write(