diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 53994d2f5..74864d8b4 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -28,7 +28,7 @@ use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory}; -use datafusion::common::{ScalarValue, TableReference, exec_err}; +use datafusion::common::{DFSchema, ScalarValue, TableReference, exec_err}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ @@ -70,11 +70,13 @@ use crate::catalog::{ PyCatalog, PyCatalogList, RustWrappedPyCatalogProvider, RustWrappedPyCatalogProviderList, }; use crate::common::data_type::PyScalarValue; +use crate::common::df_schema::PyDFSchema; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{ PyDataFusionError, PyDataFusionResult, from_datafusion_error, py_datafusion_err, }; +use crate::expr::PyExpr; use crate::expr::sort_expr::PySortExpr; use crate::options::PyCsvReadOptions; use crate::physical_plan::PyExecutionPlan; @@ -1050,6 +1052,45 @@ impl PySessionContext { self.ctx.session_id() } + pub fn session_start_time(&self) -> String { + self.ctx.session_start_time().to_rfc3339() + } + + pub fn enable_ident_normalization(&self) -> bool { + self.ctx.enable_ident_normalization() + } + + pub fn parse_sql_expr(&self, sql: &str, schema: PyDFSchema) -> PyDataFusionResult { + let df_schema: DFSchema = schema.into(); + Ok(self.ctx.parse_sql_expr(sql, &df_schema)?.into()) + } + + pub fn execute_logical_plan( + &self, + plan: PyLogicalPlan, + py: Python, + ) -> PyDataFusionResult { + let df = wait_for_future( + py, + self.ctx.execute_logical_plan(plan.plan.as_ref().clone()), + )??; + Ok(PyDataFrame::new(df)) + } + + pub fn refresh_catalogs(&self, py: Python) -> PyDataFusionResult<()> { + wait_for_future(py, self.ctx.refresh_catalogs())??; + Ok(()) + } + + pub fn remove_optimizer_rule(&self, name: &str) -> bool { + self.ctx.remove_optimizer_rule(name) + } + + pub fn table_provider(&self, name: &str, py: Python) -> PyDataFusionResult { + let provider = wait_for_future(py, self.ctx.table_provider(name))??; + Ok(PyTable { table: provider }) + } + #[allow(clippy::too_many_arguments)] #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))] pub fn read_json( diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c8edc816f..fb40f73f8 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -63,7 +63,8 @@ import polars as pl # type: ignore[import] from datafusion.catalog import CatalogProvider, Table - from datafusion.expr import SortKey + from datafusion.common import DFSchema + from datafusion.expr import Expr, SortKey from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.user_defined import ( AggregateUDF, @@ -1141,6 +1142,67 @@ def session_id(self) -> str: """Return an id that uniquely identifies this :py:class:`SessionContext`.""" return self.ctx.session_id() + def session_start_time(self) -> str: + """Return the session start time as an RFC 3339 formatted string.""" + return self.ctx.session_start_time() + + def enable_ident_normalization(self) -> bool: + """Return whether identifier normalization (lowercasing) is enabled.""" + return self.ctx.enable_ident_normalization() + + def parse_sql_expr(self, sql: str, schema: DFSchema) -> Expr: + """Parse a SQL expression string into a logical expression. + + Args: + sql: SQL expression string. + schema: Schema to use for resolving column references. + + Returns: + Parsed expression. + """ + from datafusion.expr import Expr # noqa: PLC0415 + + return Expr(self.ctx.parse_sql_expr(sql, schema)) + + def execute_logical_plan(self, plan: LogicalPlan) -> DataFrame: + """Execute a :py:class:`~datafusion.plan.LogicalPlan` and return a DataFrame. + + Args: + plan: Logical plan to execute. + + Returns: + DataFrame resulting from the execution. + """ + return DataFrame(self.ctx.execute_logical_plan(plan._raw_plan)) + + def refresh_catalogs(self) -> None: + """Refresh catalog metadata.""" + self.ctx.refresh_catalogs() + + def remove_optimizer_rule(self, name: str) -> bool: + """Remove an optimizer rule by name. + + Args: + name: Name of the optimizer rule to remove. + + Returns: + True if a rule with the given name was found and removed. + """ + return self.ctx.remove_optimizer_rule(name) + + def table_provider(self, name: str) -> Table: + """Return the :py:class:`~datafusion.catalog.Table` for the given table name. + + Args: + name: Name of the table. + + Returns: + The table provider. + """ + from datafusion.catalog import Table # noqa: PLC0415 + + return Table(self.ctx.table_provider(name)) + def read_json( self, path: str | pathlib.Path, diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 5df6ed20f..3660524d2 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -551,6 +551,48 @@ def test_table_not_found(ctx): ctx.table(f"not-found-{uuid4()}") +def test_session_start_time(ctx): + st = ctx.session_start_time() + assert isinstance(st, str) + assert "T" in st # RFC 3339 format + + +def test_enable_ident_normalization(ctx): + result = ctx.enable_ident_normalization() + assert isinstance(result, bool) + + +def test_parse_sql_expr(ctx): + from datafusion.common import DFSchema + + schema = DFSchema.empty() + expr = ctx.parse_sql_expr("1 + 2", schema) + assert "Int64(1) + Int64(2)" in str(expr) + + +def test_execute_logical_plan(ctx): + df = ctx.from_pydict({"a": [1, 2, 3]}) + plan = df.logical_plan() + df2 = ctx.execute_logical_plan(plan) + result = df2.collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + + +def test_refresh_catalogs(ctx): + ctx.refresh_catalogs() + + +def test_remove_optimizer_rule(ctx): + assert ctx.remove_optimizer_rule("nonexistent_rule") is False + + +def test_table_provider(ctx): + batch = pa.RecordBatch.from_pydict({"x": [10, 20, 30]}) + ctx.register_record_batches("provider_test", [[batch]]) + tbl = ctx.table_provider("provider_test") + assert tbl.schema == pa.schema([("x", pa.int64())]) + + def test_read_json(ctx): path = pathlib.Path(__file__).parent.resolve()