From a9ad3a51b9069535b0dbfdf3c534aacf6088c1c1 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Wed, 28 Jan 2026 09:47:45 -0600 Subject: [PATCH 1/4] [SPARK-46168][PS] Add axis argument for idxmax Signed-off-by: Devin Petersohn Co-authored-by: Devin Petersohn --- python/pyspark/pandas/frame.py | 90 +++++++++++++++++++++++++++------- 1 file changed, 73 insertions(+), 17 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 1c66bbec37b73..af2518d2dcee9 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -12138,7 +12138,6 @@ def op(psser: ps.Series) -> PySparkColumn: return self._apply_series_op(op, should_resolve=True) - # TODO(SPARK-46168): axis = 1 def idxmax(self, axis: Axis = 0) -> "Series": """ Return index of first occurrence of maximum over requested axis. @@ -12149,8 +12148,8 @@ def idxmax(self, axis: Axis = 0) -> "Series": Parameters ---------- - axis : 0 or 'index' - Can only be set to 0 now. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for column-wise. Returns ------- @@ -12197,24 +12196,81 @@ def idxmax(self, axis: Axis = 0) -> "Series": b y 0 c z 2 dtype: int64 + + For axis=1, return the column label of the maximum value in each row: + + >>> psdf.idxmax(axis=1) + 0 c + 1 c + 2 c + 3 c + dtype: object """ - max_cols = map(lambda scol: F.max(scol), self._internal.data_spark_columns) - sdf_max = self._internal.spark_frame.select(*max_cols).head() - # `sdf_max` looks like below - # +------+------+------+ - # |(a, x)|(b, y)|(c, z)| - # +------+------+------+ - # | 3| 4.0| 400| - # +------+------+------+ + axis = validate_axis(axis) + if axis == 0: + max_cols = map(lambda scol: F.max(scol), self._internal.data_spark_columns) + sdf_max = self._internal.spark_frame.select(*max_cols).head() + # `sdf_max` looks like below + # +------+------+------+ + # |(a, x)|(b, y)|(c, z)| + # +------+------+------+ + # | 3| 4.0| 400| + # +------+------+------+ + + conds = ( + scol == max_val for scol, max_val in zip(self._internal.data_spark_columns, sdf_max) + ) + cond = reduce(lambda x, y: x | y, conds) - conds = ( - scol == max_val for scol, max_val in zip(self._internal.data_spark_columns, sdf_max) - ) - cond = reduce(lambda x, y: x | y, conds) + psdf: DataFrame = DataFrame(self._internal.with_filter(cond)) - psdf: DataFrame = DataFrame(self._internal.with_filter(cond)) + return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmax())) + else: + from pyspark.pandas.series import first_series + + column_labels = self._internal.column_labels + + if len(column_labels) == 0: + return ps.Series([], dtype=np.int64) + + if self._internal.column_labels_level > 1: + raise NotImplementedError( + "idxmax with axis=1 does not support MultiIndex columns yet" + ) + + max_value = F.greatest( + *[F.coalesce(self._internal.spark_column_for(label), F.lit(float('-inf'))) + for label in column_labels], + F.lit(float('-inf')) + ) + + result = None + for label in reversed(column_labels): + scol = self._internal.spark_column_for(label) + label_value = label[0] if len(label) == 1 else label + condition = (scol == max_value) & scol.isNotNull() + + result = (F.when(condition, F.lit(label_value)) if result is None + else F.when(condition, F.lit(label_value)).otherwise(result)) + + result = F.when(max_value == float('-inf'), F.lit(None)).otherwise(result) + + sdf = self._internal.spark_frame.select( + *self._internal_frame.index_spark_columns, + result.alias(SPARK_DEFAULT_SERIES_NAME), + ) - return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmax())) + return first_series( + DataFrame( + InternalFrame( + spark_frame=sdf, + index_spark_columns=self._internal.index_spark_columns, + index_names=self._internal.index_names, + index_fields=self._internal.index_fields, + column_labels=[None], + ) + ) + ) # TODO(SPARK-46168): axis = 1 def idxmin(self, axis: Axis = 0) -> "Series": From b0f79bd6bf1cb5ee7a1c08c96a6f0650375e7599 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Wed, 28 Jan 2026 09:49:07 -0600 Subject: [PATCH 2/4] Add tests Signed-off-by: Devin Petersohn --- .../tests/computation/test_idxmax_idxmin.py | 151 ++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py diff --git a/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py new file mode 100644 index 0000000000000..2b5e617ffff52 --- /dev/null +++ b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py @@ -0,0 +1,151 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pandas as pd + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils + + +class FrameIdxMaxMinMixin: + def test_idxmax(self): + # Test basic axis=0 (default) + pdf = pd.DataFrame( + { + "a": [1, 2, 3, 2], + "b": [4.0, 2.0, 3.0, 1.0], + "c": [300, 200, 400, 200], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(), pdf.idxmax()) + self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0)) + self.assert_eq(psdf.idxmax(axis="index"), pdf.idxmax(axis="index")) + + # Test axis=1 + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + self.assert_eq(psdf.idxmax(axis="columns"), pdf.idxmax(axis="columns")) + + # Test with NAs + pdf = pd.DataFrame( + { + "a": [1.0, None, 3.0], + "b": [None, 2.0, None], + "c": [3.0, 4.0, None], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(), pdf.idxmax()) + self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0)) + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with all-NA row + pdf = pd.DataFrame( + { + "a": [1.0, None], + "b": [2.0, None], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with ties (first occurrence should win) + pdf = pd.DataFrame( + { + "a": [3, 2, 1], + "b": [3, 5, 1], + "c": [1, 5, 1], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with single column + pdf = pd.DataFrame({"a": [1, 2, 3]}) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with empty DataFrame + pdf = pd.DataFrame({}) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with different data types + pdf = pd.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.5, 2.5, 0.5], + "negative": [-5, -10, -1], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + # Test with custom index + pdf = pd.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [7, 8, 9], + }, + index=["row1", "row2", "row3"], + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1)) + + def test_idxmax_multiindex_columns(self): + # Test that MultiIndex columns raise NotImplementedError for axis=1 + pdf = pd.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [7, 8, 9], + } + ) + pdf.columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b"), ("z", "c")]) + psdf = ps.from_pandas(pdf) + + # axis=0 should work fine (it uses pandas internally) + self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0)) + + # axis=1 should raise NotImplementedError + with self.assertRaises(NotImplementedError): + psdf.idxmax(axis=1) + + +class FrameIdxMaxMinTests( + FrameIdxMaxMinMixin, + PandasOnSparkTestCase, + SQLTestUtils, +): + pass + + +if __name__ == "__main__": + from pyspark.testing import main + + main() From c5b211ad471aead208d7fa88783bb6f540fb5279 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Thu, 5 Feb 2026 11:06:29 -0600 Subject: [PATCH 3/4] Fix doc and lint Signed-off-by: Devin Petersohn --- python/pyspark/pandas/frame.py | 35 +++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index af2518d2dcee9..a78e5c1e1cf65 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -12177,6 +12177,15 @@ def idxmax(self, axis: Axis = 0) -> "Series": c 2 dtype: int64 + For axis=1, return the column label of the maximum value in each row: + + >>> psdf.idxmax(axis=1) + 0 c + 1 c + 2 c + 3 c + dtype: object + For Multi-column Index >>> psdf = ps.DataFrame({'a': [1, 2, 3, 2], @@ -12196,15 +12205,6 @@ def idxmax(self, axis: Axis = 0) -> "Series": b y 0 c z 2 dtype: int64 - - For axis=1, return the column label of the maximum value in each row: - - >>> psdf.idxmax(axis=1) - 0 c - 1 c - 2 c - 3 c - dtype: object """ axis = validate_axis(axis) if axis == 0: @@ -12239,9 +12239,11 @@ def idxmax(self, axis: Axis = 0) -> "Series": ) max_value = F.greatest( - *[F.coalesce(self._internal.spark_column_for(label), F.lit(float('-inf'))) - for label in column_labels], - F.lit(float('-inf')) + *[ + F.coalesce(self._internal.spark_column_for(label), F.lit(float("-inf"))) + for label in column_labels + ], + F.lit(float("-inf")), ) result = None @@ -12250,10 +12252,13 @@ def idxmax(self, axis: Axis = 0) -> "Series": label_value = label[0] if len(label) == 1 else label condition = (scol == max_value) & scol.isNotNull() - result = (F.when(condition, F.lit(label_value)) if result is None - else F.when(condition, F.lit(label_value)).otherwise(result)) + result = ( + F.when(condition, F.lit(label_value)) + if result is None + else F.when(condition, F.lit(label_value)).otherwise(result) + ) - result = F.when(max_value == float('-inf'), F.lit(None)).otherwise(result) + result = F.when(max_value == float("-inf"), F.lit(None)).otherwise(result) sdf = self._internal.spark_frame.select( *self._internal_frame.index_spark_columns, From 87a1a68f9e4fd102b69ed0b38c54037c696fce7c Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Fri, 6 Feb 2026 13:26:51 -0600 Subject: [PATCH 4/4] [SPARK-46168][PS] Implementation of idxmin Axis argument Signed-off-by: Devin Petersohn Co-authored-by: Devin Petersohn --- python/pyspark/pandas/frame.py | 83 +++++++++++-- .../tests/computation/test_idxmax_idxmin.py | 112 ++++++++++++++++++ 2 files changed, 184 insertions(+), 11 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index eec9d96a7136a..208ad611e0016 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -12294,7 +12294,6 @@ def idxmax(self, axis: Axis = 0) -> "Series": ) ) - # TODO(SPARK-46168): axis = 1 def idxmin(self, axis: Axis = 0) -> "Series": """ Return index of first occurrence of minimum over requested axis. @@ -12305,8 +12304,8 @@ def idxmin(self, axis: Axis = 0) -> "Series": Parameters ---------- - axis : 0 or 'index' - Can only be set to 0 now. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for column-wise. Returns ------- @@ -12353,18 +12352,80 @@ def idxmin(self, axis: Axis = 0) -> "Series": b y 3 c z 1 dtype: int64 + + For axis=1, return the column label of the minimum value in each row: + + >>> psdf.idxmin(axis=1) + 0 a + 1 b + 2 b + 3 b + dtype: object """ - min_cols = map(lambda scol: F.min(scol), self._internal.data_spark_columns) - sdf_min = self._internal.spark_frame.select(*min_cols).head() + axis = validate_axis(axis) + if axis == 0: + min_cols = map(lambda scol: F.min(scol), self._internal.data_spark_columns) + sdf_min = self._internal.spark_frame.select(*min_cols).head() - conds = ( - scol == min_val for scol, min_val in zip(self._internal.data_spark_columns, sdf_min) - ) - cond = reduce(lambda x, y: x | y, conds) + conds = ( + scol == min_val for scol, min_val in zip(self._internal.data_spark_columns, sdf_min) + ) + cond = reduce(lambda x, y: x | y, conds) - psdf: DataFrame = DataFrame(self._internal.with_filter(cond)) + psdf: DataFrame = DataFrame(self._internal.with_filter(cond)) - return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmin())) + return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmin())) + else: + from pyspark.pandas.series import first_series + + column_labels = self._internal.column_labels + + if len(column_labels) == 0: + return ps.Series([], dtype=np.int64) + + if self._internal.column_labels_level > 1: + raise NotImplementedError( + "idxmin with axis=1 does not support MultiIndex columns yet" + ) + + min_value = F.least( + *[ + F.coalesce(self._internal.spark_column_for(label), F.lit(float("inf"))) + for label in column_labels + ], + F.lit(float("inf")), + ) + + result = None + for label in reversed(column_labels): + scol = self._internal.spark_column_for(label) + label_value = label[0] if len(label) == 1 else label + condition = (scol == min_value) & scol.isNotNull() + + result = ( + F.when(condition, F.lit(label_value)) + if result is None + else F.when(condition, F.lit(label_value)).otherwise(result) + ) + + result = F.when(min_value == float("inf"), F.lit(None)).otherwise(result) + + sdf = self._internal.spark_frame.select( + *self._internal.index_spark_columns, + result.alias(SPARK_DEFAULT_SERIES_NAME), + ) + + return first_series( + DataFrame( + InternalFrame( + spark_frame=sdf, + index_spark_columns=self._internal.index_spark_columns, + index_names=self._internal.index_names, + index_fields=self._internal.index_fields, + column_labels=[None], + ) + ) + ) def info( self, diff --git a/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py index 2b5e617ffff52..eddaef3dacb37 100644 --- a/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py +++ b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py @@ -136,6 +136,118 @@ def test_idxmax_multiindex_columns(self): with self.assertRaises(NotImplementedError): psdf.idxmax(axis=1) + def test_idxmin(self): + # Test basic axis=0 (default) + pdf = pd.DataFrame( + { + "a": [1, 2, 3, 2], + "b": [4.0, 2.0, 3.0, 1.0], + "c": [300, 200, 400, 200], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(), pdf.idxmin()) + self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0)) + self.assert_eq(psdf.idxmin(axis="index"), pdf.idxmin(axis="index")) + + # Test axis=1 + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + self.assert_eq(psdf.idxmin(axis="columns"), pdf.idxmin(axis="columns")) + + # Test with NAs + pdf = pd.DataFrame( + { + "a": [1.0, None, 3.0], + "b": [None, 2.0, None], + "c": [3.0, 4.0, None], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(), pdf.idxmin()) + self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0)) + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with all-NA row + pdf = pd.DataFrame( + { + "a": [1.0, None], + "b": [2.0, None], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with ties (first occurrence should win) + pdf = pd.DataFrame( + { + "a": [3, 2, 1], + "b": [3, 5, 1], + "c": [1, 5, 1], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with single column + pdf = pd.DataFrame({"a": [1, 2, 3]}) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with empty DataFrame + pdf = pd.DataFrame({}) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with different data types + pdf = pd.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.5, 2.5, 0.5], + "negative": [-5, -10, -1], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + # Test with custom index + pdf = pd.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [7, 8, 9], + }, + index=["row1", "row2", "row3"], + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1)) + + def test_idxmin_multiindex_columns(self): + # Test that MultiIndex columns raise NotImplementedError for axis=1 + pdf = pd.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [7, 8, 9], + } + ) + pdf.columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b"), ("z", "c")]) + psdf = ps.from_pandas(pdf) + + # axis=0 should work fine (it uses pandas internally) + self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0)) + + # axis=1 should raise NotImplementedError + with self.assertRaises(NotImplementedError): + psdf.idxmin(axis=1) + class FrameIdxMaxMinTests( FrameIdxMaxMinMixin,