File tree Expand file tree Collapse file tree 2 files changed +25
-3
lines changed
Expand file tree Collapse file tree 2 files changed +25
-3
lines changed Original file line number Diff line number Diff line change @@ -165,13 +165,21 @@ def __repr__(self) -> str:
165165 # Instead of `__array__` we now implement the buffer protocol.
166166 # Note that it makes array-apis-strict requiring python>=3.12
167167 def __buffer__ (self , flags ):
168- print ('__buffer__' )
169168 if self ._device != CPU_DEVICE :
170169 raise RuntimeError (f"Can not convert array on the '{ self ._device } ' device to a Numpy array." )
171170 return memoryview (self ._array )
172171 def __release_buffer (self , buffer ):
173- print ('__release__' )
174172 # XXX anything to do here?
173+ pass
174+
175+ def __array__ (self , * args , ** kwds ):
176+ # a stub for python < 3.12; otherwise numpy silently produces object arrays
177+ import sys
178+ minor , major = sys .version_info .minor , sys .version_info .major
179+ if major < 3 or minor < 12 :
180+ raise TypeError (
181+ "Interoperation with NumPy requires python >= 3.12. Please upgrade."
182+ )
175183
176184 # These are various helper functions to make the array behavior match the
177185 # spec in places where it either deviates from or is more strict than
Original file line number Diff line number Diff line change @@ -541,9 +541,23 @@ def test_array_conversion():
541541
542542 for device in ("device1" , "device2" ):
543543 a = ones ((2 , 3 ), device = array_api_strict .Device (device ))
544- with pytest .raises (RuntimeError , match = "Can not convert array" ):
544+ with pytest .raises (( RuntimeError , ValueError ) ):
545545 np .asarray (a )
546546
547+ # __buffer__ should work for now for conversion to numpy
548+ a = ones ((2 , 3 ))
549+ na = np .array (a )
550+ assert na .shape == (2 , 3 )
551+ assert na .dtype == np .float64
552+
553+ @pytest .mark .skipif (not sys .version_info .major * 100 + sys .version_info .minor < 312 ,
554+ reason = "conversion to numpy errors out unless python >= 3.12"
555+ )
556+ def test_array_conversion_2 ():
557+ a = ones ((2 , 3 ))
558+ with pytest .raises (TypeError ):
559+ np .array (a )
560+
547561
548562def test_allow_newaxis ():
549563 a = ones (5 )
You can’t perform that action at this time.
0 commit comments