Skip to content

Commit 75165c7

Browse files
committed
Use lazy imports so tests pass without C-API
1 parent ff9c0ab commit 75165c7

3 files changed

Lines changed: 20 additions & 31 deletions

File tree

src_py/connection.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,10 @@
66
from typing import TYPE_CHECKING, Any
77
from weakref import WeakSet
88

9-
from . import _lbug_capi as _lbug
9+
from ._backend import get_capi_module, get_pybind_module
1010
from .prepared_statement import PreparedStatement
1111
from .query_result import QueryResult
1212

13-
try:
14-
from . import _lbug as _lbug_pybind
15-
except (
16-
ImportError
17-
): # pragma: no cover - pybind module may be unavailable in some builds
18-
_lbug_pybind = None
19-
2013
if TYPE_CHECKING:
2114
import sys
2215
from collections.abc import Callable
@@ -73,12 +66,14 @@ def init_connection(self) -> None:
7366
self.database.init_database()
7467
if self._connection is None:
7568
backend_module = (
76-
_lbug_pybind if self.database._use_pybind_backend else _lbug
69+
get_pybind_module()
70+
if self.database._use_pybind_backend
71+
else get_capi_module()
7772
)
78-
self._connection = backend_module.Connection(self.database._database, self.num_threads) # type: ignore[union-attr]
73+
self._connection = backend_module.Connection(self.database._database, self.num_threads)
7974

8075
def _using_pybind_backend(self) -> bool:
81-
return bool(self.database._use_pybind_backend and _lbug_pybind is not None)
76+
return bool(self.database._use_pybind_backend and get_pybind_module() is not None)
8277

8378
def set_max_threads_for_exec(self, num_threads: int) -> None:
8479
"""
@@ -230,7 +225,8 @@ def _should_use_pybind_for_scan(
230225
return False
231226

232227
def _get_pybind_connection(self) -> Any | None:
233-
if _lbug_pybind is None:
228+
pybind_module = get_pybind_module()
229+
if pybind_module is None:
234230
return None
235231
if self._using_pybind_backend():
236232
return self._connection
@@ -239,7 +235,7 @@ def _get_pybind_connection(self) -> Any | None:
239235
if pybind_db is None:
240236
return None
241237
if self._py_connection is None:
242-
self._py_connection = _lbug_pybind.Connection(pybind_db, self.num_threads)
238+
self._py_connection = pybind_module.Connection(pybind_db, self.num_threads)
243239
return self._py_connection
244240

245241
def _execute_with_pybind(

src_py/database.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,9 @@
55
from typing import TYPE_CHECKING, Any, ClassVar
66
from weakref import WeakSet
77

8-
from . import _lbug_capi as _lbug
8+
from ._backend import get_capi_module, get_pybind_module
99
from .types import Type
1010

11-
try:
12-
from . import _lbug as _lbug_pybind
13-
except (
14-
ImportError
15-
): # pragma: no cover - pybind module may be unavailable in some builds
16-
_lbug_pybind = None
17-
1811
if TYPE_CHECKING:
1912
import sys
2013
from types import TracebackType
@@ -157,12 +150,13 @@ def _resolve_backend_preference(cls, backend: str) -> str:
157150
def _should_use_pybind_backend(self) -> bool:
158151
if self.backend == "capi":
159152
return False
153+
pybind_module = get_pybind_module()
160154
if self.backend == "pybind":
161-
if _lbug_pybind is None:
155+
if pybind_module is None:
162156
msg = "Requested pybind backend, but ladybug._lbug is not available."
163157
raise RuntimeError(msg)
164158
return True
165-
return _lbug_pybind is not None
159+
return pybind_module is not None
166160

167161
def __enter__(self) -> Self:
168162
return self
@@ -185,7 +179,7 @@ def get_version() -> str:
185179
str
186180
The version of the database.
187181
"""
188-
return _lbug.Database.get_version() # type: ignore[union-attr]
182+
return get_capi_module().Database.get_version()
189183

190184
@staticmethod
191185
def get_storage_version() -> int:
@@ -197,7 +191,7 @@ def get_storage_version() -> int:
197191
int
198192
The storage version of the database.
199193
"""
200-
return _lbug.Database.get_storage_version() # type: ignore[union-attr]
194+
return get_capi_module().Database.get_storage_version()
201195

202196
def __getstate__(self) -> dict[str, Any]:
203197
state = {
@@ -217,7 +211,7 @@ def init_database(self) -> None:
217211
if self._use_pybind_backend:
218212
self._database = self.init_pybind_database()
219213
else:
220-
self._database = _lbug.Database( # type: ignore[union-attr]
214+
self._database = get_capi_module().Database(
221215
self.database_path,
222216
self.buffer_pool_size,
223217
self.max_num_threads,
@@ -234,10 +228,11 @@ def init_database(self) -> None:
234228
def init_pybind_database(self) -> Any | None:
235229
"""Initialize and return the optional pybind database backend."""
236230
self.check_for_database_close()
237-
if _lbug_pybind is None:
231+
pybind_module = get_pybind_module()
232+
if pybind_module is None:
238233
return None
239234
if self._pybind_database is None:
240-
self._pybind_database = _lbug_pybind.Database(
235+
self._pybind_database = pybind_module.Database(
241236
self.database_path,
242237
self.buffer_pool_size,
243238
self.max_num_threads,

src_py/query_result.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import pyarrow as pa
1919
import torch_geometric.data as geo
2020

21-
from . import _lbug_capi as _lbug
22-
2321
if sys.version_info >= (3, 11):
2422
from typing import Self
2523
else:
@@ -29,7 +27,7 @@
2927
class QueryResult:
3028
"""QueryResult stores the result of a query execution."""
3129

32-
def __init__(self, connection: _lbug.Connection, query_result: _lbug.QueryResult): # type: ignore[name-defined]
30+
def __init__(self, connection: Any, query_result: Any):
3331
"""
3432
Parameters
3533
----------

0 commit comments

Comments
 (0)