Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ def explain(
def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame:
return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession)

def zipWithIndex(self, indexColName: str = "index") -> ParentDataFrame:
return DataFrame(self._jdf.zipWithIndex(indexColName), self.sparkSession)

def isLocal(self) -> bool:
return self._jdf.isLocal()

Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
UnresolvedStar,
)
from pyspark.sql.connect.functions import builtin as F
from pyspark.sql.internal import InternalFunction
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
from pyspark.sql.pandas.functions import _validate_vectorized_udf # type: ignore[attr-defined]
from pyspark.sql.table_arg import TableArg
Expand Down Expand Up @@ -1212,6 +1213,11 @@ def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame:
res._cached_schema = self._merge_cached_schema(other)
return res

def zipWithIndex(self, indexColName: str = "index") -> ParentDataFrame:
return self.select(
F.col("*"), InternalFunction.distributed_sequence_id().alias(indexColName)
)

def intersect(self, other: ParentDataFrame) -> ParentDataFrame:
self._check_same_session(other)
res = DataFrame(
Expand Down
52 changes: 52 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,58 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame":
"""
...

@dispatch_df_method
def zipWithIndex(self, indexColName: str = "index") -> "DataFrame":
"""Returns a new :class:`DataFrame` by appending a column containing consecutive
0-based Long indices, similar to :meth:`RDD.zipWithIndex`.

The index column is appended as the last column of the resulting :class:`DataFrame`.

.. versionadded:: 4.2.0

Parameters
----------
indexColName : str, default "index"
The name of the index column to append.

Returns
-------
:class:`DataFrame`
A new DataFrame with an appended index column.

Notes
-----
If a column with `indexColName` already exists in the schema, the resulting
:class:`DataFrame` will have duplicate column names. Selecting the duplicate column
by name will throw `AMBIGUOUS_REFERENCE`, and writing the :class:`DataFrame` will
throw `COLUMN_ALREADY_EXISTS`.

Examples
--------
>>> df = spark.createDataFrame(
... [("a", 1), ("b", 2), ("c", 3)], ["letter", "number"])
>>> df.zipWithIndex().show()
+------+------+-----+
|letter|number|index|
+------+------+-----+
| a| 1| 0|
| b| 2| 1|
| c| 3| 2|
+------+------+-----+

Custom index column name:

>>> df.zipWithIndex("row_id").show()
+------+------+------+
|letter|number|row_id|
+------+------+------+
| a| 1| 0|
| b| 2| 1|
| c| 3| 2|
+------+------+------+
"""
...

@dispatch_df_method
def isLocal(self) -> bool:
"""Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally
Expand Down
30 changes: 30 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,36 @@ def test_to_json(self):
self.assertIsInstance(df, DataFrame)
self.assertEqual(df.select("value").count(), 10)

def test_zip_with_index(self):
df = self.spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["letter", "number"])

# Default column name "index"
result = df.zipWithIndex()
self.assertEqual(result.columns, ["letter", "number", "index"])
rows = result.collect()
self.assertEqual(len(rows), 3)
indices = [row["index"] for row in rows]
self.assertEqual(sorted(indices), [0, 1, 2])

# Custom column name
result = df.zipWithIndex("row_id")
self.assertEqual(result.columns, ["letter", "number", "row_id"])
rows = result.collect()
indices = [row["row_id"] for row in rows]
self.assertEqual(sorted(indices), [0, 1, 2])

# Duplicate column name causes AMBIGUOUS_REFERENCE on select
result = df.zipWithIndex("letter")
with self.assertRaises(AnalysisException) as ctx:
result.select("letter").collect()
self.assertEqual(ctx.exception.getCondition(), "AMBIGUOUS_REFERENCE")

# Duplicate column name causes COLUMN_ALREADY_EXISTS on write
with tempfile.TemporaryDirectory() as d:
with self.assertRaises(AnalysisException) as ctx:
result.write.parquet(d)
self.assertEqual(ctx.exception.getCondition(), "COLUMN_ALREADY_EXISTS")


class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase):
pass
Expand Down