|
1 | 1 | #include "include/py_connection.h" |
2 | 2 |
|
| 3 | +#include <algorithm> |
| 4 | +#include <cctype> |
3 | 5 | #include <utility> |
4 | 6 |
|
5 | 7 | #include "cached_import/py_cached_import.h" |
@@ -52,7 +54,8 @@ void PyConnection::initialize(py::handle& m) { |
52 | 54 | .def("create_arrow_table", &PyConnection::createArrowTable, py::arg("table_name"), |
53 | 55 | py::arg("arrow_table")) |
54 | 56 | .def("create_arrow_rel_table", &PyConnection::createArrowRelTable, py::arg("table_name"), |
55 | | - py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name")) |
| 57 | + py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"), |
| 58 | + py::arg("layout") = "FLAT", py::arg("indptr_table") = py::none()) |
56 | 59 | .def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name")); |
57 | 60 | PyDateTime_IMPORT; |
58 | 61 | } |
@@ -1013,79 +1016,91 @@ void PyConnection::removeScalarFunction(const std::string& name) { |
1013 | 1016 | refState().ref().removeUDFFunction(name); |
1014 | 1017 | } |
1015 | 1018 |
|
1016 | | -std::unique_ptr<PyQueryResult> PyConnection::createArrowTable(const std::string& tableName, |
1017 | | - py::object arrowTable) { |
1018 | | - auto& stateRef = refState(); |
1019 | | - py::gil_scoped_acquire acquire; |
| 1019 | +struct ExportedArrowTable { |
| 1020 | + ArrowSchemaWrapper schema; |
| 1021 | + std::vector<ArrowArrayWrapper> arrays; |
| 1022 | + py::list keepAlive; |
| 1023 | +}; |
1020 | 1024 |
|
1021 | | - // Convert pandas/polars to pyarrow if needed |
| 1025 | +static py::object normalizeArrowTable(py::object arrowTable) { |
1022 | 1026 | if (PyConnection::isPandasDataframe(arrowTable)) { |
1023 | | - arrowTable = importCache->pyarrow.lib.Table.from_pandas()(arrowTable); |
1024 | | - } else if (PyConnection::isPolarsDataframe(arrowTable)) { |
1025 | | - arrowTable = arrowTable.attr("to_arrow")(); |
| 1027 | + return importCache->pyarrow.lib.Table.from_pandas()(arrowTable); |
| 1028 | + } |
| 1029 | + if (PyConnection::isPolarsDataframe(arrowTable)) { |
| 1030 | + return arrowTable.attr("to_arrow")(); |
1026 | 1031 | } |
1027 | | - |
1028 | | - // Ensure we have a pyarrow table |
1029 | 1032 | if (!PyConnection::isPyArrowTable(arrowTable)) { |
1030 | 1033 | throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame"); |
1031 | 1034 | } |
| 1035 | + return arrowTable; |
| 1036 | +} |
1032 | 1037 |
|
1033 | | - // Export Arrow table to C Data Interface |
1034 | | - // First, get the schema |
1035 | | - ArrowSchemaWrapper schema; |
1036 | | - arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema)); |
| 1038 | +static ExportedArrowTable exportArrowTable(py::object arrowTable) { |
| 1039 | + arrowTable = normalizeArrowTable(std::move(arrowTable)); |
| 1040 | + |
| 1041 | + ExportedArrowTable exported; |
| 1042 | + arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&exported.schema)); |
1037 | 1043 |
|
1038 | | - // Get the batches (arrays) |
1039 | | - std::vector<ArrowArrayWrapper> arrays; |
1040 | 1044 | py::list batches = arrowTable.attr("to_batches")(); |
1041 | 1045 | for (auto& batch : batches) { |
1042 | | - arrays.emplace_back(); |
1043 | | - batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&arrays.back())); |
| 1046 | + exported.arrays.emplace_back(); |
| 1047 | + batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&exported.arrays.back())); |
1044 | 1048 | } |
1045 | 1049 |
|
1046 | 1050 | // Keep pyarrow producers alive while C++ accesses exported Arrow memory. |
1047 | | - py::list keepAlive; |
1048 | | - keepAlive.append(arrowTable); |
1049 | | - keepAlive.append(batches); |
| 1051 | + exported.keepAlive.append(arrowTable); |
| 1052 | + exported.keepAlive.append(batches); |
| 1053 | + return exported; |
| 1054 | +} |
| 1055 | + |
| 1056 | +std::unique_ptr<PyQueryResult> PyConnection::createArrowTable(const std::string& tableName, |
| 1057 | + py::object arrowTable) { |
| 1058 | + auto& stateRef = refState(); |
| 1059 | + py::gil_scoped_acquire acquire; |
1050 | 1060 |
|
| 1061 | + auto exported = exportArrowTable(std::move(arrowTable)); |
1051 | 1062 | auto result = ArrowTableSupport::createViewFromArrowTable(stateRef.ref(), tableName, |
1052 | | - std::move(schema), std::move(arrays)); |
| 1063 | + std::move(exported.schema), std::move(exported.arrays)); |
1053 | 1064 | if (result.queryResult && result.queryResult->isSuccess()) { |
1054 | | - stateRef.arrowTableRefs[tableName] = std::move(keepAlive); |
| 1065 | + stateRef.arrowTableRefs[tableName] = std::move(exported.keepAlive); |
1055 | 1066 | } |
1056 | 1067 |
|
1057 | 1068 | return checkAndWrapQueryResult(result.queryResult, state); |
1058 | 1069 | } |
1059 | 1070 |
|
1060 | 1071 | std::unique_ptr<PyQueryResult> PyConnection::createArrowRelTable(const std::string& tableName, |
1061 | | - py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName) { |
| 1072 | + py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName, |
| 1073 | + const std::string& layout, py::object indptrTable) { |
1062 | 1074 | auto& stateRef = refState(); |
1063 | 1075 | py::gil_scoped_acquire acquire; |
1064 | 1076 |
|
1065 | | - if (PyConnection::isPandasDataframe(arrowTable)) { |
1066 | | - arrowTable = importCache->pyarrow.lib.Table.from_pandas()(arrowTable); |
1067 | | - } else if (PyConnection::isPolarsDataframe(arrowTable)) { |
1068 | | - arrowTable = arrowTable.attr("to_arrow")(); |
1069 | | - } |
1070 | | - if (!PyConnection::isPyArrowTable(arrowTable)) { |
1071 | | - throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame"); |
1072 | | - } |
1073 | | - |
1074 | | - ArrowSchemaWrapper schema; |
1075 | | - arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema)); |
1076 | | - std::vector<ArrowArrayWrapper> arrays; |
1077 | | - py::list batches = arrowTable.attr("to_batches")(); |
1078 | | - for (auto& batch : batches) { |
1079 | | - arrays.emplace_back(); |
1080 | | - batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&arrays.back())); |
1081 | | - } |
| 1077 | + auto layoutUpper = layout; |
| 1078 | + std::transform(layoutUpper.begin(), layoutUpper.end(), layoutUpper.begin(), |
| 1079 | + [](unsigned char c) { return static_cast<char>(std::toupper(c)); }); |
1082 | 1080 |
|
| 1081 | + auto exported = exportArrowTable(std::move(arrowTable)); |
| 1082 | + ArrowTableCreationResult result; |
1083 | 1083 | py::list keepAlive; |
1084 | | - keepAlive.append(arrowTable); |
1085 | | - keepAlive.append(batches); |
| 1084 | + keepAlive.append(exported.keepAlive); |
1086 | 1085 |
|
1087 | | - auto result = ArrowTableSupport::createRelTableFromArrowTable(stateRef.ref(), tableName, |
1088 | | - srcTableName, dstTableName, std::move(schema), std::move(arrays)); |
| 1086 | + if (layoutUpper == "FLAT") { |
| 1087 | + if (!py::none().is(indptrTable)) { |
| 1088 | + throw RuntimeException("indptr_table is only valid for CSR Arrow relationship tables"); |
| 1089 | + } |
| 1090 | + result = ArrowTableSupport::createRelTableFromArrowTable(stateRef.ref(), tableName, |
| 1091 | + srcTableName, dstTableName, std::move(exported.schema), std::move(exported.arrays)); |
| 1092 | + } else if (layoutUpper == "CSR") { |
| 1093 | + if (py::none().is(indptrTable)) { |
| 1094 | + throw RuntimeException("indptr_table is required for CSR Arrow relationship tables"); |
| 1095 | + } |
| 1096 | + auto exportedIndptr = exportArrowTable(std::move(indptrTable)); |
| 1097 | + keepAlive.append(exportedIndptr.keepAlive); |
| 1098 | + result = ArrowTableSupport::createRelTableFromArrowCSR(stateRef.ref(), tableName, |
| 1099 | + srcTableName, dstTableName, std::move(exported.schema), std::move(exported.arrays), |
| 1100 | + std::move(exportedIndptr.schema), std::move(exportedIndptr.arrays)); |
| 1101 | + } else { |
| 1102 | + throw RuntimeException("Arrow relationship table layout must be FLAT or CSR"); |
| 1103 | + } |
1089 | 1104 | if (result.queryResult && result.queryResult->isSuccess()) { |
1090 | 1105 | stateRef.arrowTableRefs[tableName] = std::move(keepAlive); |
1091 | 1106 | } |
|
0 commit comments