Skip to content

Commit a2e4bd2

Browse files
committed
cover all dtype + BE handle
1 parent 84470c8 commit a2e4bd2

File tree

2 files changed

+45
-20
lines changed

2 files changed

+45
-20
lines changed

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -295,25 +295,36 @@ quad_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMet
295295
npy_intp required_size_chars = QUAD_STR_WIDTH;
296296
npy_intp required_size_bytes = required_size_chars * 4; // UCS4 = 4 bytes per char
297297

298+
Py_INCREF(given_descrs[0]);
299+
loop_descrs[0] = given_descrs[0];
300+
298301
if (given_descrs[1] == NULL) {
299302
// Create descriptor with required size
300303
PyArray_Descr *unicode_descr = PyArray_DescrNewFromType(NPY_UNICODE);
301304
if (unicode_descr == nullptr) {
305+
Py_DECREF(loop_descrs[0]);
302306
return (NPY_CASTING)-1;
303307
}
304308

305309
unicode_descr->elsize = required_size_bytes;
306310
loop_descrs[1] = unicode_descr;
307311
}
308312
else {
309-
Py_INCREF(given_descrs[1]);
310-
loop_descrs[1] = given_descrs[1];
313+
// Handle non-native byte order by requesting native byte order
314+
// NumPy will handle the byte swapping automatically
315+
if (!PyArray_ISNBO(given_descrs[1]->byteorder)) {
316+
loop_descrs[1] = PyArray_DescrNewByteorder(given_descrs[1], NPY_NATIVE);
317+
if (loop_descrs[1] == nullptr) {
318+
Py_DECREF(loop_descrs[0]);
319+
return (NPY_CASTING)-1;
320+
}
321+
}
322+
else {
323+
Py_INCREF(given_descrs[1]);
324+
loop_descrs[1] = given_descrs[1];
325+
}
311326
}
312327

313-
// Set the input descriptor
314-
Py_INCREF(given_descrs[0]);
315-
loop_descrs[0] = given_descrs[0];
316-
317328
*view_offset = 0;
318329

319330
// If target descriptor is wide enough, it's a safe cast

quaddtype/tests/test_quaddtype.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -746,19 +746,6 @@ def test_empty_bytes_raises_error(self):
746746
with pytest.raises(ValueError):
747747
bytes_array.astype(QuadPrecDType())
748748

749-
@pytest.mark.parametrize('dtype', ['S50', 'U50'])
750-
@pytest.mark.parametrize('size', [500, 1000, 10000])
751-
def test_large_array_casting(self, dtype, size):
752-
"""Test long array casting won't lead segfault, GIL enabled"""
753-
arr = np.arange(size).astype(np.float32).astype(dtype)
754-
quad_arr = arr.astype(QuadPrecDType())
755-
assert quad_arr.dtype == QuadPrecDType()
756-
assert quad_arr.size == size
757-
758-
# check roundtrip
759-
roundtrip = quad_arr.astype(dtype)
760-
np.testing.assert_array_equal(arr, roundtrip)
761-
762749
class TestStringParsingEdgeCases:
763750
"""Test edge cases in NumPyOS_ascii_strtoq string parsing"""
764751
@pytest.mark.parametrize("input_str", ['3.14', '-2.71', '0.0', '1e10', '-1e-10'])
@@ -5117,4 +5104,31 @@ def test_pickle_fortran_order(self, backend):
51175104
# Verify array is preserved
51185105
np.testing.assert_array_equal(unpickled, original)
51195106
assert unpickled.dtype == original.dtype
5120-
assert unpickled.flags.f_contiguous == original.flags.f_contiguous
5107+
assert unpickled.flags.f_contiguous == original.flags.f_contiguous
5108+
5109+
@pytest.mark.parametrize("dtype", [
5110+
"bool",
5111+
"byte", "int8", "ubyte", "uint8",
5112+
"short", "int16", "ushort", "uint16",
5113+
"int", "int32", "uint", "uint32",
5114+
"long", "ulong",
5115+
"longlong", "int64", "ulonglong", "uint64",
5116+
"half", "float16",
5117+
"float", "float32",
5118+
"double", "float64",
5119+
"longdouble", "float96", "float128",
5120+
"S50", "U50", "<U50", ">U50",
5121+
])
5122+
@pytest.mark.parametrize('size', [500, 1000, 10000])
5123+
def test_large_array_casting(dtype, size):
5124+
"""Test long array casting won't lead segfault, GIL enabled"""
5125+
if dtype in ("float96", "float128") and getattr(np, dtype, None) is None:
5126+
pytest.skip(f"{dtype} is unsupported on the current platform")
5127+
arr = np.arange(size).astype(np.float32).astype(dtype)
5128+
quad_arr = arr.astype(QuadPrecDType())
5129+
assert quad_arr.dtype == QuadPrecDType()
5130+
assert quad_arr.size == size
5131+
5132+
# check roundtrip
5133+
roundtrip = quad_arr.astype(dtype)
5134+
np.testing.assert_array_equal(arr, roundtrip)

0 commit comments

Comments
 (0)