|
1 | 1 | import numpy as np |
2 | 2 | import pytest |
3 | 3 |
|
4 | | -from pandas._libs.missing import NA |
| 4 | +import pandas.util._test_decorators as td |
5 | 5 |
|
6 | 6 | from pandas.core.dtypes.common import is_integer |
7 | 7 |
|
@@ -449,11 +449,23 @@ def test_where_datetimelike_categorical(tz_naive_fixture): |
449 | 449 | tm.assert_frame_equal(res, pd.DataFrame(dr)) |
450 | 450 |
|
451 | 451 |
|
452 | | -def test_where_list_with_nan(): |
453 | | - ser = Series([None, 1, 2, np.nan, 3, 4, NA]) |
454 | | - cond = [np.nan, False, False, np.nan, True, True, np.nan] |
455 | | - expected = Series([None, -99, -99, np.nan, 3, 4, NA]) |
456 | | - |
457 | | - res = ser.where(cond, -99) |
458 | | - |
459 | | - tm.assert_series_equal(res, expected) |
| 452 | +@pytest.mark.parametrize( |
| 453 | + "dtype", |
| 454 | + [ |
| 455 | + "Int64", |
| 456 | + pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")), |
| 457 | + ], |
| 458 | +) |
| 459 | +@pytest.mark.parametrize("cond_type", [["series", "list", "numpy"]]) |
| 460 | +def test_where_na(dtype, cond_type): |
| 461 | + series = Series([None, 1, 2, None, 3, 4, None], dtype=dtype) |
| 462 | + expected = Series([None, 1, 2, None, -99, -99, None], dtype=dtype) |
| 463 | + cond = series <= 2 |
| 464 | + |
| 465 | + if cond_type == "list": |
| 466 | + cond = cond.to_list() |
| 467 | + elif cond_type == "numpy": |
| 468 | + cond = cond.to_numpy() |
| 469 | + |
| 470 | + result = series.where(cond, -99) |
| 471 | + tm.assert_series_equal(result, expected) |
0 commit comments