Skip to content

Commit d5fdcc4

Browse files
committed
Handle NumPy ndarray query parameters
Fix ndarray parameter conversion by importing importlib.util directly, inferring NumPy parameter types from dtype and shape, reading ndarray buffers through Python's buffer metadata instead of materializing tolist(), and using stable homogeneous-list inference for Python bool, int, and float lists.
1 parent 694bf0c commit d5fdcc4

5 files changed

Lines changed: 258 additions & 23 deletions

File tree

src_cpp/include/cached_import/py_cached_modules.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,13 @@ class DecimalCachedItem : public PythonCachedItem {
2525
class ImportLibCachedItem : public PythonCachedItem {
2626
class UtilCachedItem : public PythonCachedItem {
2727
public:
28-
explicit UtilCachedItem(PythonCachedItem* parent)
29-
: PythonCachedItem{"util", parent}, find_spec{"find_spec", this} {}
28+
UtilCachedItem() : PythonCachedItem{"importlib.util"}, find_spec{"find_spec", this} {}
3029

3130
PythonCachedItem find_spec;
3231
};
3332

3433
public:
35-
ImportLibCachedItem() : PythonCachedItem("importlib"), util(this) {}
34+
ImportLibCachedItem() : PythonCachedItem("importlib"), util() {}
3635

3736
UtilCachedItem util;
3837
};

src_cpp/py_connection.cpp

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "include/py_udf.h"
1616
#include "main/connection.h"
1717
#include "main/query_result/materialized_query_result.h"
18+
#include "numpy/numpy_type.h"
1819
#include "pandas/pandas_scan.h"
1920
#include "processor/result/factorized_table.h"
2021
#include "pyarrow/pyarrow_scan.h"
@@ -388,6 +389,47 @@ bool integerFitsIn<uint8_t>(int64_t val) {
388389
return val >= 0 && val <= UINT8_MAX;
389390
}
390391

392+
static LogicalType pyHomogeneousListType(const py::list& lst) {
393+
py::handle firstNonNull;
394+
for (auto child : lst) {
395+
if (!child.is_none()) {
396+
firstNonNull = child;
397+
break;
398+
}
399+
}
400+
if (!firstNonNull) {
401+
return LogicalType::LIST(LogicalType::ANY());
402+
}
403+
if (!py::isinstance<py::bool_>(firstNonNull) && !py::isinstance<py::int_>(firstNonNull) &&
404+
!py::isinstance<py::float_>(firstNonNull)) {
405+
return LogicalType::ANY();
406+
}
407+
for (auto child : lst) {
408+
if (child.is_none()) {
409+
continue;
410+
}
411+
if (child.get_type().ptr() != firstNonNull.get_type().ptr()) {
412+
return LogicalType::ANY();
413+
}
414+
}
415+
if (py::isinstance<py::bool_>(firstNonNull)) {
416+
return LogicalType::LIST(LogicalType::BOOL());
417+
}
418+
if (py::isinstance<py::int_>(firstNonNull)) {
419+
return LogicalType::LIST(LogicalType::INT64());
420+
}
421+
return LogicalType::LIST(LogicalType::DOUBLE());
422+
}
423+
424+
static LogicalType pyNumpyArrayLogicalType(const py::array& arr) {
425+
auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype"));
426+
auto type = NumpyTypeUtils::numpyToLogicalType(npType);
427+
for (auto i = 0; i < arr.ndim(); ++i) {
428+
type = LogicalType::LIST(std::move(type));
429+
}
430+
return type;
431+
}
432+
391433
static LogicalType pyLogicalType(const py::handle& val) {
392434
auto datetime_datetime = importCache->datetime.datetime();
393435
auto time_delta = importCache->datetime.timedelta();
@@ -468,8 +510,14 @@ static LogicalType pyLogicalType(const py::handle& val) {
468510
childValueType = std::move(resultValue);
469511
}
470512
return LogicalType::MAP(std::move(childKeyType), std::move(childValueType));
513+
} else if (py::isinstance<py::array>(val)) {
514+
return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val));
471515
} else if (py::isinstance<py::list>(val)) {
472516
py::list lst = py::reinterpret_borrow<py::list>(val);
517+
auto homogeneousType = pyHomogeneousListType(lst);
518+
if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) {
519+
return homogeneousType;
520+
}
473521
auto childType = LogicalType::ANY();
474522
for (auto child : lst) {
475523
auto curChildType = pyLogicalType(child);
@@ -568,8 +616,14 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) {
568616
structFields.emplace_back(std::move(keyName), std::move(keyType));
569617
}
570618
return LogicalType::STRUCT(std::move(structFields));
619+
} else if (py::isinstance<py::array>(val)) {
620+
return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val));
571621
} else if (py::isinstance<py::list>(val)) {
572622
py::list lst = py::reinterpret_borrow<py::list>(val);
623+
auto homogeneousType = pyHomogeneousListType(lst);
624+
if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) {
625+
return homogeneousType;
626+
}
573627
auto childType = LogicalType::ANY();
574628
for (auto child : lst) {
575629
auto curChildType = pyLogicalTypeFromParameter(child);
@@ -603,6 +657,90 @@ static std::string pythonObjectToJsonString(const py::handle& val) {
603657
return py::cast<std::string>(jsonStr);
604658
}
605659

660+
template<typename T>
661+
static Value transformNumpyScalarAs(const void* ptr, const LogicalType& type) {
662+
auto value = *reinterpret_cast<const T*>(ptr);
663+
switch (type.getLogicalTypeID()) {
664+
case LogicalTypeID::BOOL:
665+
return Value::createValue<bool>(static_cast<bool>(value));
666+
case LogicalTypeID::INT64:
667+
return Value::createValue<int64_t>(static_cast<int64_t>(value));
668+
case LogicalTypeID::UINT32:
669+
return Value::createValue<uint32_t>(static_cast<uint32_t>(value));
670+
case LogicalTypeID::INT32:
671+
return Value::createValue<int32_t>(static_cast<int32_t>(value));
672+
case LogicalTypeID::UINT16:
673+
return Value::createValue<uint16_t>(static_cast<uint16_t>(value));
674+
case LogicalTypeID::INT16:
675+
return Value::createValue<int16_t>(static_cast<int16_t>(value));
676+
case LogicalTypeID::UINT8:
677+
return Value::createValue<uint8_t>(static_cast<uint8_t>(value));
678+
case LogicalTypeID::INT8:
679+
return Value::createValue<int8_t>(static_cast<int8_t>(value));
680+
case LogicalTypeID::FLOAT:
681+
return Value(static_cast<float>(value));
682+
case LogicalTypeID::DOUBLE:
683+
return Value::createValue<double>(static_cast<double>(value));
684+
default:
685+
throw RuntimeException("Unsupported numpy ndarray parameter child type " + type.toString());
686+
}
687+
}
688+
689+
static Value transformNumpyScalarAs(const void* ptr, NumpyNullableType npType,
690+
const LogicalType& type) {
691+
switch (npType) {
692+
case NumpyNullableType::BOOL:
693+
return transformNumpyScalarAs<bool>(ptr, type);
694+
case NumpyNullableType::INT_8:
695+
return transformNumpyScalarAs<int8_t>(ptr, type);
696+
case NumpyNullableType::UINT_8:
697+
return transformNumpyScalarAs<uint8_t>(ptr, type);
698+
case NumpyNullableType::INT_16:
699+
return transformNumpyScalarAs<int16_t>(ptr, type);
700+
case NumpyNullableType::UINT_16:
701+
return transformNumpyScalarAs<uint16_t>(ptr, type);
702+
case NumpyNullableType::INT_32:
703+
return transformNumpyScalarAs<int32_t>(ptr, type);
704+
case NumpyNullableType::UINT_32:
705+
return transformNumpyScalarAs<uint32_t>(ptr, type);
706+
case NumpyNullableType::INT_64:
707+
return transformNumpyScalarAs<int64_t>(ptr, type);
708+
case NumpyNullableType::UINT_64:
709+
return transformNumpyScalarAs<uint64_t>(ptr, type);
710+
case NumpyNullableType::FLOAT_32:
711+
return transformNumpyScalarAs<float>(ptr, type);
712+
case NumpyNullableType::FLOAT_64:
713+
return transformNumpyScalarAs<double>(ptr, type);
714+
default:
715+
throw RuntimeException("Unsupported numpy ndarray parameter dtype");
716+
}
717+
}
718+
719+
static Value transformNumpyArrayAs(const LogicalType& type, uint64_t dimension, const uint8_t* ptr,
720+
const py::buffer_info& info, NumpyNullableType npType) {
721+
if (dimension == static_cast<uint64_t>(info.ndim)) {
722+
return transformNumpyScalarAs(ptr, npType, type);
723+
}
724+
if (type.getLogicalTypeID() != LogicalTypeID::LIST) {
725+
throw RuntimeException("Cannot convert numpy ndarray parameter to " + type.toString());
726+
}
727+
std::vector<std::unique_ptr<Value>> children;
728+
children.reserve(info.shape[dimension]);
729+
const auto& childType = ListType::getChildType(type);
730+
for (auto i = 0; i < info.shape[dimension]; ++i) {
731+
auto childPtr = ptr + i * info.strides[dimension];
732+
children.push_back(std::make_unique<Value>(
733+
transformNumpyArrayAs(childType, dimension + 1, childPtr, info, npType)));
734+
}
735+
return Value(type.copy(), std::move(children));
736+
}
737+
738+
static Value transformNumpyArrayAs(const py::array& arr, const LogicalType& type) {
739+
auto info = arr.request();
740+
auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype")).type;
741+
return transformNumpyArrayAs(type, 0, reinterpret_cast<const uint8_t*>(info.ptr), info, npType);
742+
}
743+
606744
Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalType& type) {
607745
// ignore the type of the actual python object, just directly cast
608746
auto datetime_datetime = importCache->datetime.datetime();
@@ -632,6 +770,8 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
632770
return Value::createValue<int8_t>(py::cast<py::int_>(val).cast<int8_t>());
633771
case LogicalTypeID::DOUBLE:
634772
return Value::createValue<double>(py::cast<py::float_>(val).cast<double>());
773+
case LogicalTypeID::FLOAT:
774+
return Value(py::cast<py::float_>(val).cast<float>());
635775
case LogicalTypeID::DECIMAL: {
636776
auto str = py::cast<std::string>(py::str(val));
637777
int128_t result = 0;
@@ -708,6 +848,9 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
708848
return Value{uuidToAppend};
709849
}
710850
case LogicalTypeID::LIST: {
851+
if (py::isinstance<py::array>(val)) {
852+
return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type);
853+
}
711854
py::list lst = py::reinterpret_borrow<py::list>(val);
712855
std::vector<std::unique_ptr<Value>> children;
713856
for (auto child : lst) {
@@ -763,6 +906,9 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val,
763906
auto jsonStr = pythonObjectToJsonString(val);
764907
return Value::createValue<std::string>(jsonStr);
765908
}
909+
if (py::isinstance<py::array>(val)) {
910+
return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type);
911+
}
766912
py::list lst = py::reinterpret_borrow<py::list>(val);
767913
std::vector<std::unique_ptr<Value>> children;
768914
for (auto child : lst) {

src_py/_lbug_capi.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def _ensure_arrow_atexit_cleanup() -> None:
244244
_LBUG_MAP = 55
245245
_LBUG_UNION = 56
246246
_LBUG_UUID = 59
247+
_NUMPY_MODULE: Any | None = None
248+
_NUMPY_IMPORT_ATTEMPTED = False
247249

248250

249251
def _setup_signatures() -> None:
@@ -392,6 +394,16 @@ def _setup_signatures() -> None:
392394
_LIB.lbug_value_create_int32.restype = ctypes.POINTER(_LbugValue)
393395
_LIB.lbug_value_create_int64.argtypes = [ctypes.c_int64]
394396
_LIB.lbug_value_create_int64.restype = ctypes.POINTER(_LbugValue)
397+
_LIB.lbug_value_create_uint8.argtypes = [ctypes.c_uint8]
398+
_LIB.lbug_value_create_uint8.restype = ctypes.POINTER(_LbugValue)
399+
_LIB.lbug_value_create_uint16.argtypes = [ctypes.c_uint16]
400+
_LIB.lbug_value_create_uint16.restype = ctypes.POINTER(_LbugValue)
401+
_LIB.lbug_value_create_uint32.argtypes = [ctypes.c_uint32]
402+
_LIB.lbug_value_create_uint32.restype = ctypes.POINTER(_LbugValue)
403+
_LIB.lbug_value_create_uint64.argtypes = [ctypes.c_uint64]
404+
_LIB.lbug_value_create_uint64.restype = ctypes.POINTER(_LbugValue)
405+
_LIB.lbug_value_create_float.argtypes = [ctypes.c_float]
406+
_LIB.lbug_value_create_float.restype = ctypes.POINTER(_LbugValue)
395407
_LIB.lbug_value_create_double.argtypes = [ctypes.c_double]
396408
_LIB.lbug_value_create_double.restype = ctypes.POINTER(_LbugValue)
397409
_LIB.lbug_value_create_string.argtypes = [ctypes.c_char_p]
@@ -930,11 +942,89 @@ def _parse_rendered_value(value: str) -> Any:
930942
return value
931943

932944

945+
def _numpy_module() -> Any | None:
946+
global _NUMPY_IMPORT_ATTEMPTED, _NUMPY_MODULE
947+
if _NUMPY_IMPORT_ATTEMPTED:
948+
return _NUMPY_MODULE
949+
_NUMPY_IMPORT_ATTEMPTED = True
950+
try:
951+
import numpy as np
952+
except ModuleNotFoundError:
953+
return None
954+
_NUMPY_MODULE = np
955+
return np
956+
957+
958+
def _is_numpy_scalar(value: Any) -> bool:
959+
np = _numpy_module()
960+
return bool(np is not None and isinstance(value, np.generic))
961+
962+
963+
def _is_numpy_array(value: Any) -> bool:
964+
np = _numpy_module()
965+
return bool(np is not None and isinstance(value, np.ndarray))
966+
967+
968+
def _numpy_scalar_value_from_python(value: Any) -> ctypes.POINTER(_LbugValue):
969+
dtype = value.dtype
970+
kind = dtype.kind
971+
item = value.item()
972+
if kind == "b":
973+
return _LIB.lbug_value_create_bool(bool(item))
974+
if kind == "i":
975+
if dtype.itemsize == 1:
976+
return _LIB.lbug_value_create_int8(item)
977+
if dtype.itemsize == 2:
978+
return _LIB.lbug_value_create_int16(item)
979+
if dtype.itemsize == 4:
980+
return _LIB.lbug_value_create_int32(item)
981+
return _LIB.lbug_value_create_int64(item)
982+
if kind == "u":
983+
if dtype.itemsize == 1:
984+
return _LIB.lbug_value_create_uint8(item)
985+
if dtype.itemsize == 2:
986+
return _LIB.lbug_value_create_uint16(item)
987+
if dtype.itemsize == 4:
988+
return _LIB.lbug_value_create_uint32(item)
989+
return _LIB.lbug_value_create_uint64(item)
990+
if kind == "f":
991+
if dtype.itemsize == 4:
992+
return _LIB.lbug_value_create_float(item)
993+
return _LIB.lbug_value_create_double(item)
994+
995+
return _value_from_python(item)
996+
997+
998+
def _numpy_array_value_from_python(value: Any) -> ctypes.POINTER(_LbugValue):
999+
if value.ndim == 0:
1000+
return _numpy_scalar_value_from_python(value[()])
1001+
1002+
child_ptrs: list[ctypes.POINTER(_LbugValue)] = []
1003+
try:
1004+
for item in value:
1005+
child_ptrs.append(_value_from_python(item))
1006+
out = ctypes.POINTER(_LbugValue)()
1007+
arr_type = ctypes.POINTER(_LbugValue) * len(child_ptrs)
1008+
arr = arr_type(*child_ptrs) if child_ptrs else arr_type()
1009+
_check_state(
1010+
_LIB.lbug_value_create_list(len(child_ptrs), arr, ctypes.byref(out)),
1011+
"Failed to create numpy ndarray list value",
1012+
)
1013+
return out
1014+
finally:
1015+
for ptr in child_ptrs:
1016+
_LIB.lbug_value_destroy(ptr)
1017+
1018+
9331019
def _value_from_python(value: Any) -> ctypes.POINTER(_LbugValue):
9341020
if value is None:
9351021
return _LIB.lbug_value_create_null()
9361022
if isinstance(value, CAPIJsonParameter):
9371023
return _LIB.lbug_value_create_json(value.value.encode())
1024+
if _is_numpy_array(value):
1025+
return _numpy_array_value_from_python(value)
1026+
if _is_numpy_scalar(value):
1027+
return _numpy_scalar_value_from_python(value)
9381028
if isinstance(value, bool):
9391029
return _LIB.lbug_value_create_bool(value)
9401030
if isinstance(value, int) and not isinstance(value, bool):

0 commit comments

Comments
 (0)