Skip to content

Commit dae0934

Browse files
EA: add back _from_scalar / cast_pointwise_result backwards compat
1 parent ecf28e5 commit dae0934

File tree

11 files changed

+107
-17
lines changed

11 files changed

+107
-17
lines changed

pandas/core/arrays/_mixins.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ def _hash_pandas_object(
208208
values, encoding=encoding, hash_key=hash_key, categorize=categorize
209209
)
210210

211+
def _cast_pointwise_result(self, values: ArrayLike) -> ArrayLike:
212+
values = np.asarray(values, dtype=object)
213+
return lib.maybe_convert_objects(values, convert_non_numeric=True)
214+
211215
# Signature of "argmin" incompatible with supertype "ExtensionArray"
212216
def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
213217
# override base class by adding axis keyword

pandas/core/arrays/arrow/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
442442
# e.g. test_by_column_values_with_same_starting_value with nested
443443
# values, one entry of which is an ArrowStringArray
444444
# or test_agg_lambda_complex128_dtype_conversion for complex values
445-
return super()._cast_pointwise_result(values)
445+
return values
446446

447447
if pa.types.is_null(arr.type):
448448
if lib.infer_dtype(values) == "decimal":
@@ -498,7 +498,7 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
498498
if self.dtype.na_value is np.nan:
499499
# ArrowEA has different semantics, so we return numpy-based
500500
# result instead
501-
return super()._cast_pointwise_result(values)
501+
return values
502502
return ArrowExtensionArray(arr)
503503
return self._from_pyarrow_array(arr)
504504

pandas/core/arrays/base.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
cast,
2020
overload,
2121
)
22+
import warnings
2223

2324
import numpy as np
2425

