diff --git a/quaddtype/numpy_quaddtype/src/casts.cpp b/quaddtype/numpy_quaddtype/src/casts.cpp index 0d03987..659d6da 100644 --- a/quaddtype/numpy_quaddtype/src/casts.cpp +++ b/quaddtype/numpy_quaddtype/src/casts.cpp @@ -295,10 +295,14 @@ quad_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMet npy_intp required_size_chars = QUAD_STR_WIDTH; npy_intp required_size_bytes = required_size_chars * 4; // UCS4 = 4 bytes per char + Py_INCREF(given_descrs[0]); + loop_descrs[0] = given_descrs[0]; + if (given_descrs[1] == NULL) { // Create descriptor with required size PyArray_Descr *unicode_descr = PyArray_DescrNewFromType(NPY_UNICODE); if (unicode_descr == nullptr) { + Py_DECREF(loop_descrs[0]); return (NPY_CASTING)-1; } @@ -306,14 +310,21 @@ quad_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMet loop_descrs[1] = unicode_descr; } else { - Py_INCREF(given_descrs[1]); - loop_descrs[1] = given_descrs[1]; + // Handle non-native byte order by requesting native byte order + // NumPy will handle the byte swapping automatically + if (!PyArray_ISNBO(given_descrs[1]->byteorder)) { + loop_descrs[1] = PyArray_DescrNewByteorder(given_descrs[1], NPY_NATIVE); + if (loop_descrs[1] == nullptr) { + Py_DECREF(loop_descrs[0]); + return (NPY_CASTING)-1; + } + } + else { + Py_INCREF(given_descrs[1]); + loop_descrs[1] = given_descrs[1]; + } } - // Set the input descriptor - Py_INCREF(given_descrs[0]); - loop_descrs[0] = given_descrs[0]; - *view_offset = 0; // If target descriptor is wide enough, it's a safe cast @@ -1321,7 +1332,7 @@ init_casts_internal(void) .nin = 1, .nout = 1, .casting = NPY_UNSAFE_CASTING, - .flags = NPY_METH_SUPPORTS_UNALIGNED, + .flags = static_cast(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_REQUIRES_PYAPI), .dtypes = unicode_to_quad_dtypes, .slots = unicode_to_quad_slots, }; @@ -1340,7 +1351,7 @@ init_casts_internal(void) .nin = 1, .nout = 1, .casting = NPY_UNSAFE_CASTING, - .flags = NPY_METH_SUPPORTS_UNALIGNED, + .flags = static_cast(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_REQUIRES_PYAPI), .dtypes = quad_to_unicode_dtypes, .slots = quad_to_unicode_slots, }; @@ -1359,7 +1370,7 @@ init_casts_internal(void) .nin = 1, .nout = 1, .casting = NPY_UNSAFE_CASTING, - .flags = NPY_METH_SUPPORTS_UNALIGNED, + .flags = static_cast(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_REQUIRES_PYAPI), .dtypes = bytes_to_quad_dtypes, .slots = bytes_to_quad_slots, }; @@ -1378,7 +1389,7 @@ init_casts_internal(void) .nin = 1, .nout = 1, .casting = NPY_UNSAFE_CASTING, - .flags = NPY_METH_SUPPORTS_UNALIGNED, + .flags = static_cast(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_REQUIRES_PYAPI), .dtypes = quad_to_bytes_dtypes, .slots = quad_to_bytes_slots, }; diff --git a/quaddtype/tests/test_quaddtype.py b/quaddtype/tests/test_quaddtype.py index 14ceeb8..86f6f8f 100644 --- a/quaddtype/tests/test_quaddtype.py +++ b/quaddtype/tests/test_quaddtype.py @@ -746,7 +746,6 @@ def test_empty_bytes_raises_error(self): with pytest.raises(ValueError): bytes_array.astype(QuadPrecDType()) - class TestStringParsingEdgeCases: """Test edge cases in NumPyOS_ascii_strtoq string parsing""" @pytest.mark.parametrize("input_str", ['3.14', '-2.71', '0.0', '1e10', '-1e-10']) @@ -5105,4 +5104,31 @@ def test_pickle_fortran_order(self, backend): # Verify array is preserved np.testing.assert_array_equal(unpickled, original) assert unpickled.dtype == original.dtype - assert unpickled.flags.f_contiguous == original.flags.f_contiguous \ No newline at end of file + assert unpickled.flags.f_contiguous == original.flags.f_contiguous + +@pytest.mark.parametrize("dtype", [ + "bool", + "byte", "int8", "ubyte", "uint8", + "short", "int16", "ushort", "uint16", + "int", "int32", "uint", "uint32", + "long", "ulong", + "longlong", "int64", "ulonglong", "uint64", + "half", "float16", + "float", "float32", + "double", "float64", + "longdouble", "float96", "float128", + "S50", "U50", "U50", +]) +@pytest.mark.parametrize('size', [500, 1000, 10000]) +def test_large_array_casting(dtype, size): + """Test long array casting won't lead segfault, GIL enabled""" + if dtype in ("float96", "float128") and getattr(np, dtype, None) is None: + pytest.skip(f"{dtype} is unsupported on the current platform") + arr = np.arange(size).astype(np.float32).astype(dtype) + quad_arr = arr.astype(QuadPrecDType()) + assert quad_arr.dtype == QuadPrecDType() + assert quad_arr.size == size + + # check roundtrip + roundtrip = quad_arr.astype(dtype) + np.testing.assert_array_equal(arr, roundtrip) \ No newline at end of file