7373
7474from ._interface import IdentityOperator , LinearOperator , aslinearoperator
7575
76- # ---------------------------------------------------------------------------
77- # oneMKL sparse SpMV hook -- cached-handle API
78- # ---------------------------------------------------------------------------
79-
80- try :
81- from dpnp .backend .extensions .sparse import _sparse_impl as _si
82-
83- _HAS_SPARSE_IMPL = True
84- except ImportError :
85- _si = None
86- _HAS_SPARSE_IMPL = False
87-
8876_SUPPORTED_DTYPES = frozenset ("fdFD" )
8977
9078
@@ -112,11 +100,15 @@ class _CachedSpMV:
112100 Parameters
113101 ----------
114102 A : dpnp CSR sparse matrix
103+ si : dpnp.backend.extensions.sparse._sparse_impl module
104+ Passed in from _make_fast_matvec to keep the import lazy and
105+ avoid a circular import during dpnp package initialization.
115106 trans : int 0=N, 1=T, 2=C (fixed at construction)
116107 """
117108
118109 __slots__ = (
119110 "_A" ,
111+ "_si" ,
120112 "_exec_q" ,
121113 "_handle" ,
122114 "_trans" ,
@@ -129,8 +121,9 @@ class _CachedSpMV:
129121 "_val_type_id" ,
130122 )
131123
132- def __init__ (self , A , trans : int = 0 ):
124+ def __init__ (self , A , si , trans : int = 0 ):
133125 self ._A = A # keep alive so USM pointers stay valid
126+ self ._si = si
134127 self ._trans = int (trans )
135128 self ._nrows = int (A .shape [0 ])
136129 self ._ncols = int (A .shape [1 ])
@@ -154,7 +147,7 @@ def __init__(self, A, trans: int = 0):
154147 # init_matrix_handle + set_csr_data + optimize_gemv (once).
155148 # We must wait on optimize_gemv before any compute call can run;
156149 # this is the only place __init__/__call__ blocks.
157- handle , val_type_id , ev = _si ._sparse_gemv_init (
150+ handle , val_type_id , ev = self . _si ._sparse_gemv_init (
158151 self ._exec_q ,
159152 self ._trans ,
160153 A .indptr ,
@@ -177,7 +170,7 @@ def __call__(self, x: dpnp.ndarray) -> dpnp.ndarray:
177170 # Do NOT wait on the event -- subsequent dpnp ops on the same
178171 # queue will serialize behind it automatically. Blocking here
179172 # throws away async overlap and dominates small-problem runtime.
180- _si ._sparse_gemv_compute (
173+ self . _si ._sparse_gemv_compute (
181174 self ._exec_q ,
182175 self ._handle ,
183176 self ._val_type_id ,
@@ -196,9 +189,10 @@ def __del__(self):
196189 # Guard against partial construction: _handle may not be set if
197190 # __init__ raised before the assignment.
198191 handle = getattr (self , "_handle" , None )
199- if handle is not None and _si is not None :
192+ si = getattr (self , "_si" , None )
193+ if handle is not None and si is not None :
200194 try :
201- _si ._sparse_gemv_release (self ._exec_q , handle , [])
195+ si ._sparse_gemv_release (self ._exec_q , handle , [])
202196 except Exception :
203197 pass
204198 self ._handle = None
@@ -207,24 +201,27 @@ def __del__(self):
207201class _CachedSpMVPair :
208202 """Holds forward and (lazily built) adjoint cached SpMV handles."""
209203
210- __slots__ = ("forward" , "_A" , "_adjoint" )
204+ __slots__ = ("forward" , "_A" , "_si" , " _adjoint" )
211205
212- def __init__ (self , A ):
213- self .forward = _CachedSpMV (A , trans = 0 )
206+ def __init__ (self , A , si ):
214207 self ._A = A
208+ self ._si = si
209+ self .forward = _CachedSpMV (A , si , trans = 0 )
215210 self ._adjoint = None
216211
217212 def matvec (self , x ):
218213 """Apply the operator to vector x."""
219214 return self .forward (x )
220215
221216 def rmatvec (self , x ):
222- """Return the data type of the operator ."""
217+ """Apply the conjugate-transpose operator to vector x ."""
223218 if self ._adjoint is None :
224219 # Build conjtrans handle on first use. For real dtypes
225220 # this is equivalent to trans=1.
226221 is_cpx = dpnp .issubdtype (self ._A .data .dtype , dpnp .complexfloating )
227- self ._adjoint = _CachedSpMV (self ._A , trans = 2 if is_cpx else 1 )
222+ self ._adjoint = _CachedSpMV (
223+ self ._A , self ._si , trans = 2 if is_cpx else 1
224+ )
228225 return self ._adjoint (x )
229226
230227
@@ -245,15 +242,21 @@ def _make_fast_matvec(A):
245242 except (ImportError , AttributeError ):
246243 return None
247244
248- if not _HAS_SPARSE_IMPL :
245+ # Lazy backend import -- mirrors cupyx/scipy/sparse/linalg/_iterative.py,
246+ # which imports cusparse inside this function. Keeping the import out of
247+ # module scope avoids re-entering the partially-initialized dpnp package
248+ # while dpnp/__init__.py is still executing `from . import scipy as scipy`.
249+ try :
250+ from dpnp .backend .extensions .sparse import _sparse_impl as _si
251+ except ImportError :
249252 return None
250253
251254 # Only build the cached handle for supported dtypes.
252255 if _np_dtype (A .data .dtype ).char not in _SUPPORTED_DTYPES :
253256 return None
254257
255258 try :
256- return _CachedSpMVPair (A )
259+ return _CachedSpMVPair (A , _si )
257260 except Exception :
258261 return None
259262
0 commit comments