diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 056b1e8ce2840..208ad611e0016 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -12155,7 +12155,6 @@ def op(psser: ps.Series) -> PySparkColumn: return self._apply_series_op(op, should_resolve=True) - # TODO(SPARK-46168): axis = 1 def idxmax(self, axis: Axis = 0) -> "Series": """ Return index of first occurrence of maximum over requested axis. @@ -12166,8 +12165,8 @@ def idxmax(self, axis: Axis = 0) -> "Series": Parameters ---------- - axis : 0 or 'index' - Can only be set to 0 now. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for column-wise. Returns ------- @@ -12195,6 +12194,15 @@ def idxmax(self, axis: Axis = 0) -> "Series": c 2 dtype: int64 + For axis=1, return the column label of the maximum value in each row: + + >>> psdf.idxmax(axis=1) + 0 c + 1 c + 2 c + 3 c + dtype: object + For Multi-column Index >>> psdf = ps.DataFrame({'a': [1, 2, 3, 2], @@ -12215,25 +12223,77 @@ def idxmax(self, axis: Axis = 0) -> "Series": c z 2 dtype: int64 """ - max_cols = map(lambda scol: F.max(scol), self._internal.data_spark_columns) - sdf_max = self._internal.spark_frame.select(*max_cols).head() - # `sdf_max` looks like below - # +------+------+------+ - # |(a, x)|(b, y)|(c, z)| - # +------+------+------+ - # | 3| 4.0| 400| - # +------+------+------+ - - conds = ( - scol == max_val for scol, max_val in zip(self._internal.data_spark_columns, sdf_max) - ) - cond = reduce(lambda x, y: x | y, conds) + axis = validate_axis(axis) + if axis == 0: + max_cols = map(lambda scol: F.max(scol), self._internal.data_spark_columns) + sdf_max = self._internal.spark_frame.select(*max_cols).head() + # `sdf_max` looks like below + # +------+------+------+ + # |(a, x)|(b, y)|(c, z)| + # +------+------+------+ + # | 3| 4.0| 400| + # +------+------+------+ + + conds = ( + scol == max_val for scol, max_val in zip(self._internal.data_spark_columns, sdf_max) + ) + cond = reduce(lambda x, y: x | y, conds) + + psdf: DataFrame = DataFrame(self._internal.with_filter(cond)) + + return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmax())) + else: + from pyspark.pandas.series import first_series + + column_labels = self._internal.column_labels + + if len(column_labels) == 0: + return ps.Series([], dtype=np.int64) + + if self._internal.column_labels_level > 1: + raise NotImplementedError( + "idxmax with axis=1 does not support MultiIndex columns yet" + ) + + max_value = F.greatest( + *[ + F.coalesce(self._internal.spark_column_for(label), F.lit(float("-inf"))) + for label in column_labels + ], + F.lit(float("-inf")), + ) - psdf: DataFrame = DataFrame(self._internal.with_filter(cond)) + result = None + for label in reversed(column_labels): + scol = self._internal.spark_column_for(label) + label_value = label[0] if len(label) == 1 else label + condition = (scol == max_value) & scol.isNotNull() + + result = ( + F.when(condition, F.lit(label_value)) + if result is None + else F.when(condition, F.lit(label_value)).otherwise(result) + ) + + result = F.when(max_value == float("-inf"), F.lit(None)).otherwise(result) + + sdf = self._internal.spark_frame.select( + *self._internal_frame.index_spark_columns, + result.alias(SPARK_DEFAULT_SERIES_NAME), + ) - return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmax())) + return first_series( + DataFrame( + InternalFrame( + spark_frame=sdf, + index_spark_columns=self._internal.index_spark_columns, + index_names=self._internal.index_names, + index_fields=self._internal.index_fields, + column_labels=[None], + ) + ) + ) - # TODO(SPARK-46168): axis = 1 def idxmin(self, axis: Axis = 0) -> "Series": """ Return index of first occurrence of minimum over requested axis. @@ -12244,8 +12304,8 @@ def idxmin(self, axis: Axis = 0) -> "Series": Parameters ---------- - axis : 0 or 'index' - Can only be set to 0 now. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for column-wise. Returns ------- @@ -12292,18 +12352,80 @@ def idxmin(self, axis: Axis = 0) -> "Series": b y 3 c z 1 dtype: int64 + + For axis=1, return the column label of the minimum value in each row: + + >>> psdf.idxmin(axis=1) + 0 a + 1 b + 2 b + 3 b + dtype: object """ - min_cols = map(lambda scol: F.min(scol), self._internal.data_spark_columns) - sdf_min = self._internal.spark_frame.select(*min_cols).head() + axis = validate_axis(axis) + if axis == 0: + min_cols = map(lambda scol: F.min(scol), self._internal.data_spark_columns) + sdf_min = self._internal.spark_frame.select(*min_cols).head() - conds = ( - scol == min_val for scol, min_val in zip(self._internal.data_spark_columns, sdf_min) - ) - cond = reduce(lambda x, y: x | y, conds) + conds = ( + scol == min_val for scol, min_val in zip(self._internal.data_spark_columns, sdf_min) + ) + cond = reduce(lambda x, y: x | y, conds) + + psdf: DataFrame = DataFrame(self._internal.with_filter(cond)) + + return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmin())) + else: + from pyspark.pandas.series import first_series + + column_labels = self._internal.column_labels - psdf: DataFrame = DataFrame(self._internal.with_filter(cond)) + if len(column_labels) == 0: + return ps.Series([], dtype=np.int64) - return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmin())) + if self._internal.column_labels_level > 1: + raise NotImplementedError( + "idxmin with axis=1 does not support MultiIndex columns yet" + ) + + min_value = F.least( + *[ + F.coalesce(self._internal.spark_column_for(label), F.lit(float("inf"))) + for label in column_labels + ], + F.lit(float("inf")), + ) + + result = None + for label in reversed(column_labels): + scol = self._internal.spark_column_for(label) + label_value = label[0] if len(label) == 1 else label + condition = (scol == min_value) & scol.isNotNull() + + result = ( + F.when(condition, F.lit(label_value)) + if result is None + else F.when(condition, F.lit(label_value)).otherwise(result) + ) + + result = F.when(min_value == float("inf"), F.lit(None)).otherwise(result) + + sdf = self._internal.spark_frame.select( + *self._internal.index_spark_columns, + result.alias(SPARK_DEFAULT_SERIES_NAME), + ) + + return first_series( + DataFrame( + InternalFrame( + spark_frame=sdf, + index_spark_columns=self._internal.index_spark_columns, + index_names=self._internal.index_names, + index_fields=self._internal.index_fields, + column_labels=[None], + ) + ) + ) def info( self, diff --git a/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py new file mode 100644 index 0000000000000..eddaef3dacb37 --- /dev/null +++ b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py @@ -0,0 +1,263 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pandas as pd + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils + + +class FrameIdxMaxMinMixin: + def test_idxmax(self): + # Test basic axis=0 (default) + pdf = pd.DataFrame( + { + "a": [1, 2, 3, 2], + "b": [4.0, 2.0, 3.0, 1.0], + "c": [300, 200, 400, 200], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(), pdf.idxmax()) + self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0)) + self.assert_eq(psdf.idxmax(axis="index"), pdf.idxmax(axis="index")) + + # Test axis=1 + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + self.assert_eq(psdf.idxmax(axis="columns"), pdf.idxmax(axis="columns")) + + # Test with NAs + pdf = pd.DataFrame( + { + "a": [1.0, None, 3.0], + "b": [None, 2.0, None], + "c": [3.0, 4.0, None], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(), pdf.idxmax()) + self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0)) + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with all-NA row + pdf = pd.DataFrame( + { + "a": [1.0, None], + "b": [2.0, None], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with ties (first occurrence should win) + pdf = pd.DataFrame( + { + "a": [3, 2, 1], + "b": [3, 5, 1], + "c": [1, 5, 1], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with single column + pdf = pd.DataFrame({"a": [1, 2, 3]}) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with empty DataFrame + pdf = pd.DataFrame({}) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with different data types + pdf = pd.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.5, 2.5, 0.5], + "negative": [-5, -10, -1], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with custom index + pdf = pd.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [7, 8, 9], + }, + index=["row1", "row2", "row3"], + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + def test_idxmax_multiindex_columns(self): + # Test that MultiIndex columns raise NotImplementedError for axis=1 + pdf = pd.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [7, 8, 9], + } + ) + pdf.columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b"), ("z", "c")]) + psdf = ps.from_pandas(pdf) + + # axis=0 should work fine (it uses pandas internally) + self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0)) + + # axis=1 should raise NotImplementedError + with self.assertRaises(NotImplementedError): + psdf.idxmax(axis=1) + + def test_idxmin(self): + # Test basic axis=0 (default) + pdf = pd.DataFrame( + { + "a": [1, 2, 3, 2], + "b": [4.0, 2.0, 3.0, 1.0], + "c": [300, 200, 400, 200], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(), pdf.idxmin()) + self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0)) + self.assert_eq(psdf.idxmin(axis="index"), pdf.idxmin(axis="index")) + + # Test axis=1 + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + self.assert_eq(psdf.idxmin(axis="columns"), pdf.idxmin(axis="columns")) + + # Test with NAs + pdf = pd.DataFrame( + { + "a": [1.0, None, 3.0], + "b": [None, 2.0, None], + "c": [3.0, 4.0, None], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(), pdf.idxmin()) + self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0)) + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with all-NA row + pdf = pd.DataFrame( + { + "a": [1.0, None], + "b": [2.0, None], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with ties (first occurrence should win) + pdf = pd.DataFrame( + { + "a": [3, 2, 1], + "b": [3, 5, 1], + "c": [1, 5, 1], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with single column + pdf = pd.DataFrame({"a": [1, 2, 3]}) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with empty DataFrame + pdf = pd.DataFrame({}) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with different data types + pdf = pd.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.5, 2.5, 0.5], + "negative": [-5, -10, -1], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with custom index + pdf = pd.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [7, 8, 9], + }, + index=["row1", "row2", "row3"], + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + def test_idxmin_multiindex_columns(self): + # Test that MultiIndex columns raise NotImplementedError for axis=1 + pdf = pd.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [7, 8, 9], + } + ) + pdf.columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b"), ("z", "c")]) + psdf = ps.from_pandas(pdf) + + # axis=0 should work fine (it uses pandas internally) + self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0)) + + # axis=1 should raise NotImplementedError + with self.assertRaises(NotImplementedError): + psdf.idxmin(axis=1) + + +class FrameIdxMaxMinTests( + FrameIdxMaxMinMixin, + PandasOnSparkTestCase, + SQLTestUtils, +): + pass + + +if __name__ == "__main__": + from pyspark.testing import main + + main()