Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ struct ExecutionContext {
pub stream: Option<SendableRecordBatchStream>,
/// Receives batches from a spawned tokio task (async I/O path)
pub batch_receiver: Option<mpsc::Receiver<DataFusionResult<RecordBatch>>>,
/// Handle to the spawned tokio task so we can wait for it during cleanup
pub task_handle: Option<tokio::task::JoinHandle<()>>,
/// Native metrics
pub metrics: Arc<GlobalRef>,
// The interval in milliseconds to update metrics
Expand Down Expand Up @@ -317,6 +319,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
input_sources,
stream: None,
batch_receiver: None,
task_handle: None,
metrics,
metrics_update_interval,
metrics_last_update_time: Instant::now(),
Expand Down Expand Up @@ -579,7 +582,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
// decreasing to 1 would serialize production and consumption.
let (tx, rx) = mpsc::channel(2);
let mut stream = stream;
get_runtime().spawn(async move {
let handle = get_runtime().spawn(async move {
let result = std::panic::AssertUnwindSafe(async {
while let Some(batch) = stream.next().await {
if tx.send(batch).await.is_err() {
Expand All @@ -606,6 +609,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
}
});
exec_context.batch_receiver = Some(rx);
exec_context.task_handle = Some(handle);
} else {
exec_context.stream = Some(stream);
}
Expand Down Expand Up @@ -701,6 +705,17 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
execution_context.task_attempt_id,
);

// If a tokio task was spawned for async execution, wait for it to
// complete before dropping the context. This ensures all
// MemoryReservations held by the stream are released back to Spark
// before Spark calls cleanUpAllAllocatedMemory().
if let Some(handle) = execution_context.task_handle.take() {
// Drop the receiver first so the task sees a closed channel
// and exits its loop.
execution_context.batch_receiver.take();
let _ = get_runtime().block_on(handle);
}

let _: Box<ExecutionContext> = Box::from_raw(execution_context);
Ok(())
})
Expand Down
Loading