Skip to content

Commit a1a0925

Browse files
committed
Use lazy imports so tests pass without C-API
1 parent ff9c0ab commit a1a0925

4 files changed

Lines changed: 52 additions & 32 deletions

File tree

src_py/_backend.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from __future__ import annotations
2+
3+
from importlib import import_module
4+
from typing import Any
5+
6+
_CAPI_MODULE: Any | None = None
7+
_PYBIND_MODULE: Any | None = None
8+
_PYBIND_IMPORT_ATTEMPTED = False
9+
10+
11+
def get_capi_module() -> Any:
12+
global _CAPI_MODULE
13+
if _CAPI_MODULE is None:
14+
_CAPI_MODULE = import_module("._lbug_capi", __package__)
15+
return _CAPI_MODULE
16+
17+
18+
def get_pybind_module() -> Any | None:
19+
global _PYBIND_MODULE, _PYBIND_IMPORT_ATTEMPTED
20+
if _PYBIND_IMPORT_ATTEMPTED:
21+
return _PYBIND_MODULE
22+
_PYBIND_IMPORT_ATTEMPTED = True
23+
try:
24+
_PYBIND_MODULE = import_module("._lbug", __package__)
25+
except ImportError:
26+
_PYBIND_MODULE = None
27+
return _PYBIND_MODULE

src_py/connection.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,10 @@
66
from typing import TYPE_CHECKING, Any
77
from weakref import WeakSet
88

9-
from . import _lbug_capi as _lbug
9+
from ._backend import get_capi_module, get_pybind_module
1010
from .prepared_statement import PreparedStatement
1111
from .query_result import QueryResult
1212

13-
try:
14-
from . import _lbug as _lbug_pybind
15-
except (
16-
ImportError
17-
): # pragma: no cover - pybind module may be unavailable in some builds
18-
_lbug_pybind = None
19-
2013
if TYPE_CHECKING:
2114
import sys
2215
from collections.abc import Callable
@@ -73,12 +66,18 @@ def init_connection(self) -> None:
7366
self.database.init_database()
7467
if self._connection is None:
7568
backend_module = (
76-
_lbug_pybind if self.database._use_pybind_backend else _lbug
69+
get_pybind_module()
70+
if self.database._use_pybind_backend
71+
else get_capi_module()
72+
)
73+
self._connection = backend_module.Connection(
74+
self.database._database, self.num_threads
7775
)
78-
self._connection = backend_module.Connection(self.database._database, self.num_threads) # type: ignore[union-attr]
7976

8077
def _using_pybind_backend(self) -> bool:
81-
return bool(self.database._use_pybind_backend and _lbug_pybind is not None)
78+
return bool(
79+
self.database._use_pybind_backend and get_pybind_module() is not None
80+
)
8281

8382
def set_max_threads_for_exec(self, num_threads: int) -> None:
8483
"""
@@ -212,7 +211,7 @@ def _rewrite_local_scan_object(
212211
def _should_use_pybind_for_scan(
213212
self, query: str, parameters: dict[str, Any]
214213
) -> bool:
215-
if _lbug_pybind is None:
214+
if get_pybind_module() is None:
216215
return False
217216
if not self._has_scan_pattern(query):
218217
return False
@@ -230,7 +229,8 @@ def _should_use_pybind_for_scan(
230229
return False
231230

232231
def _get_pybind_connection(self) -> Any | None:
233-
if _lbug_pybind is None:
232+
pybind_module = get_pybind_module()
233+
if pybind_module is None:
234234
return None
235235
if self._using_pybind_backend():
236236
return self._connection
@@ -239,7 +239,7 @@ def _get_pybind_connection(self) -> Any | None:
239239
if pybind_db is None:
240240
return None
241241
if self._py_connection is None:
242-
self._py_connection = _lbug_pybind.Connection(pybind_db, self.num_threads)
242+
self._py_connection = pybind_module.Connection(pybind_db, self.num_threads)
243243
return self._py_connection
244244

245245
def _execute_with_pybind(

src_py/database.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,9 @@
55
from typing import TYPE_CHECKING, Any, ClassVar
66
from weakref import WeakSet
77

8-
from . import _lbug_capi as _lbug
8+
from ._backend import get_capi_module, get_pybind_module
99
from .types import Type
1010

11-
try:
12-
from . import _lbug as _lbug_pybind
13-
except (
14-
ImportError
15-
): # pragma: no cover - pybind module may be unavailable in some builds
16-
_lbug_pybind = None
17-
1811
if TYPE_CHECKING:
1912
import sys
2013
from types import TracebackType
@@ -157,12 +150,13 @@ def _resolve_backend_preference(cls, backend: str) -> str:
157150
def _should_use_pybind_backend(self) -> bool:
158151
if self.backend == "capi":
159152
return False
153+
pybind_module = get_pybind_module()
160154
if self.backend == "pybind":
161-
if _lbug_pybind is None:
155+
if pybind_module is None:
162156
msg = "Requested pybind backend, but ladybug._lbug is not available."
163157
raise RuntimeError(msg)
164158
return True
165-
return _lbug_pybind is not None
159+
return pybind_module is not None
166160

167161
def __enter__(self) -> Self:
168162
return self
@@ -185,7 +179,7 @@ def get_version() -> str:
185179
str
186180
The version of the database.
187181
"""
188-
return _lbug.Database.get_version() # type: ignore[union-attr]
182+
return get_capi_module().Database.get_version()
189183

190184
@staticmethod
191185
def get_storage_version() -> int:
@@ -197,7 +191,7 @@ def get_storage_version() -> int:
197191
int
198192
The storage version of the database.
199193
"""
200-
return _lbug.Database.get_storage_version() # type: ignore[union-attr]
194+
return get_capi_module().Database.get_storage_version()
201195

202196
def __getstate__(self) -> dict[str, Any]:
203197
state = {
@@ -217,7 +211,7 @@ def init_database(self) -> None:
217211
if self._use_pybind_backend:
218212
self._database = self.init_pybind_database()
219213
else:
220-
self._database = _lbug.Database( # type: ignore[union-attr]
214+
self._database = get_capi_module().Database(
221215
self.database_path,
222216
self.buffer_pool_size,
223217
self.max_num_threads,
@@ -234,10 +228,11 @@ def init_database(self) -> None:
234228
def init_pybind_database(self) -> Any | None:
235229
"""Initialize and return the optional pybind database backend."""
236230
self.check_for_database_close()
237-
if _lbug_pybind is None:
231+
pybind_module = get_pybind_module()
232+
if pybind_module is None:
238233
return None
239234
if self._pybind_database is None:
240-
self._pybind_database = _lbug_pybind.Database(
235+
self._pybind_database = pybind_module.Database(
241236
self.database_path,
242237
self.buffer_pool_size,
243238
self.max_num_threads,

src_py/query_result.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import pyarrow as pa
1919
import torch_geometric.data as geo
2020

21-
from . import _lbug_capi as _lbug
22-
2321
if sys.version_info >= (3, 11):
2422
from typing import Self
2523
else:
@@ -29,7 +27,7 @@
2927
class QueryResult:
3028
"""QueryResult stores the result of a query execution."""
3129

32-
def __init__(self, connection: _lbug.Connection, query_result: _lbug.QueryResult): # type: ignore[name-defined]
30+
def __init__(self, connection: Any, query_result: Any):
3331
"""
3432
Parameters
3533
----------

0 commit comments

Comments
 (0)