Skip to content

Commit 181eb59

Browse files
committed
fix the cyclic importing issue
1 parent 3ab930c commit 181eb59

1 file changed

Lines changed: 27 additions & 24 deletions

File tree

dpnp/scipy/sparse/linalg/_iterative.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,6 @@
7373

7474
from ._interface import IdentityOperator, LinearOperator, aslinearoperator
7575

76-
# ---------------------------------------------------------------------------
77-
# oneMKL sparse SpMV hook -- cached-handle API
78-
# ---------------------------------------------------------------------------
79-
80-
try:
81-
from dpnp.backend.extensions.sparse import _sparse_impl as _si
82-
83-
_HAS_SPARSE_IMPL = True
84-
except ImportError:
85-
_si = None
86-
_HAS_SPARSE_IMPL = False
87-
8876
_SUPPORTED_DTYPES = frozenset("fdFD")
8977

9078

@@ -112,11 +100,15 @@ class _CachedSpMV:
112100
Parameters
113101
----------
114102
A : dpnp CSR sparse matrix
103+
si : dpnp.backend.extensions.sparse._sparse_impl module
104+
Passed in from _make_fast_matvec to keep the import lazy and
105+
avoid a circular import during dpnp package initialization.
115106
trans : int 0=N, 1=T, 2=C (fixed at construction)
116107
"""
117108

118109
__slots__ = (
119110
"_A",
111+
"_si",
120112
"_exec_q",
121113
"_handle",
122114
"_trans",
@@ -129,8 +121,9 @@ class _CachedSpMV:
129121
"_val_type_id",
130122
)
131123

132-
def __init__(self, A, trans: int = 0):
124+
def __init__(self, A, si, trans: int = 0):
133125
self._A = A # keep alive so USM pointers stay valid
126+
self._si = si
134127
self._trans = int(trans)
135128
self._nrows = int(A.shape[0])
136129
self._ncols = int(A.shape[1])
@@ -154,7 +147,7 @@ def __init__(self, A, trans: int = 0):
154147
# init_matrix_handle + set_csr_data + optimize_gemv (once).
155148
# We must wait on optimize_gemv before any compute call can run;
156149
# this is the only place __init__/__call__ blocks.
157-
handle, val_type_id, ev = _si._sparse_gemv_init(
150+
handle, val_type_id, ev = self._si._sparse_gemv_init(
158151
self._exec_q,
159152
self._trans,
160153
A.indptr,
@@ -177,7 +170,7 @@ def __call__(self, x: dpnp.ndarray) -> dpnp.ndarray:
177170
# Do NOT wait on the event -- subsequent dpnp ops on the same
178171
# queue will serialize behind it automatically. Blocking here
179172
# throws away async overlap and dominates small-problem runtime.
180-
_si._sparse_gemv_compute(
173+
self._si._sparse_gemv_compute(
181174
self._exec_q,
182175
self._handle,
183176
self._val_type_id,
@@ -196,9 +189,10 @@ def __del__(self):
196189
# Guard against partial construction: _handle may not be set if
197190
# __init__ raised before the assignment.
198191
handle = getattr(self, "_handle", None)
199-
if handle is not None and _si is not None:
192+
si = getattr(self, "_si", None)
193+
if handle is not None and si is not None:
200194
try:
201-
_si._sparse_gemv_release(self._exec_q, handle, [])
195+
si._sparse_gemv_release(self._exec_q, handle, [])
202196
except Exception:
203197
pass
204198
self._handle = None
@@ -207,24 +201,27 @@ def __del__(self):
207201
class _CachedSpMVPair:
208202
"""Holds forward and (lazily built) adjoint cached SpMV handles."""
209203

210-
__slots__ = ("forward", "_A", "_adjoint")
204+
__slots__ = ("forward", "_A", "_si", "_adjoint")
211205

212-
def __init__(self, A):
213-
self.forward = _CachedSpMV(A, trans=0)
206+
def __init__(self, A, si):
214207
self._A = A
208+
self._si = si
209+
self.forward = _CachedSpMV(A, si, trans=0)
215210
self._adjoint = None
216211

217212
def matvec(self, x):
218213
"""Apply the operator to vector x."""
219214
return self.forward(x)
220215

221216
def rmatvec(self, x):
222-
"""Return the data type of the operator."""
217+
"""Apply the conjugate-transpose operator to vector x."""
223218
if self._adjoint is None:
224219
# Build conjtrans handle on first use. For real dtypes
225220
# this is equivalent to trans=1.
226221
is_cpx = dpnp.issubdtype(self._A.data.dtype, dpnp.complexfloating)
227-
self._adjoint = _CachedSpMV(self._A, trans=2 if is_cpx else 1)
222+
self._adjoint = _CachedSpMV(
223+
self._A, self._si, trans=2 if is_cpx else 1
224+
)
228225
return self._adjoint(x)
229226

230227

@@ -245,15 +242,21 @@ def _make_fast_matvec(A):
245242
except (ImportError, AttributeError):
246243
return None
247244

248-
if not _HAS_SPARSE_IMPL:
245+
# Lazy backend import -- mirrors cupyx/scipy/sparse/linalg/_iterative.py,
246+
# which imports cusparse inside this function. Keeping the import out of
247+
# module scope avoids re-entering the partially-initialized dpnp package
248+
# while dpnp/__init__.py is still executing `from . import scipy as scipy`.
249+
try:
250+
from dpnp.backend.extensions.sparse import _sparse_impl as _si
251+
except ImportError:
249252
return None
250253

251254
# Only build the cached handle for supported dtypes.
252255
if _np_dtype(A.data.dtype).char not in _SUPPORTED_DTYPES:
253256
return None
254257

255258
try:
256-
return _CachedSpMVPair(A)
259+
return _CachedSpMVPair(A, _si)
257260
except Exception:
258261
return None
259262

0 commit comments

Comments
 (0)