Skip to content

Commit 76fcc08

Browse files
committed
refactor: replace extract_table_provider with coerce_table_provider for improved clarity and error handling
1 parent 4c64b1e commit 76fcc08

File tree

4 files changed

+34
-39
lines changed

4 files changed

+34
-39
lines changed

python/tests/test_catalog.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pyarrow as pa
2121
import pyarrow.dataset as ds
2222
import pytest
23-
from datafusion import SessionContext, Table
23+
from datafusion import EXPECTED_PROVIDER_MSG, SessionContext, Table
2424

2525

2626
# Note we take in `database` as a variable even though we don't use
@@ -164,6 +164,16 @@ def test_python_table_provider(ctx: SessionContext):
164164
assert schema.table_names() == {"table4"}
165165

166166

167+
def test_schema_register_table_with_dataframe_errors(ctx: SessionContext):
168+
schema = ctx.catalog().schema()
169+
df = ctx.from_pydict({"a": [1]})
170+
171+
with pytest.raises(Exception) as exc_info:
172+
schema.register_table("bad", df)
173+
174+
assert str(exc_info.value) == EXPECTED_PROVIDER_MSG
175+
176+
167177
def test_in_end_to_end_python_providers(ctx: SessionContext):
168178
"""Test registering all python providers and running a query against them."""
169179

src/catalog.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::dataset::Dataset;
1919
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
2020
use crate::table::PyTableProvider;
2121
use crate::utils::{
22-
extract_table_provider, table_provider_from_pycapsule, validate_pycapsule, wait_for_future,
22+
coerce_table_provider, table_provider_from_pycapsule, validate_pycapsule, wait_for_future,
2323
};
2424
use async_trait::async_trait;
2525
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
@@ -198,10 +198,7 @@ impl PySchema {
198198
}
199199

200200
fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
201-
let provider = match extract_table_provider(&table_provider) {
202-
Ok(provider) => provider,
203-
Err(err) => return Err(err.into()),
204-
};
201+
let provider = coerce_table_provider(&table_provider).map_err(PyErr::from)?;
205202

206203
let _ = self
207204
.schema

src/context.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ use crate::udf::PyScalarUDF;
4646
use crate::udtf::PyTableFunction;
4747
use crate::udwf::PyWindowUDF;
4848
use crate::utils::{
49-
extract_table_provider, get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future,
49+
coerce_table_provider, get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future,
5050
};
5151
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5252
use datafusion::arrow::pyarrow::PyArrowType;
@@ -608,7 +608,7 @@ impl PySessionContext {
608608
name: &str,
609609
table_provider: Bound<'_, PyAny>,
610610
) -> PyDataFusionResult<()> {
611-
let provider = extract_table_provider(&table_provider)?;
611+
let provider = coerce_table_provider(&table_provider)?;
612612

613613
self.ctx.register_table(name, provider)?;
614614
Ok(())

src/utils.rs

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::{
2020
common::data_type::PyScalarValue,
2121
dataframe::PyDataFrame,
2222
dataset::Dataset,
23-
errors::{PyDataFusionError, PyDataFusionResult},
23+
errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult},
2424
table::PyTableProvider,
2525
TokioRuntime,
2626
};
@@ -140,37 +140,25 @@ pub(crate) fn table_provider_from_pycapsule(
140140
}
141141
}
142142

143-
pub(crate) fn extract_table_provider(
144-
table_like: &Bound<PyAny>,
143+
pub(crate) fn coerce_table_provider(
144+
obj: &Bound<PyAny>,
145145
) -> PyDataFusionResult<Arc<dyn TableProvider>> {
146-
if let Ok(py_table) = table_like.extract::<PyTable>() {
147-
return Ok(py_table.table());
148-
}
149-
150-
if let Ok(py_provider) = table_like.extract::<PyTableProvider>() {
151-
return Ok(py_provider.into_inner());
152-
}
153-
154-
if table_like.extract::<PyDataFrame>().is_ok() {
155-
return Err(PyDataFusionError::Common(EXPECTED_PROVIDER_MSG.to_string()));
156-
}
157-
158-
match table_provider_from_pycapsule(table_like) {
159-
Ok(Some(provider)) => Ok(provider),
160-
Ok(None) => {
161-
let py = table_like.py();
162-
match Dataset::new(table_like, py) {
163-
Ok(dataset) => Ok(Arc::new(dataset) as Arc<dyn TableProvider>),
164-
Err(err) => {
165-
if err.is_instance_of::<PyValueError>(py) {
166-
Err(PyDataFusionError::Common(EXPECTED_PROVIDER_MSG.to_string()))
167-
} else {
168-
Err(err.into())
169-
}
170-
}
171-
}
172-
}
173-
Err(err) => Err(err.into()),
146+
if let Ok(py_table) = obj.extract::<PyTable>() {
147+
Ok(py_table.table())
148+
} else if let Ok(py_provider) = obj.extract::<PyTableProvider>() {
149+
Ok(py_provider.into_inner())
150+
} else if obj.is_instance_of::<PyDataFrame>()
151+
|| obj
152+
.getattr("df")
153+
.is_ok_and(|inner| inner.is_instance_of::<PyDataFrame>())
154+
{
155+
Err(PyDataFusionError::Common(EXPECTED_PROVIDER_MSG.to_string()))
156+
} else if let Some(provider) = table_provider_from_pycapsule(obj)? {
157+
Ok(provider)
158+
} else {
159+
let py = obj.py();
160+
let provider = Dataset::new(obj, py)?;
161+
Ok(Arc::new(provider) as Arc<dyn TableProvider>)
174162
}
175163
}
176164

0 commit comments

Comments
 (0)