Skip to content

Commit c6d6831

Browse files
committed
Infer JSON parameters in Python binding
1 parent 7be676a commit c6d6831

1 file changed

Lines changed: 42 additions & 23 deletions

File tree

src_cpp/py_connection.cpp

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,8 @@ static LogicalType pyLogicalType(const py::handle& val) {
397397
if (!strVal.empty() && (strVal.front() == '{' || strVal.front() == '[')) {
398398
auto jsonModule = py::module_::import("json");
399399
try {
400-
auto parsed = jsonModule.attr("loads")(val);
401-
return pyLogicalType(parsed);
400+
jsonModule.attr("loads")(val);
401+
return LogicalType::JSON();
402402
} catch (...) {}
403403
}
404404
return LogicalType::STRING();
@@ -414,6 +414,9 @@ static LogicalType pyLogicalType(const py::handle& val) {
414414
return LogicalType::UUID();
415415
} else if (py::isinstance<py::dict>(val)) {
416416
py::dict dict = py::reinterpret_borrow<py::dict>(val);
417+
if (dict.empty()) {
418+
return LogicalType::JSON();
419+
}
417420
auto childKeyType = LogicalType::ANY(), childValueType = LogicalType::ANY();
418421
for (auto child : dict) {
419422
auto curChildKeyType = pyLogicalType(child.first),
@@ -433,6 +436,9 @@ static LogicalType pyLogicalType(const py::handle& val) {
433436
childKeyType = std::move(resultKey);
434437
childValueType = std::move(resultValue);
435438
}
439+
if (childKeyType.containsAny() || childValueType.containsAny()) {
440+
return LogicalType::JSON();
441+
}
436442
return LogicalType::MAP(std::move(childKeyType), std::move(childValueType));
437443
} else if (py::isinstance<py::list>(val)) {
438444
py::list lst = py::reinterpret_borrow<py::list>(val);
@@ -447,6 +453,9 @@ static LogicalType pyLogicalType(const py::handle& val) {
447453
}
448454
childType = std::move(result);
449455
}
456+
if (childType.containsAny()) {
457+
return LogicalType::JSON();
458+
}
450459
return LogicalType::LIST(std::move(childType));
451460
} else if (PyConnection::isPyArrowTable(val) || PyConnection::isPandasDataframe(val) ||
452461
PyConnection::isPolarsDataframe(val)) {
@@ -517,10 +526,28 @@ static bool tryCastToMap(py::dict& dict, LogicalType& result) {
517526
}
518527

519528
static LogicalType pyLogicalTypeFromParameter(const py::handle& val) {
529+
if (py::isinstance<py::str>(val)) {
530+
auto strVal = py::cast<std::string>(val);
531+
if (!strVal.empty() && (strVal.front() == '{' || strVal.front() == '[')) {
532+
auto jsonModule = py::module_::import("json");
533+
try {
534+
jsonModule.attr("loads")(val);
535+
return LogicalType::JSON();
536+
} catch (...) {}
537+
}
538+
return LogicalType::STRING();
539+
}
520540
if (py::isinstance<py::dict>(val)) {
521541
auto dict = py::reinterpret_borrow<py::dict>(val);
542+
if (dict.empty()) {
543+
return LogicalType::JSON();
544+
}
522545
LogicalType resultType;
523546
if (tryCastToMap(dict, resultType)) {
547+
if (MapType::getKeyType(resultType).containsAny() ||
548+
MapType::getValueType(resultType).containsAny()) {
549+
return LogicalType::JSON();
550+
}
524551
return resultType;
525552
}
526553
auto structFields = std::vector<StructField>{};
@@ -529,6 +556,11 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) {
529556
auto keyType = pyLogicalTypeFromParameter(child.second);
530557
structFields.emplace_back(std::move(keyName), std::move(keyType));
531558
}
559+
for (const auto& field : structFields) {
560+
if (field.getType().containsAny()) {
561+
return LogicalType::JSON();
562+
}
563+
}
532564
return LogicalType::STRUCT(std::move(structFields));
533565
} else if (py::isinstance<py::list>(val)) {
534566
py::list lst = py::reinterpret_borrow<py::list>(val);
@@ -552,6 +584,9 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) {
552584
}
553585
childType = std::move(result);
554586
}
587+
if (childType.containsAny()) {
588+
return LogicalType::JSON();
589+
}
555590
return LogicalType::LIST(std::move(childType));
556591
} else {
557592
return pyLogicalType(val);
@@ -603,18 +638,9 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
603638
}
604639
case LogicalTypeID::STRING:
605640
if (py::isinstance<py::str>(val)) {
606-
auto strVal = val.cast<std::string>();
607-
if (!strVal.empty() && (strVal.front() == '{' || strVal.front() == '[')) {
608-
auto jsonModule = py::module_::import("json");
609-
try {
610-
auto parsed = jsonModule.attr("loads")(val);
611-
return transformPythonValue(parsed);
612-
} catch (...) {}
613-
}
614-
return Value::createValue<std::string>(strVal);
615-
} else {
616-
return Value::createValue<std::string>(py::str(val));
641+
return Value::createValue<std::string>(val.cast<std::string>());
617642
}
643+
return Value::createValue<std::string>(py::str(val));
618644
case LogicalTypeID::BLOB: {
619645
auto bytes = py::cast<py::bytes>(val);
620646
const char* data = PyBytes_AsString(bytes.ptr());
@@ -727,16 +753,6 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
727753

728754
Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val,
729755
const LogicalType& type) {
730-
if (py::isinstance<py::str>(val)) {
731-
auto strVal = py::cast<std::string>(val);
732-
if (!strVal.empty() && (strVal.front() == '{' || strVal.front() == '[')) {
733-
auto jsonModule = py::module_::import("json");
734-
try {
735-
auto parsed = jsonModule.attr("loads")(val);
736-
return transformPythonValueFromParameter(parsed);
737-
} catch (...) {}
738-
}
739-
}
740756
switch (type.getLogicalTypeID()) {
741757
case LogicalTypeID::LIST: {
742758
if (ListType::getChildType(type).getLogicalTypeID() == LogicalTypeID::JSON) {
@@ -785,6 +801,9 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val,
785801
return Value(type.copy(), std::move(children));
786802
}
787803
case LogicalTypeID::JSON: {
804+
if (py::isinstance<py::str>(val)) {
805+
return Value::createValue<std::string>(py::cast<std::string>(val));
806+
}
788807
auto jsonStr = pythonObjectToJsonString(val);
789808
return Value::createValue<std::string>(jsonStr);
790809
}

0 commit comments

Comments
 (0)