Skip to content

Commit dbe02ec

Browse files
authored
Merge pull request #238 from SwayamInSync/gil-flag
FIX: Enable GIL for bytes/unicode array casting to/from Quaddtype
2 parents 8db38b0 + a2e4bd2 commit dbe02ec

File tree

2 files changed

+49
-12
lines changed

2 files changed

+49
-12
lines changed

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -295,25 +295,36 @@ quad_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMet
295295
npy_intp required_size_chars = QUAD_STR_WIDTH;
296296
npy_intp required_size_bytes = required_size_chars * 4; // UCS4 = 4 bytes per char
297297

298+
Py_INCREF(given_descrs[0]);
299+
loop_descrs[0] = given_descrs[0];
300+
298301
if (given_descrs[1] == NULL) {
299302
// Create descriptor with required size
300303
PyArray_Descr *unicode_descr = PyArray_DescrNewFromType(NPY_UNICODE);
301304
if (unicode_descr == nullptr) {
305+
Py_DECREF(loop_descrs[0]);
302306
return (NPY_CASTING)-1;
303307
}
304308

305309
unicode_descr->elsize = required_size_bytes;
306310
loop_descrs[1] = unicode_descr;
307311
}
308312
else {
309-
Py_INCREF(given_descrs[1]);
310-
loop_descrs[1] = given_descrs[1];
313+
// Handle non-native byte order by requesting native byte order
314+
// NumPy will handle the byte swapping automatically
315+
if (!PyArray_ISNBO(given_descrs[1]->byteorder)) {
316+
loop_descrs[1] = PyArray_DescrNewByteorder(given_descrs[1], NPY_NATIVE);
317+
if (loop_descrs[1] == nullptr) {
318+
Py_DECREF(loop_descrs[0]);
319+
return (NPY_CASTING)-1;
320+
}
321+
}
322+
else {
323+
Py_INCREF(given_descrs[1]);
324+
loop_descrs[1] = given_descrs[1];
325+
}
311326
}
312327

313-
// Set the input descriptor
314-
Py_INCREF(given_descrs[0]);
315-
loop_descrs[0] = given_descrs[0];
316-
317328
*view_offset = 0;
318329

319330
// If target descriptor is wide enough, it's a safe cast
@@ -1321,7 +1332,7 @@ init_casts_internal(void)
13211332
.nin = 1,
13221333
.nout = 1,
13231334
.casting = NPY_UNSAFE_CASTING,
1324-
.flags = NPY_METH_SUPPORTS_UNALIGNED,
1335+
.flags = static_cast<NPY_ARRAYMETHOD_FLAGS>(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_REQUIRES_PYAPI),
13251336
.dtypes = unicode_to_quad_dtypes,
13261337
.slots = unicode_to_quad_slots,
13271338
};
@@ -1340,7 +1351,7 @@ init_casts_internal(void)
13401351
.nin = 1,
13411352
.nout = 1,
13421353
.casting = NPY_UNSAFE_CASTING,
1343-
.flags = NPY_METH_SUPPORTS_UNALIGNED,
1354+
.flags = static_cast<NPY_ARRAYMETHOD_FLAGS>(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_REQUIRES_PYAPI),
13441355
.dtypes = quad_to_unicode_dtypes,
13451356
.slots = quad_to_unicode_slots,
13461357
};
@@ -1359,7 +1370,7 @@ init_casts_internal(void)
13591370
.nin = 1,
13601371
.nout = 1,
13611372
.casting = NPY_UNSAFE_CASTING,
1362-
.flags = NPY_METH_SUPPORTS_UNALIGNED,
1373+
.flags = static_cast<NPY_ARRAYMETHOD_FLAGS>(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_REQUIRES_PYAPI),
13631374
.dtypes = bytes_to_quad_dtypes,
13641375
.slots = bytes_to_quad_slots,
13651376
};
@@ -1378,7 +1389,7 @@ init_casts_internal(void)
13781389
.nin = 1,
13791390
.nout = 1,
13801391
.casting = NPY_UNSAFE_CASTING,
1381-
.flags = NPY_METH_SUPPORTS_UNALIGNED,
1392+
.flags = static_cast<NPY_ARRAYMETHOD_FLAGS>(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_REQUIRES_PYAPI),
13821393
.dtypes = quad_to_bytes_dtypes,
13831394
.slots = quad_to_bytes_slots,
13841395
};

quaddtype/tests/test_quaddtype.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,6 @@ def test_empty_bytes_raises_error(self):
746746
with pytest.raises(ValueError):
747747
bytes_array.astype(QuadPrecDType())
748748

749-
750749
class TestStringParsingEdgeCases:
751750
"""Test edge cases in NumPyOS_ascii_strtoq string parsing"""
752751
@pytest.mark.parametrize("input_str", ['3.14', '-2.71', '0.0', '1e10', '-1e-10'])
@@ -5105,4 +5104,31 @@ def test_pickle_fortran_order(self, backend):
51055104
# Verify array is preserved
51065105
np.testing.assert_array_equal(unpickled, original)
51075106
assert unpickled.dtype == original.dtype
5108-
assert unpickled.flags.f_contiguous == original.flags.f_contiguous
5107+
assert unpickled.flags.f_contiguous == original.flags.f_contiguous
5108+
5109+
@pytest.mark.parametrize("dtype", [
5110+
"bool",
5111+
"byte", "int8", "ubyte", "uint8",
5112+
"short", "int16", "ushort", "uint16",
5113+
"int", "int32", "uint", "uint32",
5114+
"long", "ulong",
5115+
"longlong", "int64", "ulonglong", "uint64",
5116+
"half", "float16",
5117+
"float", "float32",
5118+
"double", "float64",
5119+
"longdouble", "float96", "float128",
5120+
"S50", "U50", "<U50", ">U50",
5121+
])
5122+
@pytest.mark.parametrize('size', [500, 1000, 10000])
5123+
def test_large_array_casting(dtype, size):
5124+
"""Test long array casting won't lead segfault, GIL enabled"""
5125+
if dtype in ("float96", "float128") and getattr(np, dtype, None) is None:
5126+
pytest.skip(f"{dtype} is unsupported on the current platform")
5127+
arr = np.arange(size).astype(np.float32).astype(dtype)
5128+
quad_arr = arr.astype(QuadPrecDType())
5129+
assert quad_arr.dtype == QuadPrecDType()
5130+
assert quad_arr.size == size
5131+
5132+
# check roundtrip
5133+
roundtrip = quad_arr.astype(dtype)
5134+
np.testing.assert_array_equal(arr, roundtrip)

0 commit comments

Comments
 (0)