Skip to content

Commit 30fc3d5

Browse files
timsaucerclaude
andcommitted
Add missing SessionContext read/register methods for Arrow IPC and batches
Add read_arrow, read_empty, register_arrow, and register_batch methods to SessionContext, exposing upstream DataFusion v53 functionality. The write_* methods and read_batch/read_batches are already covered by DataFrame.write_* and SessionContext.from_arrow respectively. Closes #1458. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2499409 commit 30fc3d5

File tree

3 files changed

+170
-1
lines changed

3 files changed

+170
-1
lines changed

crates/core/src/context.rs

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use datafusion::execution::context::{
4141
};
4242
use datafusion::execution::disk_manager::DiskManagerMode;
4343
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
44-
use datafusion::execution::options::ReadOptions;
44+
use datafusion::execution::options::{ArrowReadOptions, ReadOptions};
4545
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
4646
use datafusion::execution::session_state::SessionStateBuilder;
4747
use datafusion::prelude::{
@@ -956,6 +956,39 @@ impl PySessionContext {
956956
Ok(())
957957
}
958958

959+
#[pyo3(signature = (name, path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))]
960+
pub fn register_arrow(
961+
&self,
962+
name: &str,
963+
path: &str,
964+
schema: Option<PyArrowType<Schema>>,
965+
file_extension: &str,
966+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
967+
py: Python,
968+
) -> PyDataFusionResult<()> {
969+
let mut options = ArrowReadOptions::default().table_partition_cols(
970+
table_partition_cols
971+
.into_iter()
972+
.map(|(name, ty)| (name, ty.0))
973+
.collect::<Vec<(String, DataType)>>(),
974+
);
975+
options.file_extension = file_extension;
976+
options.schema = schema.as_ref().map(|x| &x.0);
977+
978+
let result = self.ctx.register_arrow(name, path, options);
979+
wait_for_future(py, result)??;
980+
Ok(())
981+
}
982+
983+
pub fn register_batch(
984+
&self,
985+
name: &str,
986+
batch: PyArrowType<RecordBatch>,
987+
) -> PyDataFusionResult<()> {
988+
self.ctx.register_batch(name, batch.0)?;
989+
Ok(())
990+
}
991+
959992
// Registers a PyArrow.Dataset
960993
pub fn register_dataset(
961994
&self,
@@ -1184,6 +1217,34 @@ impl PySessionContext {
11841217
Ok(PyDataFrame::new(df))
11851218
}
11861219

1220+
pub fn read_empty(&self) -> PyDataFusionResult<PyDataFrame> {
1221+
let df = self.ctx.read_empty()?;
1222+
Ok(PyDataFrame::new(df))
1223+
}
1224+
1225+
#[pyo3(signature = (path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))]
1226+
pub fn read_arrow(
1227+
&self,
1228+
path: &str,
1229+
schema: Option<PyArrowType<Schema>>,
1230+
file_extension: &str,
1231+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1232+
py: Python,
1233+
) -> PyDataFusionResult<PyDataFrame> {
1234+
let mut options = ArrowReadOptions::default().table_partition_cols(
1235+
table_partition_cols
1236+
.into_iter()
1237+
.map(|(name, ty)| (name, ty.0))
1238+
.collect::<Vec<(String, DataType)>>(),
1239+
);
1240+
options.file_extension = file_extension;
1241+
options.schema = schema.as_ref().map(|x| &x.0);
1242+
1243+
let result = self.ctx.read_arrow(path, options);
1244+
let df = wait_for_future(py, result)??;
1245+
Ok(PyDataFrame::new(df))
1246+
}
1247+
11871248
pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
11881249
let session = self.clone().into_bound_py_any(table.py())?;
11891250
let table = PyTable::new(table, Some(session))?;

python/datafusion/context.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,15 @@ def register_udtf(self, func: TableFunction) -> None:
894894
"""Register a user defined table function."""
895895
self.ctx.register_udtf(func._udtf)
896896

897+
def register_batch(self, name: str, batch: pa.RecordBatch) -> None:
898+
"""Register a single :py:class:`pa.RecordBatch` as a table.
899+
900+
Args:
901+
name: Name of the resultant table.
902+
batch: Record batch to register as a table.
903+
"""
904+
self.ctx.register_batch(name, batch)
905+
897906
def register_record_batches(
898907
self, name: str, partitions: list[list[pa.RecordBatch]]
899908
) -> None:
@@ -1092,6 +1101,33 @@ def register_avro(
10921101
name, str(path), schema, file_extension, table_partition_cols
10931102
)
10941103

1104+
def register_arrow(
1105+
self,
1106+
name: str,
1107+
path: str | pathlib.Path,
1108+
schema: pa.Schema | None = None,
1109+
file_extension: str = ".arrow",
1110+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
1111+
) -> None:
1112+
"""Register an Arrow IPC file as a table.
1113+
1114+
The registered table can be referenced from SQL statements executed
1115+
against this context.
1116+
1117+
Args:
1118+
name: Name of the table to register.
1119+
path: Path to the Arrow IPC file.
1120+
schema: The data source schema.
1121+
file_extension: File extension to select.
1122+
table_partition_cols: Partition columns.
1123+
"""
1124+
if table_partition_cols is None:
1125+
table_partition_cols = []
1126+
table_partition_cols = _convert_table_partition_cols(table_partition_cols)
1127+
self.ctx.register_arrow(
1128+
name, str(path), schema, file_extension, table_partition_cols
1129+
)
1130+
10951131
def register_dataset(self, name: str, dataset: pa.dataset.Dataset) -> None:
10961132
"""Register a :py:class:`pa.dataset.Dataset` as a table.
10971133
@@ -1328,6 +1364,39 @@ def read_avro(
13281364
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
13291365
)
13301366

1367+
def read_arrow(
1368+
self,
1369+
path: str | pathlib.Path,
1370+
schema: pa.Schema | None = None,
1371+
file_extension: str = ".arrow",
1372+
file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
1373+
) -> DataFrame:
1374+
"""Create a :py:class:`DataFrame` for reading an Arrow IPC data source.
1375+
1376+
Args:
1377+
path: Path to the Arrow IPC file.
1378+
schema: The data source schema.
1379+
file_extension: File extension to select.
1380+
file_partition_cols: Partition columns.
1381+
1382+
Returns:
1383+
DataFrame representation of the read Arrow IPC file.
1384+
"""
1385+
if file_partition_cols is None:
1386+
file_partition_cols = []
1387+
file_partition_cols = _convert_table_partition_cols(file_partition_cols)
1388+
return DataFrame(
1389+
self.ctx.read_arrow(str(path), schema, file_extension, file_partition_cols)
1390+
)
1391+
1392+
def read_empty(self) -> DataFrame:
1393+
"""Create an empty :py:class:`DataFrame` with no columns or rows.
1394+
1395+
Returns:
1396+
An empty DataFrame.
1397+
"""
1398+
return DataFrame(self.ctx.read_empty())
1399+
13311400
def read_table(
13321401
self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset
13331402
) -> DataFrame:

python/tests/test_context.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,45 @@ def test_read_avro(ctx):
668668
assert avro_df is not None
669669

670670

671+
def test_read_arrow(ctx, tmp_path):
672+
# Write an Arrow IPC file, then read it back
673+
table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]})
674+
arrow_path = tmp_path / "test.arrow"
675+
with pa.ipc.new_file(str(arrow_path), table.schema) as writer:
676+
writer.write_table(table)
677+
678+
df = ctx.read_arrow(str(arrow_path))
679+
result = df.collect()
680+
assert result[0].column(0) == pa.array([1, 2, 3])
681+
assert result[0].column(1) == pa.array(["x", "y", "z"])
682+
683+
684+
def test_read_empty(ctx):
685+
df = ctx.read_empty()
686+
result = df.collect()
687+
assert result[0].num_columns == 0
688+
689+
690+
def test_register_arrow(ctx, tmp_path):
691+
# Write an Arrow IPC file, then register and query it
692+
table = pa.table({"x": [10, 20, 30]})
693+
arrow_path = tmp_path / "test.arrow"
694+
with pa.ipc.new_file(str(arrow_path), table.schema) as writer:
695+
writer.write_table(table)
696+
697+
ctx.register_arrow("arrow_tbl", str(arrow_path))
698+
result = ctx.sql("SELECT * FROM arrow_tbl").collect()
699+
assert result[0].column(0) == pa.array([10, 20, 30])
700+
701+
702+
def test_register_batch(ctx):
703+
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
704+
ctx.register_batch("batch_tbl", batch)
705+
result = ctx.sql("SELECT * FROM batch_tbl").collect()
706+
assert result[0].column(0) == pa.array([1, 2, 3])
707+
assert result[0].column(1) == pa.array([4, 5, 6])
708+
709+
671710
def test_create_sql_options():
672711
SQLOptions()
673712

0 commit comments

Comments
 (0)