Skip to content

Commit 5eb7c91

Browse files
committed
Refactor pybind ownership around shared handles
1 parent 0bca48d commit 5eb7c91

10 files changed

Lines changed: 285 additions & 167 deletions

src_cpp/include/py_connection.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#pragma once
22

3+
#include <memory>
34
#include <unordered_map>
45

56
#include "main/storage_driver.h"
67
#include "py_database.h"
8+
#include "py_handle_state.h"
79
#include "py_prepared_statement.h"
810
#include "py_query_result.h"
911

@@ -29,8 +31,7 @@ class PyConnection {
2931
const py::dict& params);
3032

3133
std::unique_ptr<PyQueryResult> query(const std::string& statement);
32-
std::unique_ptr<PyQueryResult> queryAsArrow(const std::string& statement,
33-
int64_t chunkSize);
34+
std::unique_ptr<PyQueryResult> queryAsArrow(const std::string& statement, int64_t chunkSize);
3435

3536
void setMaxNumThreadForExec(uint64_t numThreads);
3637

@@ -65,12 +66,10 @@ class PyConnection {
6566
const LogicalType& type);
6667

6768
private:
68-
void checkOpen() const;
69+
PyConnectionState& refState() const;
6970

70-
std::unique_ptr<StorageDriver> storageDriver;
71-
std::unique_ptr<Connection> conn;
72-
std::unordered_map<std::string, py::object> arrowTableRefs;
71+
std::shared_ptr<PyConnectionState> state;
7372

7473
static std::unique_ptr<PyQueryResult> checkAndWrapQueryResult(
75-
std::unique_ptr<QueryResult>& queryResult);
74+
std::unique_ptr<QueryResult>& queryResult, std::shared_ptr<PyConnectionState> state);
7675
};

src_cpp/include/py_database.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#pragma once
22

3+
#include <memory>
4+
35
#include "main/lbug.h"
46
#include "main/storage_driver.h"
7+
#include "py_handle_state.h"
58
#include "pybind_include.h" // IWYU pragma: keep (used for py:: namespace)
69
#define PYBIND11_DETAILED_ERROR_MESSAGES
710
using namespace lbug::main;
@@ -30,6 +33,5 @@ class PyDatabase {
3033
const py::array_t<uint64_t>& indices, py::array_t<T>& result, int numThreads);
3134

3235
private:
33-
std::unique_ptr<Database> database;
34-
std::unique_ptr<StorageDriver> storageDriver;
36+
std::shared_ptr<PyDatabaseState> state;
3537
};

src_cpp/include/py_handle_state.h

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
#include <unordered_map>
6+
7+
#include "common/exception/runtime.h"
8+
#include "main/lbug.h"
9+
#include "main/prepared_statement.h"
10+
#include "main/storage_driver.h"
11+
#include "pybind_include.h"
12+
13+
struct PyDatabaseState {
14+
std::unique_ptr<lbug::main::Database> database;
15+
std::unique_ptr<lbug::main::StorageDriver> storageDriver;
16+
17+
~PyDatabaseState() { closeNative(); }
18+
19+
void closeNative() {
20+
storageDriver.reset();
21+
database.reset();
22+
}
23+
24+
lbug::main::Database& ref() const {
25+
if (database == nullptr) {
26+
throw lbug::common::RuntimeException("Database is closed.");
27+
}
28+
return *database;
29+
}
30+
31+
lbug::main::StorageDriver& storage() const {
32+
if (storageDriver == nullptr) {
33+
throw lbug::common::RuntimeException("Database is closed.");
34+
}
35+
return *storageDriver;
36+
}
37+
};
38+
39+
struct PyConnectionState {
40+
std::shared_ptr<PyDatabaseState> database;
41+
std::unique_ptr<lbug::main::StorageDriver> storageDriver;
42+
std::unique_ptr<lbug::main::Connection> conn;
43+
std::unordered_map<std::string, py::object> arrowTableRefs;
44+
45+
~PyConnectionState() { closeNative(); }
46+
47+
void closeNative() {
48+
arrowTableRefs.clear();
49+
conn.reset();
50+
storageDriver.reset();
51+
database.reset();
52+
}
53+
54+
lbug::main::Connection& ref() const {
55+
if (conn == nullptr) {
56+
throw lbug::common::RuntimeException("Connection is closed.");
57+
}
58+
return *conn;
59+
}
60+
61+
lbug::main::StorageDriver& storage() const {
62+
if (storageDriver == nullptr) {
63+
throw lbug::common::RuntimeException("Connection is closed.");
64+
}
65+
return *storageDriver;
66+
}
67+
};
68+
69+
struct PyPreparedStatementState {
70+
std::shared_ptr<PyConnectionState> connection;
71+
std::unique_ptr<lbug::main::PreparedStatement> preparedStatement;
72+
73+
lbug::main::PreparedStatement& ref() const {
74+
if (preparedStatement == nullptr) {
75+
throw lbug::common::RuntimeException("Prepared statement is closed.");
76+
}
77+
return *preparedStatement;
78+
}
79+
};
80+
81+
struct PyQueryResultState {
82+
std::shared_ptr<PyConnectionState> connection;
83+
std::shared_ptr<PyQueryResultState> parent;
84+
std::unique_ptr<lbug::main::QueryResult> owned;
85+
lbug::main::QueryResult* borrowed = nullptr;
86+
87+
lbug::main::QueryResult& ref() const {
88+
auto* result = owned != nullptr ? owned.get() : borrowed;
89+
if (result == nullptr) {
90+
throw lbug::common::RuntimeException("Query result is closed.");
91+
}
92+
return *result;
93+
}
94+
};

src_cpp/include/py_prepared_statement.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "main/lbug.h"
44
#include "main/prepared_statement.h"
5+
#include "py_handle_state.h"
56
#include "pybind_include.h"
67

78
using namespace lbug::main;
@@ -17,5 +18,5 @@ class PyPreparedStatement {
1718
bool isSuccess() const;
1819

1920
private:
20-
std::unique_ptr<PreparedStatement> preparedStatement;
21+
std::shared_ptr<PyPreparedStatementState> state;
2122
};

src_cpp/include/py_query_result.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "arrow_array.h"
77
#include "common/arrow/arrow.h"
88
#include "main/lbug.h"
9+
#include "py_handle_state.h"
910
#include "pybind_include.h"
1011

1112
using namespace lbug::main;
@@ -54,7 +55,7 @@ class PyQueryResult {
5455
size_t getNumTuples();
5556

5657
private:
57-
void checkOpen() const;
58+
PyQueryResultState& refState() const;
5859

5960
static py::dict convertNodeIdToPyDict(const lbug::common::nodeID_t& nodeId);
6061

@@ -65,6 +66,5 @@ class PyQueryResult {
6566
const std::vector<std::string>& names, std::int64_t chunkSize, bool fallbackExtensionTypes);
6667

6768
private:
68-
QueryResult* queryResult = nullptr;
69-
bool isOwned = false;
69+
std::shared_ptr<PyQueryResultState> state;
7070
};

0 commit comments

Comments
 (0)