From c40d37baf7f336d750e405c5183cb20d36ccd109 Mon Sep 17 00:00:00 2001 From: Fangchen Li Date: Fri, 6 Feb 2026 21:59:41 -0800 Subject: [PATCH 1/4] [SPARK-55229][PYTHON] Implement DataFrame.zipWithIndex in PySpark Classic --- python/pyspark/sql/classic/dataframe.py | 6 +++ python/pyspark/sql/dataframe.py | 52 ++++++++++++++++++++++ python/pyspark/sql/tests/test_dataframe.py | 24 ++++++++++ 3 files changed, 82 insertions(+) diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index 41d9d41667274..986a50209ac99 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -57,6 +57,7 @@ from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.column import Column from pyspark.sql.functions import builtin as F +from pyspark.sql.internal import InternalFunction from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 from pyspark.sql.merge import MergeIntoWriter @@ -280,6 +281,11 @@ def explain( def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame: return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession) + def zipWithIndex(self, indexColName: str = "index") -> ParentDataFrame: + return self.select( + F.col("*"), InternalFunction.distributed_sequence_id().alias(indexColName) + ) + def isLocal(self) -> bool: return self._jdf.isLocal() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c6f348ce600ad..a1607df4ecef1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -782,6 +782,58 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": """ ... + @dispatch_df_method + def zipWithIndex(self, indexColName: str = "index") -> "DataFrame": + """Returns a new :class:`DataFrame` by appending a column containing consecutive + 0-based Long indices, similar to :meth:`RDD.zipWithIndex`. + + The index column is appended as the last column of the resulting :class:`DataFrame`. + + .. versionadded:: 4.2.0 + + Parameters + ---------- + indexColName : str, default "index" + The name of the index column to append. + + Returns + ------- + :class:`DataFrame` + A new DataFrame with an appended index column. + + Notes + ----- + If a column with `indexColName` already exists in the schema, the resulting + :class:`DataFrame` will have duplicate column names. Selecting the duplicate column + by name will throw `AMBIGUOUS_REFERENCE`, and writing the :class:`DataFrame` will + throw `COLUMN_ALREADY_EXISTS`. + + Examples + -------- + >>> df = spark.createDataFrame( + ... [("a", 1), ("b", 2), ("c", 3)], ["letter", "number"]) + >>> df.zipWithIndex().show() + +------+------+-----+ + |letter|number|index| + +------+------+-----+ + | a| 1| 0| + | b| 2| 1| + | c| 3| 2| + +------+------+-----+ + + Custom index column name: + + >>> df.zipWithIndex("row_id").show() + +------+------+------+ + |letter|number|row_id| + +------+------+------+ + | a| 1| 0| + | b| 2| 1| + | c| 3| 2| + +------+------+------+ + """ + ... + @dispatch_df_method def isLocal(self) -> bool: """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index d850f2598a2de..4444f281e7823 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -1214,6 +1214,30 @@ def test_to_json(self): self.assertIsInstance(df, DataFrame) self.assertEqual(df.select("value").count(), 10) + def test_zip_with_index(self): + df = self.spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["letter", "number"]) + + # Default column name "index" + result = df.zipWithIndex() + self.assertEqual(result.columns, ["letter", "number", "index"]) + rows = result.collect() + self.assertEqual(len(rows), 3) + indices = [row["index"] for row in rows] + self.assertEqual(sorted(indices), [0, 1, 2]) + + # Custom column name + result = df.zipWithIndex("row_id") + self.assertEqual(result.columns, ["letter", "number", "row_id"]) + rows = result.collect() + indices = [row["row_id"] for row in rows] + self.assertEqual(sorted(indices), [0, 1, 2]) + + # Duplicate column name causes AMBIGUOUS_REFERENCE on select + result = df.zipWithIndex("letter") + with self.assertRaises(AnalysisException) as ctx: + result.select("letter").collect() + self.assertEqual(ctx.exception.getCondition(), "AMBIGUOUS_REFERENCE") + class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase): pass From 29b435e01bb4f9e887e4e01ff9924abbfe164973 Mon Sep 17 00:00:00 2001 From: Fangchen Li Date: Fri, 6 Feb 2026 22:09:24 -0800 Subject: [PATCH 2/4] add error test --- python/pyspark/sql/tests/test_dataframe.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 4444f281e7823..7aa25605b2ee4 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -1238,6 +1238,12 @@ def test_zip_with_index(self): result.select("letter").collect() self.assertEqual(ctx.exception.getCondition(), "AMBIGUOUS_REFERENCE") + # Duplicate column name causes COLUMN_ALREADY_EXISTS on write + with tempfile.TemporaryDirectory() as d: + with self.assertRaises(AnalysisException) as ctx: + result.write.parquet(d) + self.assertEqual(ctx.exception.getCondition(), "COLUMN_ALREADY_EXISTS") + class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase): pass From ca5e8d0c40f7bb9d494d5baf1828bcc90697669d Mon Sep 17 00:00:00 2001 From: Fangchen Li Date: Fri, 6 Feb 2026 23:13:13 -0800 Subject: [PATCH 3/4] add zipWithIndex in connect --- python/pyspark/sql/connect/dataframe.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 0df13c1020d7f..8de2f6db30e91 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -87,6 +87,7 @@ UnresolvedStar, ) from pyspark.sql.connect.functions import builtin as F +from pyspark.sql.internal import InternalFunction from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema from pyspark.sql.pandas.functions import _validate_vectorized_udf # type: ignore[attr-defined] from pyspark.sql.table_arg import TableArg @@ -1212,6 +1213,11 @@ def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame: res._cached_schema = self._merge_cached_schema(other) return res + def zipWithIndex(self, indexColName: str = "index") -> ParentDataFrame: + return self.select( + F.col("*"), InternalFunction.distributed_sequence_id().alias(indexColName) + ) + def intersect(self, other: ParentDataFrame) -> ParentDataFrame: self._check_same_session(other) res = DataFrame( From f3f51793229e41ff559cb8647ee826ebb5164f22 Mon Sep 17 00:00:00 2001 From: Fangchen Li Date: Sat, 7 Feb 2026 09:36:13 -0800 Subject: [PATCH 4/4] directly invoke jvm method --- python/pyspark/sql/classic/dataframe.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index 986a50209ac99..854ab0ff89a07 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -57,7 +57,6 @@ from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.column import Column from pyspark.sql.functions import builtin as F -from pyspark.sql.internal import InternalFunction from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 from pyspark.sql.merge import MergeIntoWriter @@ -282,9 +281,7 @@ def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame: return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession) def zipWithIndex(self, indexColName: str = "index") -> ParentDataFrame: - return self.select( - F.col("*"), InternalFunction.distributed_sequence_id().alias(indexColName) - ) + return DataFrame(self._jdf.zipWithIndex(indexColName), self.sparkSession) def isLocal(self) -> bool: return self._jdf.isLocal()