Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,20 @@ impl PySessionContext {
Ok(())
}

/// Deregister an object store with the given url
#[pyo3(signature = (scheme, host=None))]
pub fn deregister_object_store(
&self,
scheme: &str,
host: Option<&str>,
) -> PyDataFusionResult<()> {
let host = host.unwrap_or("");
let url_string = format!("{scheme}{host}");
let url = Url::parse(&url_string).unwrap();
self.ctx.runtime_env().deregister_object_store(&url)?;
Ok(())
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (name, path, table_partition_cols=vec![],
file_extension=".parquet",
Expand Down Expand Up @@ -492,6 +506,10 @@ impl PySessionContext {
self.ctx.register_udtf(&name, func);
}

pub fn deregister_udtf(&self, name: &str) {
self.ctx.deregister_udtf(name);
}

#[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))]
pub fn sql_with_options(
&self,
Expand Down Expand Up @@ -975,16 +993,28 @@ impl PySessionContext {
Ok(())
}

pub fn deregister_udf(&self, name: &str) {
self.ctx.deregister_udf(name);
}

pub fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> {
self.ctx.register_udaf(udaf.function);
Ok(())
}

pub fn deregister_udaf(&self, name: &str) {
self.ctx.deregister_udaf(name);
}

pub fn register_udwf(&self, udwf: PyWindowUDF) -> PyResult<()> {
self.ctx.register_udwf(udwf.function);
Ok(())
}

pub fn deregister_udwf(&self, name: &str) {
self.ctx.deregister_udwf(name);
}

#[pyo3(signature = (name="datafusion"))]
pub fn catalog(&self, py: Python, name: &str) -> PyResult<Py<PyAny>> {
let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(
Expand Down
41 changes: 41 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,15 @@ def register_object_store(
"""
self.ctx.register_object_store(schema, store, host)

def deregister_object_store(self, schema: str, host: str | None = None) -> None:
"""Remove an object store from the session.

Args:
schema: The data source schema (e.g. ``"s3://"``).
host: URL for the host (e.g. bucket name).
"""
self.ctx.deregister_object_store(schema, host)

def register_listing_table(
self,
name: str,
Expand Down Expand Up @@ -894,6 +903,14 @@ def register_udtf(self, func: TableFunction) -> None:
"""Register a user defined table function."""
self.ctx.register_udtf(func._udtf)

def deregister_udtf(self, name: str) -> None:
"""Remove a user-defined table function from the session.

Args:
name: Name of the UDTF to deregister.
"""
self.ctx.deregister_udtf(name)

def register_record_batches(
self, name: str, partitions: list[list[pa.RecordBatch]]
) -> None:
Expand Down Expand Up @@ -1105,14 +1122,38 @@ def register_udf(self, udf: ScalarUDF) -> None:
"""Register a user-defined function (UDF) with the context."""
self.ctx.register_udf(udf._udf)

def deregister_udf(self, name: str) -> None:
"""Remove a user-defined scalar function from the session.

Args:
name: Name of the UDF to deregister.
"""
self.ctx.deregister_udf(name)

def register_udaf(self, udaf: AggregateUDF) -> None:
"""Register a user-defined aggregation function (UDAF) with the context."""
self.ctx.register_udaf(udaf._udaf)

def deregister_udaf(self, name: str) -> None:
"""Remove a user-defined aggregate function from the session.

Args:
name: Name of the UDAF to deregister.
"""
self.ctx.deregister_udaf(name)

def register_udwf(self, udwf: WindowUDF) -> None:
"""Register a user-defined window function (UDWF) with the context."""
self.ctx.register_udwf(udwf._udwf)

def deregister_udwf(self, name: str) -> None:
"""Remove a user-defined window function from the session.

Args:
name: Name of the UDWF to deregister.
"""
self.ctx.deregister_udwf(name)

def catalog(self, name: str = "datafusion") -> Catalog:
"""Retrieve a catalog by name."""
return Catalog(self.ctx.catalog(name))
Expand Down
120 changes: 120 additions & 0 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,126 @@ def test_deregister_table(ctx, database):
assert public.names() == {"csv1", "csv2"}


def test_deregister_udf():
ctx = SessionContext()
from datafusion import udf

is_null = udf(
lambda x: x.is_null(),
[pa.float64()],
pa.bool_(),
volatility="immutable",
name="my_is_null",
)
ctx.register_udf(is_null)

# Verify it works
df = ctx.from_pydict({"a": [1.0, None]})
ctx.register_table("t", df.into_view())
result = ctx.sql("SELECT my_is_null(a) FROM t").collect()
assert result[0].column(0) == pa.array([False, True])

# Deregister and verify it's gone
ctx.deregister_udf("my_is_null")
with pytest.raises(RuntimeError):
ctx.sql("SELECT my_is_null(a) FROM t").collect()


def test_deregister_udaf():
import pyarrow.compute as pc

ctx = SessionContext()
from datafusion import Accumulator, udaf

class MySum(Accumulator):
def __init__(self):
self._sum = 0.0

def update(self, values: pa.Array) -> None:
self._sum += pc.sum(values).as_py()

def merge(self, states: list[pa.Array]) -> None:
self._sum += pc.sum(states[0]).as_py()

def state(self) -> list:
return [self._sum]

def evaluate(self) -> pa.Scalar:
return self._sum

my_sum = udaf(
MySum,
[pa.float64()],
pa.float64(),
[pa.float64()],
volatility="immutable",
name="my_sum",
)
ctx.register_udaf(my_sum)
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]})
ctx.register_table("t", df.into_view())

result = ctx.sql("SELECT my_sum(a) FROM t").collect()
assert result[0].column(0) == pa.array([6.0])

ctx.deregister_udaf("my_sum")
with pytest.raises(RuntimeError):
ctx.sql("SELECT my_sum(a) FROM t").collect()


def test_deregister_udwf():
ctx = SessionContext()
from datafusion import udwf
from datafusion.user_defined import WindowEvaluator

class MyRowNumber(WindowEvaluator):
def __init__(self):
self._row = 0

def evaluate_all(self, values, num_rows):
return pa.array(list(range(1, num_rows + 1)), type=pa.uint64())

my_row_number = udwf(
MyRowNumber,
[pa.float64()],
pa.uint64(),
volatility="immutable",
name="my_row_number",
)
ctx.register_udwf(my_row_number)
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]})
ctx.register_table("t", df.into_view())

result = ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect()
assert result[0].column(0) == pa.array([1, 2, 3], type=pa.uint64())

ctx.deregister_udwf("my_row_number")
with pytest.raises(RuntimeError):
ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect()


def test_deregister_udtf():
import pyarrow.dataset as ds

ctx = SessionContext()
from datafusion import Table, udtf

class MyTable:
def __call__(self):
batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]})
return Table(ds.dataset([batch]))

my_table = udtf(MyTable(), "my_table")
ctx.register_udtf(my_table)

result = ctx.sql("SELECT * FROM my_table()").collect()
assert result[0].column(0) == pa.array([1, 2, 3])

ctx.deregister_udtf("my_table")
with pytest.raises(RuntimeError):
ctx.sql("SELECT * FROM my_table()").collect()


def test_register_table_from_dataframe(ctx):
df = ctx.from_pydict({"a": [1, 2]})
ctx.register_table("df_tbl", df)
Expand Down
Loading