@@ -33,6 +34,7 @@
3334
cache_readonly,
3435
set_module,
3536
)
37+
from pandas.util._exceptions import find_stack_level
3638
from pandas.util._validators import (
3739
validate_bool_kwarg,
3840
validate_insert_loc,
@@ -86,6 +88,7 @@
8688
AstypeArg,
8789
AxisInt,
8890
Dtype,
91+
DtypeObj,
8992
FillnaOptions,
9093
InterpolateOptions,
9194
NumpySorter,
@@ -353,6 +356,38 @@ def _from_sequence_of_strings(
353356
"""
354357
raise AbstractMethodError(cls)
355358

359+
@classmethod
360+
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
361+
"""
362+
Strict analogue to _from_sequence, allowing only sequences of scalars
363+
that should be specifically inferred to the given dtype.
364+
365+
Parameters
366+
----------
367+
scalars : sequence
368+
dtype : ExtensionDtype
369+
370+
Raises
371+
------
372+
TypeError or ValueError
373+
374+
Notes
375+
-----
376+
This is called in a try/except block when casting the result of a
377+
pointwise operation.
378+
"""
379+
try:
380+
return cls._from_sequence(scalars, dtype=dtype, copy=False)
381+
except (ValueError, TypeError):
382+
raise
383+
except Exception:
384+
warnings.warn(
385+
"_from_scalars should only raise ValueError or TypeError. "
386+
"Consider overriding _from_scalars where appropriate.",
387+
stacklevel=find_stack_level(),
388+
)
389+
raise
390+
356391
@classmethod
357392
def _from_factorized(cls, values, original):
358393
"""
@@ -383,13 +418,26 @@ def _from_factorized(cls, values, original):
383418
"""
384419
raise AbstractMethodError(cls)
385420

386-
def _cast_pointwise_result(self, values) -> ArrayLike:
421+
def _cast_pointwise_result(self, values: ArrayLike) -> ArrayLike:
387422
"""
423+
Construct an ExtensionArray after a pointwise operation.
424+
388425
Cast the result of a pointwise operation (e.g. Series.map) to an
389-
array, preserve dtype_backend if possible.
426+
array. This is not required to return an ExtensionArray of the same
427+
type as self or of the same dtype. It can also return another
428+
ExtensionArray of the same "family" if you implement multiple
429+
ExtensionArrays/Dtypes that are interoperable (e.g. if you have float
430+
array with units, this method can return an int array with units).
431+
432+
If converting to your own ExtensionArray is not possible, this method
433+
can raise an error (TypeError or ValueError) or return the input
434+
`values` as-is. Then pandas will do the further type inference.
435+
390436
"""
391-
values = np.asarray(values, dtype=object)
392-
return lib.maybe_convert_objects(values, convert_non_numeric=True)
437+
try:
438+
return type(self)._from_scalars(values, dtype=self.dtype)
439+
except (ValueError, TypeError):
440+
return values
393441

394442
# ------------------------------------------------------------------------
395443
# Must be a Sequence

pandas/core/arrays/sparse/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ def _from_factorized(cls, values, original) -> Self:
622622
return cls(values, dtype=original.dtype)
623623

624624
def _cast_pointwise_result(self, values):
625-
result = super()._cast_pointwise_result(values)
625+
result = lib.maybe_convert_objects(values, convert_non_numeric=True)
626626
if result.dtype.kind == self.dtype.kind:
627627
try:
628628
# e.g. test_groupby_agg_extension

pandas/core/dtypes/cast.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,42 @@ def maybe_upcast_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT:
414414
return arr
415415

416416

417+
def cast_pointwise_result(
418+
result: ArrayLike,
419+
original_array: ArrayLike,
420+
) -> ArrayLike:
421+
"""
422+
Try casting result of a pointwise operation back to the original dtype if
423+
appropriate.
424+
425+
Parameters
426+
----------
427+
result : array-like
428+
Result to cast.
429+
original_array : array-like
430+
Input array from which result was calculated.
431+
432+
Returns
433+
-------
434+
array-like
435+
"""
436+
if isinstance(original_array.dtype, ExtensionDtype):
437+
try:
438+
result = original_array._cast_pointwise_result(result)
439+
except (TypeError, ValueError):
440+
pass
441+
442+
if isinstance(result.dtype, ExtensionDtype):
443+
return result
444+
445+
if not isinstance(result, np.ndarray):
446+
result = np.asarray(result, dtype=object)
447+
448+
if result.dtype != object:
449+
return result
450+
return lib.maybe_convert_objects(result, convert_non_numeric=True)
451+
452+
417453
@overload
418454
def ensure_dtype_can_hold_na(dtype: np.dtype) -> np.dtype: ...
419455

pandas/core/frame.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
from pandas.core.dtypes.cast import (
8585
LossySetitemError,
8686
can_hold_element,
87+
cast_pointwise_result,
8788
construct_1d_arraylike_from_scalar,
8889
construct_2d_arraylike_from_scalar,
8990
find_common_type,
@@ -11200,7 +11201,7 @@ def _append_internal(
1120011201
if isinstance(self.index.dtype, ExtensionDtype):
1120111202
# GH#41626 retain e.g. CategoricalDtype if reached via
1120211203
# df.loc[key] = item
11203-
row_df.index = self.index.array._cast_pointwise_result(row_df.index._values)
11204+
row_df.index = cast_pointwise_result(row_df.index._values, self.index.array)
1120411205

1120511206
# infer_objects is needed for
1120611207
# test_append_empty_frame_to_series_with_dateutil_tz

pandas/core/groupby/ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pandas.util._decorators import cache_readonly
3636

3737
from pandas.core.dtypes.cast import (
38+
cast_pointwise_result,
3839
maybe_downcast_to_dtype,
3940
)
4041
from pandas.core.dtypes.common import (
@@ -963,7 +964,7 @@ def agg_series(
963964
np.ndarray or ExtensionArray
964965
"""
965966
result = self._aggregate_series_pure_python(obj, func)
966-
return obj.array._cast_pointwise_result(result)
967+
return cast_pointwise_result(result, obj.array)
967968

968969
@final
969970
def _aggregate_series_pure_python(

pandas/core/indexes/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
from pandas.core.dtypes.cast import (
9090
LossySetitemError,
9191
can_hold_element,
92+
cast_pointwise_result,
9293
common_dtype_categorical_compat,
9394
find_result_type,
9495
infer_dtype_from,
@@ -6531,7 +6532,7 @@ def map(self, mapper, na_action: Literal["ignore"] | None = None):
65316532
# e.g. if we are floating and new_values is all ints, then we
65326533
# don't want to cast back to floating. But if we are UInt64
65336534
# and new_values is all ints, we want to try.
6534-
new_values = arr._cast_pointwise_result(new_values)
6535+
new_values = cast_pointwise_result(new_values, arr)
65356536
dtype = new_values.dtype
65366537
return Index(new_values, dtype=dtype, copy=False, name=self.name)
65376538

pandas/core/series.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from pandas.core.dtypes.astype import astype_is_view
7070
from pandas.core.dtypes.cast import (
7171
LossySetitemError,
72+
cast_pointwise_result,
7273
construct_1d_arraylike_from_scalar,
7374
find_common_type,
7475
infer_dtype_from,
@@ -3252,7 +3253,7 @@ def combine(
32523253
new_values[:] = [func(lv, other) for lv in self._values]
32533254
new_name = self.name
32543255

3255-
res_values = self.array._cast_pointwise_result(new_values)
3256+
res_values = cast_pointwise_result(new_values, self.array)
32563257
return self._constructor(
32573258
res_values,
32583259
dtype=res_values.dtype,

pandas/tests/extension/decimal/array.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,12 @@ def _from_factorized(cls, values, original):
112112
return cls(values)
113113

114114
def _cast_pointwise_result(self, values):
115-
result = super()._cast_pointwise_result(values)
116115
try:
117116
# If this were ever made a non-test EA, special-casing could
118117
# be avoided by handling Decimal in maybe_convert_objects
119-
res = type(self)._from_sequence(result, dtype=self.dtype)
118+
res = type(self)._from_sequence(values, dtype=self.dtype)
120119
except (ValueError, TypeError):
121-
return result
120+
return values
122121
return res
123122

124123
_HANDLED_TYPES = (decimal.Decimal, numbers.Number, np.ndarray)

0 commit comments

Comments
 (0)