diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 53994d2f5..cf38b6f25 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -439,6 +439,20 @@ impl PySessionContext { Ok(()) } + /// Deregister an object store with the given url + #[pyo3(signature = (scheme, host=None))] + pub fn deregister_object_store( + &self, + scheme: &str, + host: Option<&str>, + ) -> PyDataFusionResult<()> { + let host = host.unwrap_or(""); + let url_string = format!("{scheme}{host}"); + let url = Url::parse(&url_string).unwrap(); + self.ctx.runtime_env().deregister_object_store(&url)?; + Ok(()) + } + #[allow(clippy::too_many_arguments)] #[pyo3(signature = (name, path, table_partition_cols=vec![], file_extension=".parquet", @@ -492,6 +506,10 @@ impl PySessionContext { self.ctx.register_udtf(&name, func); } + pub fn deregister_udtf(&self, name: &str) { + self.ctx.deregister_udtf(name); + } + #[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))] pub fn sql_with_options( &self, @@ -975,16 +993,28 @@ impl PySessionContext { Ok(()) } + pub fn deregister_udf(&self, name: &str) { + self.ctx.deregister_udf(name); + } + pub fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> { self.ctx.register_udaf(udaf.function); Ok(()) } + pub fn deregister_udaf(&self, name: &str) { + self.ctx.deregister_udaf(name); + } + pub fn register_udwf(&self, udwf: PyWindowUDF) -> PyResult<()> { self.ctx.register_udwf(udwf.function); Ok(()) } + pub fn deregister_udwf(&self, name: &str) { + self.ctx.deregister_udwf(name); + } + #[pyo3(signature = (name="datafusion"))] pub fn catalog(&self, py: Python, name: &str) -> PyResult> { let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!( diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c8edc816f..f190e3ca1 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -568,6 +568,15 @@ def register_object_store( """ self.ctx.register_object_store(schema, store, host) + def deregister_object_store(self, schema: str, host: str | None = None) -> None: + """Remove an object store from the session. + + Args: + schema: The data source schema (e.g. ``"s3://"``). + host: URL for the host (e.g. bucket name). + """ + self.ctx.deregister_object_store(schema, host) + def register_listing_table( self, name: str, @@ -894,6 +903,14 @@ def register_udtf(self, func: TableFunction) -> None: """Register a user defined table function.""" self.ctx.register_udtf(func._udtf) + def deregister_udtf(self, name: str) -> None: + """Remove a user-defined table function from the session. + + Args: + name: Name of the UDTF to deregister. + """ + self.ctx.deregister_udtf(name) + def register_record_batches( self, name: str, partitions: list[list[pa.RecordBatch]] ) -> None: @@ -1105,14 +1122,38 @@ def register_udf(self, udf: ScalarUDF) -> None: """Register a user-defined function (UDF) with the context.""" self.ctx.register_udf(udf._udf) + def deregister_udf(self, name: str) -> None: + """Remove a user-defined scalar function from the session. + + Args: + name: Name of the UDF to deregister. + """ + self.ctx.deregister_udf(name) + def register_udaf(self, udaf: AggregateUDF) -> None: """Register a user-defined aggregation function (UDAF) with the context.""" self.ctx.register_udaf(udaf._udaf) + def deregister_udaf(self, name: str) -> None: + """Remove a user-defined aggregate function from the session. + + Args: + name: Name of the UDAF to deregister. + """ + self.ctx.deregister_udaf(name) + def register_udwf(self, udwf: WindowUDF) -> None: """Register a user-defined window function (UDWF) with the context.""" self.ctx.register_udwf(udwf._udwf) + def deregister_udwf(self, name: str) -> None: + """Remove a user-defined window function from the session. + + Args: + name: Name of the UDWF to deregister. + """ + self.ctx.deregister_udwf(name) + def catalog(self, name: str = "datafusion") -> Catalog: """Retrieve a catalog by name.""" return Catalog(self.ctx.catalog(name)) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 5df6ed20f..43b85406a 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -351,6 +351,126 @@ def test_deregister_table(ctx, database): assert public.names() == {"csv1", "csv2"} +def test_deregister_udf(): + ctx = SessionContext() + from datafusion import udf + + is_null = udf( + lambda x: x.is_null(), + [pa.float64()], + pa.bool_(), + volatility="immutable", + name="my_is_null", + ) + ctx.register_udf(is_null) + + # Verify it works + df = ctx.from_pydict({"a": [1.0, None]}) + ctx.register_table("t", df.into_view()) + result = ctx.sql("SELECT my_is_null(a) FROM t").collect() + assert result[0].column(0) == pa.array([False, True]) + + # Deregister and verify it's gone + ctx.deregister_udf("my_is_null") + with pytest.raises(RuntimeError): + ctx.sql("SELECT my_is_null(a) FROM t").collect() + + +def test_deregister_udaf(): + import pyarrow.compute as pc + + ctx = SessionContext() + from datafusion import Accumulator, udaf + + class MySum(Accumulator): + def __init__(self): + self._sum = 0.0 + + def update(self, values: pa.Array) -> None: + self._sum += pc.sum(values).as_py() + + def merge(self, states: list[pa.Array]) -> None: + self._sum += pc.sum(states[0]).as_py() + + def state(self) -> list: + return [self._sum] + + def evaluate(self) -> pa.Scalar: + return self._sum + + my_sum = udaf( + MySum, + [pa.float64()], + pa.float64(), + [pa.float64()], + volatility="immutable", + name="my_sum", + ) + ctx.register_udaf(my_sum) + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) + ctx.register_table("t", df.into_view()) + + result = ctx.sql("SELECT my_sum(a) FROM t").collect() + assert result[0].column(0) == pa.array([6.0]) + + ctx.deregister_udaf("my_sum") + with pytest.raises(RuntimeError): + ctx.sql("SELECT my_sum(a) FROM t").collect() + + +def test_deregister_udwf(): + ctx = SessionContext() + from datafusion import udwf + from datafusion.user_defined import WindowEvaluator + + class MyRowNumber(WindowEvaluator): + def __init__(self): + self._row = 0 + + def evaluate_all(self, values, num_rows): + return pa.array(list(range(1, num_rows + 1)), type=pa.uint64()) + + my_row_number = udwf( + MyRowNumber, + [pa.float64()], + pa.uint64(), + volatility="immutable", + name="my_row_number", + ) + ctx.register_udwf(my_row_number) + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) + ctx.register_table("t", df.into_view()) + + result = ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect() + assert result[0].column(0) == pa.array([1, 2, 3], type=pa.uint64()) + + ctx.deregister_udwf("my_row_number") + with pytest.raises(RuntimeError): + ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect() + + +def test_deregister_udtf(): + import pyarrow.dataset as ds + + ctx = SessionContext() + from datafusion import Table, udtf + + class MyTable: + def __call__(self): + batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}) + return Table(ds.dataset([batch])) + + my_table = udtf(MyTable(), "my_table") + ctx.register_udtf(my_table) + + result = ctx.sql("SELECT * FROM my_table()").collect() + assert result[0].column(0) == pa.array([1, 2, 3]) + + ctx.deregister_udtf("my_table") + with pytest.raises(RuntimeError): + ctx.sql("SELECT * FROM my_table()").collect() + + def test_register_table_from_dataframe(ctx): df = ctx.from_pydict({"a": [1, 2]}) ctx.register_table("df_tbl", df)