Skip to content

Commit 0ea6dc3

Browse files
committed
fix: handle mixed nested/non-nested types in Python list parameters
When converting Python list parameters with mixed types (e.g., [str, str, dict]), the type inference would incorrectly combine STRING and STRUCT types due to STRING being treated as a universal cast type. This caused data corruption where list elements were converted to the wrong types. The fix: 1. Detect mixed nested/non-nested type combinations in lists and return JSON type 2. Parse JSON strings passed as parameters to their proper Python objects 3. Handle JSON type conversion for lists with mixed element types Fixes test_to_json_string_param_roundtrip which failed when passing JSON strings containing nested structures with mixed-type arrays like @context.
1 parent 69c526d commit 0ea6dc3

2 files changed

Lines changed: 52 additions & 1191 deletions

File tree

src_cpp/py_connection.cpp

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include "common/constants.h"
77
#include "common/exception/not_implemented.h"
88
#include "common/exception/runtime.h"
9+
#include "common/json_utils.h"
10+
#include "common/types/json_type.h"
911
#include "common/types/uuid.h"
1012
#include "common/utils.h"
1113
#include "datetime.h" // from Python
@@ -391,6 +393,14 @@ static LogicalType pyLogicalType(const py::handle& val) {
391393
}
392394
return LogicalType::DECIMAL(precision, -exponent);
393395
} else if (py::isinstance<py::str>(val)) {
396+
auto strVal = py::cast<std::string>(val);
397+
if (!strVal.empty() && (strVal.front() == '{' || strVal.front() == '[')) {
398+
auto jsonModule = py::module_::import("json");
399+
try {
400+
auto parsed = jsonModule.attr("loads")(val);
401+
return pyLogicalType(parsed);
402+
} catch (...) {}
403+
}
394404
return LogicalType::STRING();
395405
} else if (py::isinstance<py::bytes>(val)) {
396406
return LogicalType::BLOB();
@@ -525,6 +535,15 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) {
525535
auto childType = LogicalType::ANY();
526536
for (auto child : lst) {
527537
auto curChildType = pyLogicalTypeFromParameter(child);
538+
if (childType.getLogicalTypeID() != LogicalTypeID::ANY &&
539+
childType.getLogicalTypeID() != curChildType.getLogicalTypeID()) {
540+
if ((LogicalTypeUtils::isNested(childType.getLogicalTypeID()) &&
541+
!LogicalTypeUtils::isNested(curChildType.getLogicalTypeID())) ||
542+
(!LogicalTypeUtils::isNested(childType.getLogicalTypeID()) &&
543+
LogicalTypeUtils::isNested(curChildType.getLogicalTypeID()))) {
544+
return LogicalType::JSON();
545+
}
546+
}
528547
LogicalType result;
529548
if (!LogicalTypeUtils::tryGetMaxLogicalType(childType, curChildType, result)) {
530549
throw RuntimeException(std::format(
@@ -539,6 +558,12 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) {
539558
}
540559
}
541560

561+
static std::string pythonObjectToJsonString(const py::handle& val) {
562+
auto jsonModule = py::module_::import("json");
563+
auto jsonStr = jsonModule.attr("dumps")(val);
564+
return py::cast<std::string>(jsonStr);
565+
}
566+
542567
Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalType& type) {
543568
// ignore the type of the actual python object, just directly cast
544569
auto datetime_datetime = importCache->datetime.datetime();
@@ -578,7 +603,15 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
578603
}
579604
case LogicalTypeID::STRING:
580605
if (py::isinstance<py::str>(val)) {
581-
return Value::createValue<std::string>(val.cast<std::string>());
606+
auto strVal = val.cast<std::string>();
607+
if (!strVal.empty() && (strVal.front() == '{' || strVal.front() == '[')) {
608+
auto jsonModule = py::module_::import("json");
609+
try {
610+
auto parsed = jsonModule.attr("loads")(val);
611+
return transformPythonValue(parsed);
612+
} catch (...) {}
613+
}
614+
return Value::createValue<std::string>(strVal);
582615
} else {
583616
return Value::createValue<std::string>(py::str(val));
584617
}
@@ -694,8 +727,22 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
694727

695728
Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val,
696729
const LogicalType& type) {
730+
if (py::isinstance<py::str>(val)) {
731+
auto strVal = py::cast<std::string>(val);
732+
if (!strVal.empty() && (strVal.front() == '{' || strVal.front() == '[')) {
733+
auto jsonModule = py::module_::import("json");
734+
try {
735+
auto parsed = jsonModule.attr("loads")(val);
736+
return transformPythonValueFromParameter(parsed);
737+
} catch (...) {}
738+
}
739+
}
697740
switch (type.getLogicalTypeID()) {
698741
case LogicalTypeID::LIST: {
742+
if (ListType::getChildType(type).getLogicalTypeID() == LogicalTypeID::JSON) {
743+
auto jsonStr = pythonObjectToJsonString(val);
744+
return Value::createValue<std::string>(jsonStr);
745+
}
699746
py::list lst = py::reinterpret_borrow<py::list>(val);
700747
std::vector<std::unique_ptr<Value>> children;
701748
for (auto child : lst) {
@@ -737,6 +784,10 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val,
737784
}
738785
return Value(type.copy(), std::move(children));
739786
}
787+
case LogicalTypeID::JSON: {
788+
auto jsonStr = pythonObjectToJsonString(val);
789+
return Value::createValue<std::string>(jsonStr);
790+
}
740791
case LogicalTypeID::POINTER: {
741792
return Value::createValue(reinterpret_cast<uint8_t*>(val.ptr()));
742793
}

0 commit comments

Comments
 (0)