Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/docs/source/tutorial/sql/python_data_source.rst
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ The following example demonstrates how to implement a basic Data Source using Ar

from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql import SparkSession
from pyspark.sql.pandas.types import to_arrow_schema
import pyarrow as pa

# Define the ArrowBatchDataSource
Expand All @@ -534,14 +535,14 @@ The following example demonstrates how to implement a basic Data Source using Ar
class ArrowBatchDataSourceReader(DataSourceReader):
def __init__(self, schema, options):
self.schema: str = schema
self.arrow_schema = to_arrow_schema(self.schema)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that this is an example.
Do this example really works?

the self.schema: str is a str, but the to_arrow_schema accept a Spark StructType as input, not a str

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moreover, to_arrow_schema method is not a public API, I think we should not use it in examples?
cc @HyukjinKwon

Copy link
Copy Markdown
Author

@casgie casgie Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, curiously it worked (tested on Databricks DBR 17.3 with Spark 4.0.0)
We can ofc change to

def schema(self):
    return StructType([
        StructField("key", IntegerType(), True),
        StructField("value", StringType(), True),
    ])

Regarding your second point yes & fair, but maybe we can find another way to avoid having users specify schemas twice?
Maybe allow using a PyArrow schema as a schema definition in DataSource?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe allow using a PyArrow schema as a schema definition in DataSource?

LGTM

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to implement that then & change the PR accordingly

self.options = options

def read(self, partition):
# Create Arrow Record Batch
keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
record_batch = pa.RecordBatch.from_arrays([keys, values], schema=self.arrow_schema)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems the schema = pa.schema([("key", pa.int32()), ("value", pa.string())]) is dropped, is this expected?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is the core idea of this PR, to avoid specifying the arrow schema and the PySpark schema.

yield record_batch

def partitions(self):
Expand Down
13 changes: 12 additions & 1 deletion python/pyspark/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from pyspark.errors import PySparkNotImplementedError

if TYPE_CHECKING:
import pyarrow as pa
from pyarrow import RecordBatch
from pyspark.sql.session import SparkSession

Expand Down Expand Up @@ -115,7 +116,7 @@ def name(cls) -> str:
"""
return cls.__name__

def schema(self) -> Union[StructType, str]:
def schema(self) -> Union[StructType, str, "pa.schema"]:
"""
Returns the schema of the data source.

Expand All @@ -142,6 +143,16 @@ def schema(self) -> Union[StructType, str]:

>>> def schema(self):
... return StructType().add("a", "int").add("b", "string")

Returns a PyArrow schema:

>>> def schema(self):
... return pa.schema([
pa.field("a", pa.int64()),
pa.field("b", pa.string()),
])


"""
raise PySparkNotImplementedError(
errorClass="NOT_IMPLEMENTED",
Expand Down
9 changes: 8 additions & 1 deletion python/pyspark/sql/worker/create_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
write_with_length,
)
from pyspark.sql.datasource import DataSource, CaseInsensitiveDict
from pyspark.sql.types import _parse_datatype_json_string, StructType
from pyspark.sql.types import _parse_datatype_json_string, from_arrow_schema, StructType
from pyspark.sql.worker.utils import worker_run
from pyspark.util import local_connect_and_auth
from pyspark.worker_util import (
Expand All @@ -36,6 +36,8 @@
utf8_deserializer,
)

import pyarrow as pa


def _main(infile: IO, outfile: IO) -> None:
"""
Expand Down Expand Up @@ -125,6 +127,11 @@ def _main(infile: IO, outfile: IO) -> None:
# Here we cannot use _parse_datatype_string to parse the DDL string schema.
# as it requires an active Spark session.
is_ddl_string = True
if isinstance(schema, pa.schema):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, let me be more clear. What I thought is to add pa.schema in the example ArrowBatchDataSourceReader.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@allisonwang-db for this change

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't get what you mean.
Could you please elaborate?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering whether we can directly pass the pyarrow schema in this example

    class ArrowBatchDataSourceReader(DataSourceReader):
        def __init__(self, pa_schema, options):
            self.pa_schema: pa.Schema = pa_schema

Not sure whether this works, @allisonwang-db is this allowed?

# Convert Arrow schema to Spark schema for compatibility,
# as the data source API in Python allows data source to
# return Arrow schema directly.
schema = from_arrow_schema(schema)
else:
schema = user_specified_schema # type: ignore

Expand Down