Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 151 additions & 29 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -12155,7 +12155,6 @@ def op(psser: ps.Series) -> PySparkColumn:

return self._apply_series_op(op, should_resolve=True)

# TODO(SPARK-46168): axis = 1
def idxmax(self, axis: Axis = 0) -> "Series":
"""
Return index of first occurrence of maximum over requested axis.
Expand All @@ -12166,8 +12165,8 @@ def idxmax(self, axis: Axis = 0) -> "Series":

Parameters
----------
axis : 0 or 'index'
Can only be set to 0 now.
axis : {0 or 'index', 1 or 'columns'}, default 0
The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for column-wise.

Returns
-------
Expand Down Expand Up @@ -12195,6 +12194,15 @@ def idxmax(self, axis: Axis = 0) -> "Series":
c 2
dtype: int64

For axis=1, return the column label of the maximum value in each row:

>>> psdf.idxmax(axis=1)
0 c
1 c
2 c
3 c
dtype: object

For Multi-column Index

>>> psdf = ps.DataFrame({'a': [1, 2, 3, 2],
Expand All @@ -12215,25 +12223,77 @@ def idxmax(self, axis: Axis = 0) -> "Series":
c z 2
dtype: int64
"""
max_cols = map(lambda scol: F.max(scol), self._internal.data_spark_columns)
sdf_max = self._internal.spark_frame.select(*max_cols).head()
# `sdf_max` looks like below
# +------+------+------+
# |(a, x)|(b, y)|(c, z)|
# +------+------+------+
# | 3| 4.0| 400|
# +------+------+------+

conds = (
scol == max_val for scol, max_val in zip(self._internal.data_spark_columns, sdf_max)
)
cond = reduce(lambda x, y: x | y, conds)
axis = validate_axis(axis)
if axis == 0:
max_cols = map(lambda scol: F.max(scol), self._internal.data_spark_columns)
sdf_max = self._internal.spark_frame.select(*max_cols).head()
# `sdf_max` looks like below
# +------+------+------+
# |(a, x)|(b, y)|(c, z)|
# +------+------+------+
# | 3| 4.0| 400|
# +------+------+------+

conds = (
scol == max_val for scol, max_val in zip(self._internal.data_spark_columns, sdf_max)
)
cond = reduce(lambda x, y: x | y, conds)

psdf: DataFrame = DataFrame(self._internal.with_filter(cond))

return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmax()))
else:
from pyspark.pandas.series import first_series

column_labels = self._internal.column_labels

if len(column_labels) == 0:
return ps.Series([], dtype=np.int64)

if self._internal.column_labels_level > 1:
raise NotImplementedError(
"idxmax with axis=1 does not support MultiIndex columns yet"
)

max_value = F.greatest(
*[
F.coalesce(self._internal.spark_column_for(label), F.lit(float("-inf")))
for label in column_labels
],
F.lit(float("-inf")),
)

psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
result = None
for label in reversed(column_labels):
scol = self._internal.spark_column_for(label)
label_value = label[0] if len(label) == 1 else label
condition = (scol == max_value) & scol.isNotNull()

result = (
F.when(condition, F.lit(label_value))
if result is None
else F.when(condition, F.lit(label_value)).otherwise(result)
)

result = F.when(max_value == float("-inf"), F.lit(None)).otherwise(result)

sdf = self._internal.spark_frame.select(
*self._internal_frame.index_spark_columns,
result.alias(SPARK_DEFAULT_SERIES_NAME),
)

return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmax()))
return first_series(
DataFrame(
InternalFrame(
spark_frame=sdf,
index_spark_columns=self._internal.index_spark_columns,
index_names=self._internal.index_names,
index_fields=self._internal.index_fields,
column_labels=[None],
)
)
)

# TODO(SPARK-46168): axis = 1
def idxmin(self, axis: Axis = 0) -> "Series":
"""
Return index of first occurrence of minimum over requested axis.
Expand All @@ -12244,8 +12304,8 @@ def idxmin(self, axis: Axis = 0) -> "Series":

Parameters
----------
axis : 0 or 'index'
Can only be set to 0 now.
axis : {0 or 'index', 1 or 'columns'}, default 0
The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for column-wise.

Returns
-------
Expand Down Expand Up @@ -12292,18 +12352,80 @@ def idxmin(self, axis: Axis = 0) -> "Series":
b y 3
c z 1
dtype: int64

For axis=1, return the column label of the minimum value in each row:

>>> psdf.idxmin(axis=1)
0 a
1 b
2 b
3 b
dtype: object
"""
min_cols = map(lambda scol: F.min(scol), self._internal.data_spark_columns)
sdf_min = self._internal.spark_frame.select(*min_cols).head()
axis = validate_axis(axis)
if axis == 0:
min_cols = map(lambda scol: F.min(scol), self._internal.data_spark_columns)
sdf_min = self._internal.spark_frame.select(*min_cols).head()

conds = (
scol == min_val for scol, min_val in zip(self._internal.data_spark_columns, sdf_min)
)
cond = reduce(lambda x, y: x | y, conds)
conds = (
scol == min_val for scol, min_val in zip(self._internal.data_spark_columns, sdf_min)
)
cond = reduce(lambda x, y: x | y, conds)

psdf: DataFrame = DataFrame(self._internal.with_filter(cond))

return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmin()))
else:
from pyspark.pandas.series import first_series

column_labels = self._internal.column_labels

psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
if len(column_labels) == 0:
return ps.Series([], dtype=np.int64)

return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmin()))
if self._internal.column_labels_level > 1:
raise NotImplementedError(
"idxmin with axis=1 does not support MultiIndex columns yet"
)

min_value = F.least(
*[
F.coalesce(self._internal.spark_column_for(label), F.lit(float("inf")))
for label in column_labels
],
F.lit(float("inf")),
)

result = None
for label in reversed(column_labels):
scol = self._internal.spark_column_for(label)
label_value = label[0] if len(label) == 1 else label
condition = (scol == min_value) & scol.isNotNull()

result = (
F.when(condition, F.lit(label_value))
if result is None
else F.when(condition, F.lit(label_value)).otherwise(result)
)

result = F.when(min_value == float("inf"), F.lit(None)).otherwise(result)

sdf = self._internal.spark_frame.select(
*self._internal.index_spark_columns,
result.alias(SPARK_DEFAULT_SERIES_NAME),
)

return first_series(
DataFrame(
InternalFrame(
spark_frame=sdf,
index_spark_columns=self._internal.index_spark_columns,
index_names=self._internal.index_names,
index_fields=self._internal.index_fields,
column_labels=[None],
)
)
)

def info(
self,
Expand Down
Loading