Skip to content

Commit cb6add2

Browse files
committed
Handle NumPy ndarray query parameters
Fix ndarray parameter conversion by importing importlib.util directly, inferring NumPy parameter types from dtype and shape, reading ndarray buffers through Python's buffer metadata instead of materializing tolist(), and using stable homogeneous-list inference for Python bool, int, and float lists.
1 parent e36e29d commit cb6add2

3 files changed

Lines changed: 244 additions & 3 deletions

File tree

src_cpp/include/cached_import/py_cached_modules.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,13 @@ class DecimalCachedItem : public PythonCachedItem {
2525
class ImportLibCachedItem : public PythonCachedItem {
2626
class UtilCachedItem : public PythonCachedItem {
2727
public:
28-
explicit UtilCachedItem(PythonCachedItem* parent)
29-
: PythonCachedItem{"util", parent}, find_spec{"find_spec", this} {}
28+
UtilCachedItem() : PythonCachedItem{"importlib.util"}, find_spec{"find_spec", this} {}
3029

3130
PythonCachedItem find_spec;
3231
};
3332

3433
public:
35-
ImportLibCachedItem() : PythonCachedItem("importlib"), util(this) {}
34+
ImportLibCachedItem() : PythonCachedItem("importlib"), util() {}
3635

3736
UtilCachedItem util;
3837
};

src_cpp/py_connection.cpp

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "include/py_udf.h"
1616
#include "main/connection.h"
1717
#include "main/query_result/materialized_query_result.h"
18+
#include "numpy/numpy_type.h"
1819
#include "pandas/pandas_scan.h"
1920
#include "processor/result/factorized_table.h"
2021
#include "pyarrow/pyarrow_scan.h"
@@ -358,6 +359,47 @@ bool integerFitsIn<uint8_t>(int64_t val) {
358359
return val >= 0 && val <= UINT8_MAX;
359360
}
360361

362+
static LogicalType pyHomogeneousListType(const py::list& lst) {
363+
py::handle firstNonNull;
364+
for (auto child : lst) {
365+
if (!child.is_none()) {
366+
firstNonNull = child;
367+
break;
368+
}
369+
}
370+
if (!firstNonNull) {
371+
return LogicalType::LIST(LogicalType::ANY());
372+
}
373+
if (!py::isinstance<py::bool_>(firstNonNull) && !py::isinstance<py::int_>(firstNonNull) &&
374+
!py::isinstance<py::float_>(firstNonNull)) {
375+
return LogicalType::ANY();
376+
}
377+
for (auto child : lst) {
378+
if (child.is_none()) {
379+
continue;
380+
}
381+
if (child.get_type().ptr() != firstNonNull.get_type().ptr()) {
382+
return LogicalType::ANY();
383+
}
384+
}
385+
if (py::isinstance<py::bool_>(firstNonNull)) {
386+
return LogicalType::LIST(LogicalType::BOOL());
387+
}
388+
if (py::isinstance<py::int_>(firstNonNull)) {
389+
return LogicalType::LIST(LogicalType::INT64());
390+
}
391+
return LogicalType::LIST(LogicalType::DOUBLE());
392+
}
393+
394+
static LogicalType pyNumpyArrayLogicalType(const py::array& arr) {
395+
auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype"));
396+
auto type = NumpyTypeUtils::numpyToLogicalType(npType);
397+
for (auto i = 0; i < arr.ndim(); ++i) {
398+
type = LogicalType::LIST(std::move(type));
399+
}
400+
return type;
401+
}
402+
361403
static LogicalType pyLogicalType(const py::handle& val) {
362404
auto datetime_datetime = importCache->datetime.datetime();
363405
auto time_delta = importCache->datetime.timedelta();
@@ -436,8 +478,14 @@ static LogicalType pyLogicalType(const py::handle& val) {
436478
childValueType = std::move(resultValue);
437479
}
438480
return LogicalType::MAP(std::move(childKeyType), std::move(childValueType));
481+
} else if (py::isinstance<py::array>(val)) {
482+
return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val));
439483
} else if (py::isinstance<py::list>(val)) {
440484
py::list lst = py::reinterpret_borrow<py::list>(val);
485+
auto homogeneousType = pyHomogeneousListType(lst);
486+
if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) {
487+
return homogeneousType;
488+
}
441489
auto childType = LogicalType::ANY();
442490
for (auto child : lst) {
443491
auto curChildType = pyLogicalType(child);
@@ -535,8 +583,14 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) {
535583
structFields.emplace_back(std::move(keyName), std::move(keyType));
536584
}
537585
return LogicalType::STRUCT(std::move(structFields));
586+
} else if (py::isinstance<py::array>(val)) {
587+
return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val));
538588
} else if (py::isinstance<py::list>(val)) {
539589
py::list lst = py::reinterpret_borrow<py::list>(val);
590+
auto homogeneousType = pyHomogeneousListType(lst);
591+
if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) {
592+
return homogeneousType;
593+
}
540594
auto childType = LogicalType::ANY();
541595
for (auto child : lst) {
542596
auto curChildType = pyLogicalTypeFromParameter(child);
@@ -569,6 +623,90 @@ static std::string pythonObjectToJsonString(const py::handle& val) {
569623
return py::cast<std::string>(jsonStr);
570624
}
571625

626+
template<typename T>
627+
static Value transformNumpyScalarAs(const void* ptr, const LogicalType& type) {
628+
auto value = *reinterpret_cast<const T*>(ptr);
629+
switch (type.getLogicalTypeID()) {
630+
case LogicalTypeID::BOOL:
631+
return Value::createValue<bool>(static_cast<bool>(value));
632+
case LogicalTypeID::INT64:
633+
return Value::createValue<int64_t>(static_cast<int64_t>(value));
634+
case LogicalTypeID::UINT32:
635+
return Value::createValue<uint32_t>(static_cast<uint32_t>(value));
636+
case LogicalTypeID::INT32:
637+
return Value::createValue<int32_t>(static_cast<int32_t>(value));
638+
case LogicalTypeID::UINT16:
639+
return Value::createValue<uint16_t>(static_cast<uint16_t>(value));
640+
case LogicalTypeID::INT16:
641+
return Value::createValue<int16_t>(static_cast<int16_t>(value));
642+
case LogicalTypeID::UINT8:
643+
return Value::createValue<uint8_t>(static_cast<uint8_t>(value));
644+
case LogicalTypeID::INT8:
645+
return Value::createValue<int8_t>(static_cast<int8_t>(value));
646+
case LogicalTypeID::FLOAT:
647+
return Value(static_cast<float>(value));
648+
case LogicalTypeID::DOUBLE:
649+
return Value::createValue<double>(static_cast<double>(value));
650+
default:
651+
throw RuntimeException("Unsupported numpy ndarray parameter child type " + type.toString());
652+
}
653+
}
654+
655+
static Value transformNumpyScalarAs(const void* ptr, NumpyNullableType npType,
656+
const LogicalType& type) {
657+
switch (npType) {
658+
case NumpyNullableType::BOOL:
659+
return transformNumpyScalarAs<bool>(ptr, type);
660+
case NumpyNullableType::INT_8:
661+
return transformNumpyScalarAs<int8_t>(ptr, type);
662+
case NumpyNullableType::UINT_8:
663+
return transformNumpyScalarAs<uint8_t>(ptr, type);
664+
case NumpyNullableType::INT_16:
665+
return transformNumpyScalarAs<int16_t>(ptr, type);
666+
case NumpyNullableType::UINT_16:
667+
return transformNumpyScalarAs<uint16_t>(ptr, type);
668+
case NumpyNullableType::INT_32:
669+
return transformNumpyScalarAs<int32_t>(ptr, type);
670+
case NumpyNullableType::UINT_32:
671+
return transformNumpyScalarAs<uint32_t>(ptr, type);
672+
case NumpyNullableType::INT_64:
673+
return transformNumpyScalarAs<int64_t>(ptr, type);
674+
case NumpyNullableType::UINT_64:
675+
return transformNumpyScalarAs<uint64_t>(ptr, type);
676+
case NumpyNullableType::FLOAT_32:
677+
return transformNumpyScalarAs<float>(ptr, type);
678+
case NumpyNullableType::FLOAT_64:
679+
return transformNumpyScalarAs<double>(ptr, type);
680+
default:
681+
throw RuntimeException("Unsupported numpy ndarray parameter dtype");
682+
}
683+
}
684+
685+
static Value transformNumpyArrayAs(const LogicalType& type, uint64_t dimension, const uint8_t* ptr,
686+
const py::buffer_info& info, NumpyNullableType npType) {
687+
if (dimension == static_cast<uint64_t>(info.ndim)) {
688+
return transformNumpyScalarAs(ptr, npType, type);
689+
}
690+
if (type.getLogicalTypeID() != LogicalTypeID::LIST) {
691+
throw RuntimeException("Cannot convert numpy ndarray parameter to " + type.toString());
692+
}
693+
std::vector<std::unique_ptr<Value>> children;
694+
children.reserve(info.shape[dimension]);
695+
const auto& childType = ListType::getChildType(type);
696+
for (auto i = 0; i < info.shape[dimension]; ++i) {
697+
auto childPtr = ptr + i * info.strides[dimension];
698+
children.push_back(std::make_unique<Value>(
699+
transformNumpyArrayAs(childType, dimension + 1, childPtr, info, npType)));
700+
}
701+
return Value(type.copy(), std::move(children));
702+
}
703+
704+
static Value transformNumpyArrayAs(const py::array& arr, const LogicalType& type) {
705+
auto info = arr.request();
706+
auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype")).type;
707+
return transformNumpyArrayAs(type, 0, reinterpret_cast<const uint8_t*>(info.ptr), info, npType);
708+
}
709+
572710
Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalType& type) {
573711
// ignore the type of the actual python object, just directly cast
574712
auto datetime_datetime = importCache->datetime.datetime();
@@ -598,6 +736,8 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
598736
return Value::createValue<int8_t>(py::cast<py::int_>(val).cast<int8_t>());
599737
case LogicalTypeID::DOUBLE:
600738
return Value::createValue<double>(py::cast<py::float_>(val).cast<double>());
739+
case LogicalTypeID::FLOAT:
740+
return Value(py::cast<py::float_>(val).cast<float>());
601741
case LogicalTypeID::DECIMAL: {
602742
auto str = py::cast<std::string>(py::str(val));
603743
int128_t result = 0;
@@ -674,6 +814,9 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
674814
return Value{uuidToAppend};
675815
}
676816
case LogicalTypeID::LIST: {
817+
if (py::isinstance<py::array>(val)) {
818+
return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type);
819+
}
677820
py::list lst = py::reinterpret_borrow<py::list>(val);
678821
std::vector<std::unique_ptr<Value>> children;
679822
for (auto child : lst) {
@@ -729,6 +872,9 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val,
729872
auto jsonStr = pythonObjectToJsonString(val);
730873
return Value::createValue<std::string>(jsonStr);
731874
}
875+
if (py::isinstance<py::array>(val)) {
876+
return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type);
877+
}
732878
py::list lst = py::reinterpret_borrow<py::list>(val);
733879
std::vector<std::unique_ptr<Value>> children;
734880
for (auto child : lst) {

src_py/_lbug_capi.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def _resolve_library_path() -> str:
182182
_LBUG_MAP = 55
183183
_LBUG_UNION = 56
184184
_LBUG_UUID = 59
185+
_NUMPY_MODULE: Any | None = None
186+
_NUMPY_IMPORT_ATTEMPTED = False
185187

186188

187189
def _setup_signatures() -> None:
@@ -293,8 +295,24 @@ def _setup_signatures() -> None:
293295
_LIB.lbug_value_create_null.restype = ctypes.POINTER(_LbugValue)
294296
_LIB.lbug_value_create_bool.argtypes = [ctypes.c_bool]
295297
_LIB.lbug_value_create_bool.restype = ctypes.POINTER(_LbugValue)
298+
_LIB.lbug_value_create_int8.argtypes = [ctypes.c_int8]
299+
_LIB.lbug_value_create_int8.restype = ctypes.POINTER(_LbugValue)
300+
_LIB.lbug_value_create_int16.argtypes = [ctypes.c_int16]
301+
_LIB.lbug_value_create_int16.restype = ctypes.POINTER(_LbugValue)
302+
_LIB.lbug_value_create_int32.argtypes = [ctypes.c_int32]
303+
_LIB.lbug_value_create_int32.restype = ctypes.POINTER(_LbugValue)
296304
_LIB.lbug_value_create_int64.argtypes = [ctypes.c_int64]
297305
_LIB.lbug_value_create_int64.restype = ctypes.POINTER(_LbugValue)
306+
_LIB.lbug_value_create_uint8.argtypes = [ctypes.c_uint8]
307+
_LIB.lbug_value_create_uint8.restype = ctypes.POINTER(_LbugValue)
308+
_LIB.lbug_value_create_uint16.argtypes = [ctypes.c_uint16]
309+
_LIB.lbug_value_create_uint16.restype = ctypes.POINTER(_LbugValue)
310+
_LIB.lbug_value_create_uint32.argtypes = [ctypes.c_uint32]
311+
_LIB.lbug_value_create_uint32.restype = ctypes.POINTER(_LbugValue)
312+
_LIB.lbug_value_create_uint64.argtypes = [ctypes.c_uint64]
313+
_LIB.lbug_value_create_uint64.restype = ctypes.POINTER(_LbugValue)
314+
_LIB.lbug_value_create_float.argtypes = [ctypes.c_float]
315+
_LIB.lbug_value_create_float.restype = ctypes.POINTER(_LbugValue)
298316
_LIB.lbug_value_create_double.argtypes = [ctypes.c_double]
299317
_LIB.lbug_value_create_double.restype = ctypes.POINTER(_LbugValue)
300318
_LIB.lbug_value_create_string.argtypes = [ctypes.c_char_p]
@@ -815,9 +833,87 @@ def _parse_rendered_value(value: str) -> Any:
815833
return value
816834

817835

836+
def _numpy_module() -> Any | None:
837+
global _NUMPY_IMPORT_ATTEMPTED, _NUMPY_MODULE
838+
if _NUMPY_IMPORT_ATTEMPTED:
839+
return _NUMPY_MODULE
840+
_NUMPY_IMPORT_ATTEMPTED = True
841+
try:
842+
import numpy as np
843+
except ModuleNotFoundError:
844+
return None
845+
_NUMPY_MODULE = np
846+
return np
847+
848+
849+
def _is_numpy_scalar(value: Any) -> bool:
850+
np = _numpy_module()
851+
return bool(np is not None and isinstance(value, np.generic))
852+
853+
854+
def _is_numpy_array(value: Any) -> bool:
855+
np = _numpy_module()
856+
return bool(np is not None and isinstance(value, np.ndarray))
857+
858+
859+
def _numpy_scalar_value_from_python(value: Any) -> ctypes.POINTER(_LbugValue):
860+
dtype = value.dtype
861+
kind = dtype.kind
862+
item = value.item()
863+
if kind == "b":
864+
return _LIB.lbug_value_create_bool(bool(item))
865+
if kind == "i":
866+
if dtype.itemsize == 1:
867+
return _LIB.lbug_value_create_int8(item)
868+
if dtype.itemsize == 2:
869+
return _LIB.lbug_value_create_int16(item)
870+
if dtype.itemsize == 4:
871+
return _LIB.lbug_value_create_int32(item)
872+
return _LIB.lbug_value_create_int64(item)
873+
if kind == "u":
874+
if dtype.itemsize == 1:
875+
return _LIB.lbug_value_create_uint8(item)
876+
if dtype.itemsize == 2:
877+
return _LIB.lbug_value_create_uint16(item)
878+
if dtype.itemsize == 4:
879+
return _LIB.lbug_value_create_uint32(item)
880+
return _LIB.lbug_value_create_uint64(item)
881+
if kind == "f":
882+
if dtype.itemsize == 4:
883+
return _LIB.lbug_value_create_float(item)
884+
return _LIB.lbug_value_create_double(item)
885+
886+
return _value_from_python(item)
887+
888+
889+
def _numpy_array_value_from_python(value: Any) -> ctypes.POINTER(_LbugValue):
890+
if value.ndim == 0:
891+
return _numpy_scalar_value_from_python(value[()])
892+
893+
child_ptrs: list[ctypes.POINTER(_LbugValue)] = []
894+
try:
895+
for item in value:
896+
child_ptrs.append(_value_from_python(item))
897+
out = ctypes.POINTER(_LbugValue)()
898+
arr_type = ctypes.POINTER(_LbugValue) * len(child_ptrs)
899+
arr = arr_type(*child_ptrs) if child_ptrs else arr_type()
900+
_check_state(
901+
_LIB.lbug_value_create_list(len(child_ptrs), arr, ctypes.byref(out)),
902+
"Failed to create numpy ndarray list value",
903+
)
904+
return out
905+
finally:
906+
for ptr in child_ptrs:
907+
_LIB.lbug_value_destroy(ptr)
908+
909+
818910
def _value_from_python(value: Any) -> ctypes.POINTER(_LbugValue):
819911
if value is None:
820912
return _LIB.lbug_value_create_null()
913+
if _is_numpy_array(value):
914+
return _numpy_array_value_from_python(value)
915+
if _is_numpy_scalar(value):
916+
return _numpy_scalar_value_from_python(value)
821917
if isinstance(value, bool):
822918
return _LIB.lbug_value_create_bool(value)
823919
if isinstance(value, int) and not isinstance(value, bool):

0 commit comments

Comments
 (0)