Skip to content

Commit fedadcb

Browse files
committed
Refactor to use PyCapsule for ArrowArrayStream handling in tests and improve memory management
1 parent 3b7f834 commit fedadcb

3 files changed

Lines changed: 18 additions & 13 deletions

File tree

python/tests/test_dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,7 +1606,7 @@ def fail_collect(self): # pragma: no cover - failure path
16061606

16071607

16081608
def test_arrow_c_stream_reader(df):
1609-
reader = pa.RecordBatchReader._import_from_c(df.__arrow_c_stream__())
1609+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
16101610
assert isinstance(reader, pa.RecordBatchReader)
16111611
table = pa.Table.from_batches(reader)
16121612
expected = pa.Table.from_batches(df.collect())
@@ -2751,7 +2751,7 @@ def test_arrow_c_stream_interrupted():
27512751
"""
27522752
)
27532753

2754-
reader = pa.RecordBatchReader._import_from_c(df.__arrow_c_stream__())
2754+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
27552755

27562756
interrupted = False
27572757
interrupt_error = None

python/tests/test_io.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_arrow_c_stream_large_dataset(ctx):
106106
# Create a very large DataFrame using range; this would be terabytes if collected
107107
df = ctx.range(0, 1 << 40)
108108

109-
reader = pa.RecordBatchReader._import_from_c(df.__arrow_c_stream__())
109+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
110110

111111
# Track RSS before consuming batches
112112
psutil = pytest.importorskip("psutil")
@@ -126,7 +126,8 @@ def test_table_from_batches_stream(ctx, monkeypatch):
126126
df = ctx.range(0, 10)
127127

128128
def fail_collect(self): # pragma: no cover - failure path
129-
raise AssertionError("collect should not be called")
129+
msg = "collect should not be called"
130+
raise AssertionError(msg)
130131

131132
monkeypatch.setattr(DataFrame, "collect", fail_collect)
132133

src/dataframe.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
// under the License.
1717

1818
use std::collections::HashMap;
19-
use std::ffi::{c_void, CString};
20-
use std::sync::{Arc, OnceLock};
19+
use std::ffi::{c_void, CStr, CString};
20+
use std::sync::Arc;
2121

2222
use arrow::array::{new_null_array, RecordBatch, RecordBatchReader};
2323
use arrow::compute::can_cast_types;
@@ -59,14 +59,15 @@ use crate::{
5959
expr::{sort_expr::PySortExpr, PyExpr},
6060
};
6161

62-
static ARROW_STREAM_NAME: OnceLock<CString> = OnceLock::new();
62+
static ARROW_STREAM_NAME: &CStr =
63+
unsafe { CStr::from_bytes_with_nul_unchecked(b"arrow_array_stream\0") };
6364

6465
unsafe extern "C" fn drop_stream(capsule: *mut ffi::PyObject) {
6566
if capsule.is_null() {
6667
return;
6768
}
68-
let name = ARROW_STREAM_NAME.get_or_init(|| CString::new("arrow_array_stream").unwrap());
69-
let stream_ptr = ffi::PyCapsule_GetPointer(capsule, name.as_ptr()) as *mut FFI_ArrowArrayStream;
69+
let stream_ptr =
70+
ffi::PyCapsule_GetPointer(capsule, ARROW_STREAM_NAME.as_ptr()) as *mut FFI_ArrowArrayStream;
7071
if !stream_ptr.is_null() {
7172
drop(Box::from_raw(stream_ptr));
7273
}
@@ -945,7 +946,7 @@ impl PyDataFrame {
945946
let stream = spawn_stream(py, async move { df.execute_stream().await })?;
946947

947948
let mut schema: Schema = self.df.schema().to_owned().into();
948-
let mut projection: Option<SchemaRef> = None;
949+
let mut projection: Option<SchemaRef> = Some(Arc::new(schema.clone()));
949950

950951
if let Some(schema_capsule) = requested_schema {
951952
validate_pycapsule(&schema_capsule, "arrow_schema")?;
@@ -957,7 +958,7 @@ impl PyDataFrame {
957958
projection = Some(Arc::new(schema.clone()));
958959
}
959960

960-
let schema_ref = projection.clone().unwrap_or_else(|| Arc::new(schema));
961+
let schema_ref = Arc::new(schema.clone());
961962

962963
let reader = DataFrameStreamReader {
963964
stream,
@@ -972,9 +973,12 @@ impl PyDataFrame {
972973
!stream_ptr.is_null(),
973974
"ArrowArrayStream pointer should never be null"
974975
);
975-
let name = ARROW_STREAM_NAME.get_or_init(|| CString::new("arrow_array_stream").unwrap());
976976
let capsule = unsafe {
977-
ffi::PyCapsule_New(stream_ptr as *mut c_void, name.as_ptr(), Some(drop_stream))
977+
ffi::PyCapsule_New(
978+
stream_ptr as *mut c_void,
979+
ARROW_STREAM_NAME.as_ptr(),
980+
Some(drop_stream),
981+
)
978982
};
979983
if capsule.is_null() {
980984
unsafe { drop(Box::from_raw(stream_ptr)) };

0 commit comments

Comments
 (0)