diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 53994d2f5..4c001b55b 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -41,7 +41,7 @@ use datafusion::execution::context::{ }; use datafusion::execution::disk_manager::DiskManagerMode; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool}; -use datafusion::execution::options::ReadOptions; +use datafusion::execution::options::{ArrowReadOptions, ReadOptions}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::{ @@ -956,6 +956,39 @@ impl PySessionContext { Ok(()) } + #[pyo3(signature = (name, path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))] + pub fn register_arrow( + &self, + name: &str, + path: &str, + schema: Option>, + file_extension: &str, + table_partition_cols: Vec<(String, PyArrowType)>, + py: Python, + ) -> PyDataFusionResult<()> { + let mut options = ArrowReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + + let result = self.ctx.register_arrow(name, path, options); + wait_for_future(py, result)??; + Ok(()) + } + + pub fn register_batch( + &self, + name: &str, + batch: PyArrowType, + ) -> PyDataFusionResult<()> { + self.ctx.register_batch(name, batch.0)?; + Ok(()) + } + // Registers a PyArrow.Dataset pub fn register_dataset( &self, @@ -1184,6 +1217,34 @@ impl PySessionContext { Ok(PyDataFrame::new(df)) } + pub fn read_empty(&self) -> PyDataFusionResult { + let df = self.ctx.read_empty()?; + Ok(PyDataFrame::new(df)) + } + + #[pyo3(signature = (path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))] + pub fn read_arrow( + &self, + path: &str, + schema: Option>, + file_extension: &str, + table_partition_cols: Vec<(String, PyArrowType)>, + py: Python, + ) -> PyDataFusionResult { + let mut options = ArrowReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + + let result = self.ctx.read_arrow(path, options); + let df = wait_for_future(py, result)??; + Ok(PyDataFrame::new(df)) + } + pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult { let session = self.clone().into_bound_py_any(table.py())?; let table = PyTable::new(table, Some(session))?; diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c8edc816f..c2a06dc82 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -894,6 +894,15 @@ def register_udtf(self, func: TableFunction) -> None: """Register a user defined table function.""" self.ctx.register_udtf(func._udtf) + def register_batch(self, name: str, batch: pa.RecordBatch) -> None: + """Register a single :py:class:`pa.RecordBatch` as a table. + + Args: + name: Name of the resultant table. + batch: Record batch to register as a table. + """ + self.ctx.register_batch(name, batch) + def register_record_batches( self, name: str, partitions: list[list[pa.RecordBatch]] ) -> None: @@ -1092,6 +1101,33 @@ def register_avro( name, str(path), schema, file_extension, table_partition_cols ) + def register_arrow( + self, + name: str, + path: str | pathlib.Path, + schema: pa.Schema | None = None, + file_extension: str = ".arrow", + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, + ) -> None: + """Register an Arrow IPC file as a table. + + The registered table can be referenced from SQL statements executed + against this context. + + Args: + name: Name of the table to register. + path: Path to the Arrow IPC file. + schema: The data source schema. + file_extension: File extension to select. + table_partition_cols: Partition columns. + """ + if table_partition_cols is None: + table_partition_cols = [] + table_partition_cols = _convert_table_partition_cols(table_partition_cols) + self.ctx.register_arrow( + name, str(path), schema, file_extension, table_partition_cols + ) + def register_dataset(self, name: str, dataset: pa.dataset.Dataset) -> None: """Register a :py:class:`pa.dataset.Dataset` as a table. @@ -1328,6 +1364,39 @@ def read_avro( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) + def read_arrow( + self, + path: str | pathlib.Path, + schema: pa.Schema | None = None, + file_extension: str = ".arrow", + file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, + ) -> DataFrame: + """Create a :py:class:`DataFrame` for reading an Arrow IPC data source. + + Args: + path: Path to the Arrow IPC file. + schema: The data source schema. + file_extension: File extension to select. + file_partition_cols: Partition columns. + + Returns: + DataFrame representation of the read Arrow IPC file. + """ + if file_partition_cols is None: + file_partition_cols = [] + file_partition_cols = _convert_table_partition_cols(file_partition_cols) + return DataFrame( + self.ctx.read_arrow(str(path), schema, file_extension, file_partition_cols) + ) + + def read_empty(self) -> DataFrame: + """Create an empty :py:class:`DataFrame` with no columns or rows. + + Returns: + An empty DataFrame. + """ + return DataFrame(self.ctx.read_empty()) + def read_table( self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset ) -> DataFrame: diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 5df6ed20f..a4a82cdf6 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -668,6 +668,45 @@ def test_read_avro(ctx): assert avro_df is not None +def test_read_arrow(ctx, tmp_path): + # Write an Arrow IPC file, then read it back + table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + arrow_path = tmp_path / "test.arrow" + with pa.ipc.new_file(str(arrow_path), table.schema) as writer: + writer.write_table(table) + + df = ctx.read_arrow(str(arrow_path)) + result = df.collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + assert result[0].column(1) == pa.array(["x", "y", "z"]) + + +def test_read_empty(ctx): + df = ctx.read_empty() + result = df.collect() + assert result[0].num_columns == 0 + + +def test_register_arrow(ctx, tmp_path): + # Write an Arrow IPC file, then register and query it + table = pa.table({"x": [10, 20, 30]}) + arrow_path = tmp_path / "test.arrow" + with pa.ipc.new_file(str(arrow_path), table.schema) as writer: + writer.write_table(table) + + ctx.register_arrow("arrow_tbl", str(arrow_path)) + result = ctx.sql("SELECT * FROM arrow_tbl").collect() + assert result[0].column(0) == pa.array([10, 20, 30]) + + +def test_register_batch(ctx): + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + ctx.register_batch("batch_tbl", batch) + result = ctx.sql("SELECT * FROM batch_tbl").collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + assert result[0].column(1) == pa.array([4, 5, 6]) + + def test_create_sql_options(): SQLOptions()