Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use datafusion::execution::context::{
};
use datafusion::execution::disk_manager::DiskManagerMode;
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
use datafusion::execution::options::ReadOptions;
use datafusion::execution::options::{ArrowReadOptions, ReadOptions};
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::prelude::{
Expand Down Expand Up @@ -956,6 +956,39 @@ impl PySessionContext {
Ok(())
}

#[pyo3(signature = (name, path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))]
pub fn register_arrow(
&self,
name: &str,
path: &str,
schema: Option<PyArrowType<Schema>>,
file_extension: &str,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
py: Python,
) -> PyDataFusionResult<()> {
let mut options = ArrowReadOptions::default().table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
);
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_arrow(name, path, options);
wait_for_future(py, result)??;
Ok(())
}

pub fn register_batch(
&self,
name: &str,
batch: PyArrowType<RecordBatch>,
) -> PyDataFusionResult<()> {
self.ctx.register_batch(name, batch.0)?;
Ok(())
}

// Registers a PyArrow.Dataset
pub fn register_dataset(
&self,
Expand Down Expand Up @@ -1184,6 +1217,34 @@ impl PySessionContext {
Ok(PyDataFrame::new(df))
}

pub fn read_empty(&self) -> PyDataFusionResult<PyDataFrame> {
let df = self.ctx.read_empty()?;
Ok(PyDataFrame::new(df))
}

#[pyo3(signature = (path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))]
pub fn read_arrow(
&self,
path: &str,
schema: Option<PyArrowType<Schema>>,
file_extension: &str,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
let mut options = ArrowReadOptions::default().table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
);
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.read_arrow(path, options);
let df = wait_for_future(py, result)??;
Ok(PyDataFrame::new(df))
}

pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
let session = self.clone().into_bound_py_any(table.py())?;
let table = PyTable::new(table, Some(session))?;
Expand Down
69 changes: 69 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,15 @@ def register_udtf(self, func: TableFunction) -> None:
"""Register a user defined table function."""
self.ctx.register_udtf(func._udtf)

def register_batch(self, name: str, batch: pa.RecordBatch) -> None:
"""Register a single :py:class:`pa.RecordBatch` as a table.

Args:
name: Name of the resultant table.
batch: Record batch to register as a table.
"""
self.ctx.register_batch(name, batch)

def register_record_batches(
self, name: str, partitions: list[list[pa.RecordBatch]]
) -> None:
Expand Down Expand Up @@ -1092,6 +1101,33 @@ def register_avro(
name, str(path), schema, file_extension, table_partition_cols
)

def register_arrow(
self,
name: str,
path: str | pathlib.Path,
schema: pa.Schema | None = None,
file_extension: str = ".arrow",
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
) -> None:
"""Register an Arrow IPC file as a table.

The registered table can be referenced from SQL statements executed
against this context.

Args:
name: Name of the table to register.
path: Path to the Arrow IPC file.
schema: The data source schema.
file_extension: File extension to select.
table_partition_cols: Partition columns.
"""
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols = _convert_table_partition_cols(table_partition_cols)
self.ctx.register_arrow(
name, str(path), schema, file_extension, table_partition_cols
)

def register_dataset(self, name: str, dataset: pa.dataset.Dataset) -> None:
"""Register a :py:class:`pa.dataset.Dataset` as a table.

Expand Down Expand Up @@ -1328,6 +1364,39 @@ def read_avro(
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
)

def read_arrow(
self,
path: str | pathlib.Path,
schema: pa.Schema | None = None,
file_extension: str = ".arrow",
file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
) -> DataFrame:
"""Create a :py:class:`DataFrame` for reading an Arrow IPC data source.

Args:
path: Path to the Arrow IPC file.
schema: The data source schema.
file_extension: File extension to select.
file_partition_cols: Partition columns.

Returns:
DataFrame representation of the read Arrow IPC file.
"""
if file_partition_cols is None:
file_partition_cols = []
file_partition_cols = _convert_table_partition_cols(file_partition_cols)
return DataFrame(
self.ctx.read_arrow(str(path), schema, file_extension, file_partition_cols)
)

def read_empty(self) -> DataFrame:
"""Create an empty :py:class:`DataFrame` with no columns or rows.

Returns:
An empty DataFrame.
"""
return DataFrame(self.ctx.read_empty())

def read_table(
self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset
) -> DataFrame:
Expand Down
39 changes: 39 additions & 0 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,45 @@ def test_read_avro(ctx):
assert avro_df is not None


def test_read_arrow(ctx, tmp_path):
# Write an Arrow IPC file, then read it back
table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]})
arrow_path = tmp_path / "test.arrow"
with pa.ipc.new_file(str(arrow_path), table.schema) as writer:
writer.write_table(table)

df = ctx.read_arrow(str(arrow_path))
result = df.collect()
assert result[0].column(0) == pa.array([1, 2, 3])
assert result[0].column(1) == pa.array(["x", "y", "z"])


def test_read_empty(ctx):
df = ctx.read_empty()
result = df.collect()
assert result[0].num_columns == 0


def test_register_arrow(ctx, tmp_path):
# Write an Arrow IPC file, then register and query it
table = pa.table({"x": [10, 20, 30]})
arrow_path = tmp_path / "test.arrow"
with pa.ipc.new_file(str(arrow_path), table.schema) as writer:
writer.write_table(table)

ctx.register_arrow("arrow_tbl", str(arrow_path))
result = ctx.sql("SELECT * FROM arrow_tbl").collect()
assert result[0].column(0) == pa.array([10, 20, 30])


def test_register_batch(ctx):
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
ctx.register_batch("batch_tbl", batch)
result = ctx.sql("SELECT * FROM batch_tbl").collect()
assert result[0].column(0) == pa.array([1, 2, 3])
assert result[0].column(1) == pa.array([4, 5, 6])


def test_create_sql_options():
SQLOptions()

Expand Down
Loading