diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index e0a395ebbf..5d9bd0a2ce 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -159,6 +159,8 @@ struct ExecutionContext { pub stream: Option, /// Receives batches from a spawned tokio task (async I/O path) pub batch_receiver: Option>>, + /// Handle to the spawned tokio task so we can wait for it during cleanup + pub task_handle: Option>, /// Native metrics pub metrics: Arc, // The interval in milliseconds to update metrics @@ -317,6 +319,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( input_sources, stream: None, batch_receiver: None, + task_handle: None, metrics, metrics_update_interval, metrics_last_update_time: Instant::now(), @@ -579,7 +582,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // decreasing to 1 would serialize production and consumption. let (tx, rx) = mpsc::channel(2); let mut stream = stream; - get_runtime().spawn(async move { + let handle = get_runtime().spawn(async move { let result = std::panic::AssertUnwindSafe(async { while let Some(batch) = stream.next().await { if tx.send(batch).await.is_err() { @@ -606,6 +609,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( } }); exec_context.batch_receiver = Some(rx); + exec_context.task_handle = Some(handle); } else { exec_context.stream = Some(stream); } @@ -701,6 +705,17 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( execution_context.task_attempt_id, ); + // If a tokio task was spawned for async execution, wait for it to + // complete before dropping the context. This ensures all + // MemoryReservations held by the stream are released back to Spark + // before Spark calls cleanUpAllAllocatedMemory(). + if let Some(handle) = execution_context.task_handle.take() { + // Drop the receiver first so the task sees a closed channel + // and exits its loop. + execution_context.batch_receiver.take(); + let _ = get_runtime().block_on(handle); + } + let _: Box = Box::from_raw(execution_context); Ok(()) })