|
15 | 15 | #include "include/py_udf.h" |
16 | 16 | #include "main/connection.h" |
17 | 17 | #include "main/query_result/materialized_query_result.h" |
| 18 | +#include "numpy/numpy_type.h" |
18 | 19 | #include "pandas/pandas_scan.h" |
19 | 20 | #include "processor/result/factorized_table.h" |
20 | 21 | #include "pyarrow/pyarrow_scan.h" |
@@ -358,6 +359,47 @@ bool integerFitsIn<uint8_t>(int64_t val) { |
358 | 359 | return val >= 0 && val <= UINT8_MAX; |
359 | 360 | } |
360 | 361 |
|
| 362 | +static LogicalType pyHomogeneousListType(const py::list& lst) { |
| 363 | + py::handle firstNonNull; |
| 364 | + for (auto child : lst) { |
| 365 | + if (!child.is_none()) { |
| 366 | + firstNonNull = child; |
| 367 | + break; |
| 368 | + } |
| 369 | + } |
| 370 | + if (!firstNonNull) { |
| 371 | + return LogicalType::LIST(LogicalType::ANY()); |
| 372 | + } |
| 373 | + if (!py::isinstance<py::bool_>(firstNonNull) && !py::isinstance<py::int_>(firstNonNull) && |
| 374 | + !py::isinstance<py::float_>(firstNonNull)) { |
| 375 | + return LogicalType::ANY(); |
| 376 | + } |
| 377 | + for (auto child : lst) { |
| 378 | + if (child.is_none()) { |
| 379 | + continue; |
| 380 | + } |
| 381 | + if (child.get_type().ptr() != firstNonNull.get_type().ptr()) { |
| 382 | + return LogicalType::ANY(); |
| 383 | + } |
| 384 | + } |
| 385 | + if (py::isinstance<py::bool_>(firstNonNull)) { |
| 386 | + return LogicalType::LIST(LogicalType::BOOL()); |
| 387 | + } |
| 388 | + if (py::isinstance<py::int_>(firstNonNull)) { |
| 389 | + return LogicalType::LIST(LogicalType::INT64()); |
| 390 | + } |
| 391 | + return LogicalType::LIST(LogicalType::DOUBLE()); |
| 392 | +} |
| 393 | + |
| 394 | +static LogicalType pyNumpyArrayLogicalType(const py::array& arr) { |
| 395 | + auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype")); |
| 396 | + auto type = NumpyTypeUtils::numpyToLogicalType(npType); |
| 397 | + for (auto i = 0; i < arr.ndim(); ++i) { |
| 398 | + type = LogicalType::LIST(std::move(type)); |
| 399 | + } |
| 400 | + return type; |
| 401 | +} |
| 402 | + |
361 | 403 | static LogicalType pyLogicalType(const py::handle& val) { |
362 | 404 | auto datetime_datetime = importCache->datetime.datetime(); |
363 | 405 | auto time_delta = importCache->datetime.timedelta(); |
@@ -436,8 +478,14 @@ static LogicalType pyLogicalType(const py::handle& val) { |
436 | 478 | childValueType = std::move(resultValue); |
437 | 479 | } |
438 | 480 | return LogicalType::MAP(std::move(childKeyType), std::move(childValueType)); |
| 481 | + } else if (py::isinstance<py::array>(val)) { |
| 482 | + return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val)); |
439 | 483 | } else if (py::isinstance<py::list>(val)) { |
440 | 484 | py::list lst = py::reinterpret_borrow<py::list>(val); |
| 485 | + auto homogeneousType = pyHomogeneousListType(lst); |
| 486 | + if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) { |
| 487 | + return homogeneousType; |
| 488 | + } |
441 | 489 | auto childType = LogicalType::ANY(); |
442 | 490 | for (auto child : lst) { |
443 | 491 | auto curChildType = pyLogicalType(child); |
@@ -535,8 +583,14 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) { |
535 | 583 | structFields.emplace_back(std::move(keyName), std::move(keyType)); |
536 | 584 | } |
537 | 585 | return LogicalType::STRUCT(std::move(structFields)); |
| 586 | + } else if (py::isinstance<py::array>(val)) { |
| 587 | + return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val)); |
538 | 588 | } else if (py::isinstance<py::list>(val)) { |
539 | 589 | py::list lst = py::reinterpret_borrow<py::list>(val); |
| 590 | + auto homogeneousType = pyHomogeneousListType(lst); |
| 591 | + if (homogeneousType.getLogicalTypeID() != LogicalTypeID::ANY) { |
| 592 | + return homogeneousType; |
| 593 | + } |
540 | 594 | auto childType = LogicalType::ANY(); |
541 | 595 | for (auto child : lst) { |
542 | 596 | auto curChildType = pyLogicalTypeFromParameter(child); |
@@ -569,6 +623,90 @@ static std::string pythonObjectToJsonString(const py::handle& val) { |
569 | 623 | return py::cast<std::string>(jsonStr); |
570 | 624 | } |
571 | 625 |
|
| 626 | +template<typename T> |
| 627 | +static Value transformNumpyScalarAs(const void* ptr, const LogicalType& type) { |
| 628 | + auto value = *reinterpret_cast<const T*>(ptr); |
| 629 | + switch (type.getLogicalTypeID()) { |
| 630 | + case LogicalTypeID::BOOL: |
| 631 | + return Value::createValue<bool>(static_cast<bool>(value)); |
| 632 | + case LogicalTypeID::INT64: |
| 633 | + return Value::createValue<int64_t>(static_cast<int64_t>(value)); |
| 634 | + case LogicalTypeID::UINT32: |
| 635 | + return Value::createValue<uint32_t>(static_cast<uint32_t>(value)); |
| 636 | + case LogicalTypeID::INT32: |
| 637 | + return Value::createValue<int32_t>(static_cast<int32_t>(value)); |
| 638 | + case LogicalTypeID::UINT16: |
| 639 | + return Value::createValue<uint16_t>(static_cast<uint16_t>(value)); |
| 640 | + case LogicalTypeID::INT16: |
| 641 | + return Value::createValue<int16_t>(static_cast<int16_t>(value)); |
| 642 | + case LogicalTypeID::UINT8: |
| 643 | + return Value::createValue<uint8_t>(static_cast<uint8_t>(value)); |
| 644 | + case LogicalTypeID::INT8: |
| 645 | + return Value::createValue<int8_t>(static_cast<int8_t>(value)); |
| 646 | + case LogicalTypeID::FLOAT: |
| 647 | + return Value(static_cast<float>(value)); |
| 648 | + case LogicalTypeID::DOUBLE: |
| 649 | + return Value::createValue<double>(static_cast<double>(value)); |
| 650 | + default: |
| 651 | + throw RuntimeException("Unsupported numpy ndarray parameter child type " + type.toString()); |
| 652 | + } |
| 653 | +} |
| 654 | + |
| 655 | +static Value transformNumpyScalarAs(const void* ptr, NumpyNullableType npType, |
| 656 | + const LogicalType& type) { |
| 657 | + switch (npType) { |
| 658 | + case NumpyNullableType::BOOL: |
| 659 | + return transformNumpyScalarAs<bool>(ptr, type); |
| 660 | + case NumpyNullableType::INT_8: |
| 661 | + return transformNumpyScalarAs<int8_t>(ptr, type); |
| 662 | + case NumpyNullableType::UINT_8: |
| 663 | + return transformNumpyScalarAs<uint8_t>(ptr, type); |
| 664 | + case NumpyNullableType::INT_16: |
| 665 | + return transformNumpyScalarAs<int16_t>(ptr, type); |
| 666 | + case NumpyNullableType::UINT_16: |
| 667 | + return transformNumpyScalarAs<uint16_t>(ptr, type); |
| 668 | + case NumpyNullableType::INT_32: |
| 669 | + return transformNumpyScalarAs<int32_t>(ptr, type); |
| 670 | + case NumpyNullableType::UINT_32: |
| 671 | + return transformNumpyScalarAs<uint32_t>(ptr, type); |
| 672 | + case NumpyNullableType::INT_64: |
| 673 | + return transformNumpyScalarAs<int64_t>(ptr, type); |
| 674 | + case NumpyNullableType::UINT_64: |
| 675 | + return transformNumpyScalarAs<uint64_t>(ptr, type); |
| 676 | + case NumpyNullableType::FLOAT_32: |
| 677 | + return transformNumpyScalarAs<float>(ptr, type); |
| 678 | + case NumpyNullableType::FLOAT_64: |
| 679 | + return transformNumpyScalarAs<double>(ptr, type); |
| 680 | + default: |
| 681 | + throw RuntimeException("Unsupported numpy ndarray parameter dtype"); |
| 682 | + } |
| 683 | +} |
| 684 | + |
| 685 | +static Value transformNumpyArrayAs(const LogicalType& type, uint64_t dimension, const uint8_t* ptr, |
| 686 | + const py::buffer_info& info, NumpyNullableType npType) { |
| 687 | + if (dimension == static_cast<uint64_t>(info.ndim)) { |
| 688 | + return transformNumpyScalarAs(ptr, npType, type); |
| 689 | + } |
| 690 | + if (type.getLogicalTypeID() != LogicalTypeID::LIST) { |
| 691 | + throw RuntimeException("Cannot convert numpy ndarray parameter to " + type.toString()); |
| 692 | + } |
| 693 | + std::vector<std::unique_ptr<Value>> children; |
| 694 | + children.reserve(info.shape[dimension]); |
| 695 | + const auto& childType = ListType::getChildType(type); |
| 696 | + for (auto i = 0; i < info.shape[dimension]; ++i) { |
| 697 | + auto childPtr = ptr + i * info.strides[dimension]; |
| 698 | + children.push_back(std::make_unique<Value>( |
| 699 | + transformNumpyArrayAs(childType, dimension + 1, childPtr, info, npType))); |
| 700 | + } |
| 701 | + return Value(type.copy(), std::move(children)); |
| 702 | +} |
| 703 | + |
| 704 | +static Value transformNumpyArrayAs(const py::array& arr, const LogicalType& type) { |
| 705 | + auto info = arr.request(); |
| 706 | + auto npType = NumpyTypeUtils::convertNumpyType(arr.attr("dtype")).type; |
| 707 | + return transformNumpyArrayAs(type, 0, reinterpret_cast<const uint8_t*>(info.ptr), info, npType); |
| 708 | +} |
| 709 | + |
572 | 710 | Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalType& type) { |
573 | 711 | // ignore the type of the actual python object, just directly cast |
574 | 712 | auto datetime_datetime = importCache->datetime.datetime(); |
@@ -598,6 +736,8 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT |
598 | 736 | return Value::createValue<int8_t>(py::cast<py::int_>(val).cast<int8_t>()); |
599 | 737 | case LogicalTypeID::DOUBLE: |
600 | 738 | return Value::createValue<double>(py::cast<py::float_>(val).cast<double>()); |
| 739 | + case LogicalTypeID::FLOAT: |
| 740 | + return Value(py::cast<py::float_>(val).cast<float>()); |
601 | 741 | case LogicalTypeID::DECIMAL: { |
602 | 742 | auto str = py::cast<std::string>(py::str(val)); |
603 | 743 | int128_t result = 0; |
@@ -674,6 +814,9 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT |
674 | 814 | return Value{uuidToAppend}; |
675 | 815 | } |
676 | 816 | case LogicalTypeID::LIST: { |
| 817 | + if (py::isinstance<py::array>(val)) { |
| 818 | + return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type); |
| 819 | + } |
677 | 820 | py::list lst = py::reinterpret_borrow<py::list>(val); |
678 | 821 | std::vector<std::unique_ptr<Value>> children; |
679 | 822 | for (auto child : lst) { |
@@ -729,6 +872,9 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val, |
729 | 872 | auto jsonStr = pythonObjectToJsonString(val); |
730 | 873 | return Value::createValue<std::string>(jsonStr); |
731 | 874 | } |
| 875 | + if (py::isinstance<py::array>(val)) { |
| 876 | + return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type); |
| 877 | + } |
732 | 878 | py::list lst = py::reinterpret_borrow<py::list>(val); |
733 | 879 | std::vector<std::unique_ptr<Value>> children; |
734 | 880 | for (auto child : lst) { |
|
0 commit comments