diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 1adbd4e9807..b3604beb685 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1786,6 +1786,55 @@ def test_tensor_type_cast(): assert storage_result.equals(storage) +@pytest.mark.pandas +@pytest.mark.parametrize("value_type", [pa.int8(), pa.float32(), pa.float64()]) +@pytest.mark.parametrize("shape,permutation", [ + ([2, 2], None), + ([2, 3], None), + ([2, 2, 3], [0, 2, 1]), +]) +def test_tensor_type_to_pandas(value_type, shape, permutation): + # GH-49907: to_pandas_dtype should return a pandas dtype instead of + # raising NotImplementedError, and enable Table.to_pandas(split_blocks=True) + import pandas as pd + + if Version(pd.__version__) < Version("2.1.0"): + # pd.ArrowDtype extension blocks are only reliable from 2.1.0, + # see GH-35821 + pytest.skip("requires pandas >= 2.1.0") + + tensor_type = pa.fixed_shape_tensor( + value_type, shape, permutation=permutation) + + # The type maps to a pandas ArrowDtype wrapping the extension type + dtype = tensor_type.to_pandas_dtype() + assert isinstance(dtype, pd.ArrowDtype) + assert dtype.pyarrow_dtype == tensor_type + + # Build an extension array of a few tensors via the storage type so the + # explicit permutation is preserved exactly + size = 3 + n = int(np.prod(shape)) + storage = pa.array( + [list(range(i * n, (i + 1) * n)) for i in range(size)], + pa.list_(value_type, n)) + arr = pa.ExtensionArray.from_storage(tensor_type, storage) + + # Array.to_pandas uses the ArrowDtype + series = arr.to_pandas() + assert isinstance(series.dtype, pd.ArrowDtype) + assert series.dtype.pyarrow_dtype == tensor_type + assert len(series) == size + + # Table.to_pandas, including the split_blocks=True path from GH-49907 + table = pa.table({"tensor": arr}) + for split_blocks in [False, True]: + result = table.to_pandas(split_blocks=split_blocks) + assert isinstance(result["tensor"].dtype, pd.ArrowDtype) + assert result["tensor"].dtype.pyarrow_dtype == tensor_type + assert len(result) == size + + @pytest.mark.pandas def test_extension_to_pandas_storage_type(registered_period_type): period_type, _ = registered_period_type diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index ec1a5a2ba9a..8f444041121 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -2049,6 +2049,36 @@ cdef class FixedShapeTensorType(BaseExtensionType): else: return None + def to_pandas_dtype(self): + """ + Return the equivalent pandas dtype, an instance of + :class:`pandas.ArrowDtype` wrapping this extension type. + + Each value of the resulting pandas column is a tensor with this + type's ``shape``. Returning a pandas extension dtype (rather than a + NumPy dtype) is what lets ``Table.to_pandas(split_blocks=True)`` + build an extension block for this type. + + This requires pandas >= 2.1.0, the first version with reliable + ``ArrowDtype`` extension blocks (see GH-35821). On older pandas it + raises ``NotImplementedError`` and conversion falls back to the + object dtype. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.fixed_shape_tensor(pa.int32(), [2, 2]).to_pandas_dtype() + extension[pyarrow] + """ + if not _pandas_api.is_ge_v21(): + # pandas.ArrowDtype extension blocks are only reliable from 2.1.0 + # (GH-35821); on older pandas keep the documented fallback so the + # conversion code produces an object-dtype column instead. + raise NotImplementedError( + f"{self} requires pandas >= 2.1.0 to map to pandas.ArrowDtype") + import pandas as pd + return pd.ArrowDtype(self) + def __arrow_ext_class__(self): return FixedShapeTensorArray