|
15 | 15 | #include "include/py_udf.h" |
16 | 16 | #include "main/connection.h" |
17 | 17 | #include "main/query_result/materialized_query_result.h" |
| 18 | +#include "numpy/numpy_type.h" |
18 | 19 | #include "pandas/pandas_scan.h" |
19 | 20 | #include "processor/result/factorized_table.h" |
20 | 21 | #include "pyarrow/pyarrow_scan.h" |
@@ -388,6 +389,47 @@ bool integerFitsIn<uint8_t>(int64_t val) { |
388 | 389 | return val >= 0 && val <= UINT8_MAX; |
389 | 390 | } |
390 | 391 |
|
| 392 | +static LogicalType pyHomogeneousListType(const py::list& lst) { |
| 393 | + py::handle firstNonNull; |
| 394 | + for (auto child : lst) { |
| 395 | + if (!child.is_none()) { |
| 396 | + firstNonNull = child; |
| 397 | + break; |
| 398 | + } |
| 399 | + } |
| 400 | + if (!firstNonNull) { |
| 401 | + return LogicalType::LIST(LogicalType::ANY()); |
| 402 | + } |
| 403 | + if (!py::isinstance<py::bool_>(firstNonNull) && !py::isinstance<py::int_>(firstNonNull) && |
| 404 | + !py::isinstance<py::float_>(firstNonNull)) { |
| 405 | + return LogicalType::ANY(); |
| 406 | + } |
| 407 | + for (auto child : lst) { |
| 408 | + if (child.is_none()) { |
| 409 | + continue; |
| 410 | + } |
| 411 | + if (child.get_type().ptr() != firstNonNull.get_type().ptr()) { |
| 412 | + return LogicalType::ANY(); |
| 413 | + } |
| 414 | + } |
| 415 | + if (py::isinstance<py::bool_>(firstNonNull)) { |
| 416 | + return LogicalType::LIST(LogicalType::BOOL()); |
| 417 | + } |
| 418 | + if (py::isinstance<py::int_>(firstNonNull)) { |
| 419 | + return LogicalType::LIST(LogicalType::INT64()); |
| 420 | + } |
| 421 | + return LogicalType::LIST(LogicalType::DOUBLE()); |
| 422 | +} |
| 423 | + |
| 424 | +static LogicalType pyNumpyArrayLogicalType(const py::array& arr) { |
| 425 | + auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype")); |
| 426 | + auto type = NumpyTypeUtils::numpyToLogicalType(npType); |
| 427 | + for (auto i = 0; i < arr.ndim(); ++i) { |
| 428 | + type = LogicalType::LIST(std::move(type)); |
| 429 | + } |
| 430 | + return type; |
| 431 | +} |
| 432 | + |
391 | 433 | static LogicalType pyLogicalType(const py::handle& val) { |
392 | 434 | auto datetime_datetime = importCache->datetime.datetime(); |
393 | 435 | auto time_delta = importCache->datetime.timedelta(); |
@@ -468,8 +510,14 @@ static LogicalType pyLogicalType(const py::handle& val) { |
468 | 510 | childValueType = std::move(resultValue); |
469 | 511 | } |
470 | 512 | return LogicalType::MAP(std::move(childKeyType), std::move(childValueType)); |
| 513 | + } else if (py::isinstance<py::array>(val)) { |
| 514 | + return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val)); |
471 | 515 | } else if (py::isinstance<py::list>(val)) { |
472 | 516 | py::list lst = py::reinterpret_borrow<py::list>(val); |
| 517 | + auto homogeneousType = pyHomogeneousListType(lst); |
| 518 | + if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) { |
| 519 | + return homogeneousType; |
| 520 | + } |
473 | 521 | auto childType = LogicalType::ANY(); |
474 | 522 | for (auto child : lst) { |
475 | 523 | auto curChildType = pyLogicalType(child); |
@@ -568,8 +616,14 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) { |
568 | 616 | structFields.emplace_back(std::move(keyName), std::move(keyType)); |
569 | 617 | } |
570 | 618 | return LogicalType::STRUCT(std::move(structFields)); |
| 619 | + } else if (py::isinstance<py::array>(val)) { |
| 620 | + return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val)); |
571 | 621 | } else if (py::isinstance<py::list>(val)) { |
572 | 622 | py::list lst = py::reinterpret_borrow<py::list>(val); |
| 623 | + auto homogeneousType = pyHomogeneousListType(lst); |
| 624 | + if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) { |
| 625 | + return homogeneousType; |
| 626 | + } |
573 | 627 | auto childType = LogicalType::ANY(); |
574 | 628 | for (auto child : lst) { |
575 | 629 | auto curChildType = pyLogicalTypeFromParameter(child); |
@@ -603,6 +657,90 @@ static std::string pythonObjectToJsonString(const py::handle& val) { |
603 | 657 | return py::cast<std::string>(jsonStr); |
604 | 658 | } |
605 | 659 |
|
| 660 | +template<typename T> |
| 661 | +static Value transformNumpyScalarAs(const void* ptr, const LogicalType& type) { |
| 662 | + auto value = *reinterpret_cast<const T*>(ptr); |
| 663 | + switch (type.getLogicalTypeID()) { |
| 664 | + case LogicalTypeID::BOOL: |
| 665 | + return Value::createValue<bool>(static_cast<bool>(value)); |
| 666 | + case LogicalTypeID::INT64: |
| 667 | + return Value::createValue<int64_t>(static_cast<int64_t>(value)); |
| 668 | + case LogicalTypeID::UINT32: |
| 669 | + return Value::createValue<uint32_t>(static_cast<uint32_t>(value)); |
| 670 | + case LogicalTypeID::INT32: |
| 671 | + return Value::createValue<int32_t>(static_cast<int32_t>(value)); |
| 672 | + case LogicalTypeID::UINT16: |
| 673 | + return Value::createValue<uint16_t>(static_cast<uint16_t>(value)); |
| 674 | + case LogicalTypeID::INT16: |
| 675 | + return Value::createValue<int16_t>(static_cast<int16_t>(value)); |
| 676 | + case LogicalTypeID::UINT8: |
| 677 | + return Value::createValue<uint8_t>(static_cast<uint8_t>(value)); |
| 678 | + case LogicalTypeID::INT8: |
| 679 | + return Value::createValue<int8_t>(static_cast<int8_t>(value)); |
| 680 | + case LogicalTypeID::FLOAT: |
| 681 | + return Value(static_cast<float>(value)); |
| 682 | + case LogicalTypeID::DOUBLE: |
| 683 | + return Value::createValue<double>(static_cast<double>(value)); |
| 684 | + default: |
| 685 | + throw RuntimeException("Unsupported numpy ndarray parameter child type " + type.toString()); |
| 686 | + } |
| 687 | +} |
| 688 | + |
| 689 | +static Value transformNumpyScalarAs(const void* ptr, NumpyNullableType npType, |
| 690 | + const LogicalType& type) { |
| 691 | + switch (npType) { |
| 692 | + case NumpyNullableType::BOOL: |
| 693 | + return transformNumpyScalarAs<bool>(ptr, type); |
| 694 | + case NumpyNullableType::INT_8: |
| 695 | + return transformNumpyScalarAs<int8_t>(ptr, type); |
| 696 | + case NumpyNullableType::UINT_8: |
| 697 | + return transformNumpyScalarAs<uint8_t>(ptr, type); |
| 698 | + case NumpyNullableType::INT_16: |
| 699 | + return transformNumpyScalarAs<int16_t>(ptr, type); |
| 700 | + case NumpyNullableType::UINT_16: |
| 701 | + return transformNumpyScalarAs<uint16_t>(ptr, type); |
| 702 | + case NumpyNullableType::INT_32: |
| 703 | + return transformNumpyScalarAs<int32_t>(ptr, type); |
| 704 | + case NumpyNullableType::UINT_32: |
| 705 | + return transformNumpyScalarAs<uint32_t>(ptr, type); |
| 706 | + case NumpyNullableType::INT_64: |
| 707 | + return transformNumpyScalarAs<int64_t>(ptr, type); |
| 708 | + case NumpyNullableType::UINT_64: |
| 709 | + return transformNumpyScalarAs<uint64_t>(ptr, type); |
| 710 | + case NumpyNullableType::FLOAT_32: |
| 711 | + return transformNumpyScalarAs<float>(ptr, type); |
| 712 | + case NumpyNullableType::FLOAT_64: |
| 713 | + return transformNumpyScalarAs<double>(ptr, type); |
| 714 | + default: |
| 715 | + throw RuntimeException("Unsupported numpy ndarray parameter dtype"); |
| 716 | + } |
| 717 | +} |
| 718 | + |
| 719 | +static Value transformNumpyArrayAs(const LogicalType& type, uint64_t dimension, const uint8_t* ptr, |
| 720 | + const py::buffer_info& info, NumpyNullableType npType) { |
| 721 | + if (dimension == static_cast<uint64_t>(info.ndim)) { |
| 722 | + return transformNumpyScalarAs(ptr, npType, type); |
| 723 | + } |
| 724 | + if (type.getLogicalTypeID() != LogicalTypeID::LIST) { |
| 725 | + throw RuntimeException("Cannot convert numpy ndarray parameter to " + type.toString()); |
| 726 | + } |
| 727 | + std::vector<std::unique_ptr<Value>> children; |
| 728 | + children.reserve(info.shape[dimension]); |
| 729 | + const auto& childType = ListType::getChildType(type); |
| 730 | + for (auto i = 0; i < info.shape[dimension]; ++i) { |
| 731 | + auto childPtr = ptr + i * info.strides[dimension]; |
| 732 | + children.push_back(std::make_unique<Value>( |
| 733 | + transformNumpyArrayAs(childType, dimension + 1, childPtr, info, npType))); |
| 734 | + } |
| 735 | + return Value(type.copy(), std::move(children)); |
| 736 | +} |
| 737 | + |
| 738 | +static Value transformNumpyArrayAs(const py::array& arr, const LogicalType& type) { |
| 739 | + auto info = arr.request(); |
| 740 | + auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype")).type; |
| 741 | + return transformNumpyArrayAs(type, 0, reinterpret_cast<const uint8_t*>(info.ptr), info, npType); |
| 742 | +} |
| 743 | + |
606 | 744 | Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalType& type) { |
607 | 745 | // ignore the type of the actual python object, just directly cast |
608 | 746 | auto datetime_datetime = importCache->datetime.datetime(); |
@@ -632,6 +770,8 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT |
632 | 770 | return Value::createValue<int8_t>(py::cast<py::int_>(val).cast<int8_t>()); |
633 | 771 | case LogicalTypeID::DOUBLE: |
634 | 772 | return Value::createValue<double>(py::cast<py::float_>(val).cast<double>()); |
| 773 | + case LogicalTypeID::FLOAT: |
| 774 | + return Value(py::cast<py::float_>(val).cast<float>()); |
635 | 775 | case LogicalTypeID::DECIMAL: { |
636 | 776 | auto str = py::cast<std::string>(py::str(val)); |
637 | 777 | int128_t result = 0; |
@@ -708,6 +848,9 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT |
708 | 848 | return Value{uuidToAppend}; |
709 | 849 | } |
710 | 850 | case LogicalTypeID::LIST: { |
| 851 | + if (py::isinstance<py::array>(val)) { |
| 852 | + return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type); |
| 853 | + } |
711 | 854 | py::list lst = py::reinterpret_borrow<py::list>(val); |
712 | 855 | std::vector<std::unique_ptr<Value>> children; |
713 | 856 | for (auto child : lst) { |
@@ -763,6 +906,9 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val, |
763 | 906 | auto jsonStr = pythonObjectToJsonString(val); |
764 | 907 | return Value::createValue<std::string>(jsonStr); |
765 | 908 | } |
| 909 | + if (py::isinstance<py::array>(val)) { |
| 910 | + return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type); |
| 911 | + } |
766 | 912 | py::list lst = py::reinterpret_borrow<py::list>(val); |
767 | 913 | std::vector<std::unique_ptr<Value>> children; |
768 | 914 | for (auto child : lst) { |
|
0 commit comments