From 2c56784eb91c48928f71aae4fe46132efa4c461f Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Fri, 6 Feb 2026 14:25:42 -0800 Subject: [PATCH 1/7] Refactor the reattach shutdown mechanism for connect --- python/pyspark/sql/connect/client/core.py | 8 ++- python/pyspark/sql/connect/client/reattach.py | 56 ++++++------------- 2 files changed, 23 insertions(+), 41 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 7af0f857e8f7c..8f13896dfe03d 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -28,6 +28,7 @@ check_dependencies(__name__) +import concurrent.futures import logging import threading import os @@ -38,6 +39,7 @@ import sys import time import traceback +import weakref from typing import ( Iterable, Iterator, @@ -751,6 +753,8 @@ def __init__( self._plan_compression_threshold: Optional[int] = None # Will be fetched lazily self._plan_compression_algorithm: Optional[str] = None # Will be fetched lazily + self._release_futures = weakref.WeakSet() + # cleanup ml cache if possible atexit.register(self._cleanup_ml_cache) @@ -1272,7 +1276,7 @@ def close(self) -> None: """ Close the channel. """ - ExecutePlanResponseReattachableIterator.shutdown() + concurrent.futures.wait(self._release_futures) self._channel.close() self._closed = True @@ -1487,6 +1491,7 @@ def handle_response(b: pb2.ExecutePlanResponse) -> None: handle_response(b) finally: generator.close() + self._release_futures.update(generator.release_futures) else: for attempt in self._retrying(): with attempt: @@ -1687,6 +1692,7 @@ def handle_response( yield from handle_response(b) finally: generator.close() + self._release_futures.update(generator.release_futures) else: for attempt in self._retrying(): with attempt: diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 2abe95bd2510e..545980b84fda1 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -22,9 +22,10 @@ from threading import RLock import uuid from collections.abc import Generator -from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar -from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, ClassVar +from concurrent.futures import Future,ThreadPoolExecutor import os +import weakref import grpc from grpc_status import rpc_status @@ -56,35 +57,9 @@ class ExecutePlanResponseReattachableIterator(Generator): ReleaseExecute RPCs that instruct the server to release responses that it already processed. """ - # Lock to manage the pool _lock: ClassVar[RLock] = RLock() _release_thread_pool_instance: Optional[ThreadPoolExecutor] = None - @classmethod - def _get_or_create_release_thread_pool(cls) -> ThreadPoolExecutor: - # Perform a first check outside the critical path. - if cls._release_thread_pool_instance is not None: - return cls._release_thread_pool_instance - with cls._lock: - if cls._release_thread_pool_instance is None: - max_workers = os.cpu_count() or 8 - cls._release_thread_pool_instance = ThreadPoolExecutor(max_workers=max_workers) - return cls._release_thread_pool_instance - - @classmethod - def shutdown(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None: - """ - When the channel is closed, this method will be called before, to make sure all - outstanding calls are closed. - """ - with cls._lock: - if cls._release_thread_pool_instance is not None: - thread_pool = cls._release_thread_pool_instance - cls._release_thread_pool_instance = None - # This method could be called within the thread pool so don't wait for the - # shutdown to complete. Otherwise it could deadlock. - thread_pool.shutdown(wait=False) - def __init__( self, request: pb2.ExecutePlanRequest, @@ -92,9 +67,9 @@ def __init__( retrying: Callable[[], Retrying], metadata: Iterable[Tuple[str, str]], ): - self._release_thread_pool # Trigger initialization self._request = request self._retrying = retrying + self._release_futures = weakref.WeakSet() if request.operation_id: self._operation_id = request.operation_id else: @@ -132,9 +107,18 @@ def __init__( # Current item from this iterator. self._current: Optional[pb2.ExecutePlanResponse] = None + @property + def release_futures(self) -> weakref.WeakSet[Future]: + return self._release_futures + @property def _release_thread_pool(self) -> ThreadPoolExecutor: - return self._get_or_create_release_thread_pool() + if self._release_thread_pool_instance is not None: + return self._release_thread_pool_instance + with self._lock: + if self._release_thread_pool_instance is None: + self._release_thread_pool_instance = ThreadPoolExecutor(max_workers=os.cpu_count() or 8) + return self._release_thread_pool_instance def send(self, value: Any) -> pb2.ExecutePlanResponse: # will trigger reattach in case the stream completed without result_complete @@ -214,11 +198,7 @@ def target() -> None: except Exception as e: logger.warn(f"ReleaseExecute failed with exception: {e}.") - with self._lock: - if self._release_thread_pool_instance is not None: - thread_pool = self._release_thread_pool - if not thread_pool._shutdown: - thread_pool.submit(target) + self._release_futures.add(self._release_thread_pool.submit(target)) def _release_all(self) -> None: """ @@ -241,11 +221,7 @@ def target() -> None: except Exception as e: logger.warn(f"ReleaseExecute failed with exception: {e}.") - with self._lock: - if self._release_thread_pool_instance is not None: - thread_pool = self._release_thread_pool - if not thread_pool._shutdown: - thread_pool.submit(target) + self._release_futures.add(self._release_thread_pool.submit(target)) self._result_complete = True def _call_iter(self, iter_fun: Callable) -> Any: From 9701ccbd6bbc596efc33d0725a7534a308e74998 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Fri, 6 Feb 2026 14:43:33 -0800 Subject: [PATCH 2/7] Add a mechanism for release threadpool too --- python/pyspark/sql/connect/client/core.py | 3 ++- python/pyspark/sql/connect/client/reattach.py | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 8f13896dfe03d..8db7126e9a4f5 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -753,7 +753,7 @@ def __init__( self._plan_compression_threshold: Optional[int] = None # Will be fetched lazily self._plan_compression_algorithm: Optional[str] = None # Will be fetched lazily - self._release_futures = weakref.WeakSet() + self._release_futures: weakref.WeakSet[concurrent.futures.Future] = weakref.WeakSet() # cleanup ml cache if possible atexit.register(self._cleanup_ml_cache) @@ -1277,6 +1277,7 @@ def close(self) -> None: Close the channel. """ concurrent.futures.wait(self._release_futures) + ExecutePlanResponseReattachableIterator.shutdown_threadpool_if_idle() self._channel.close() self._closed = True diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 545980b84fda1..3a22bb8914d8c 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -59,6 +59,12 @@ class ExecutePlanResponseReattachableIterator(Generator): _lock: ClassVar[RLock] = RLock() _release_thread_pool_instance: Optional[ThreadPoolExecutor] = None + _instances: ClassVar[weakref.WeakSet["ExecutePlanResponseReattachableIterator"]] = weakref.WeakSet() + + def __new__(cls, *args, **kwargs) -> "ExecutePlanResponseReattachableIterator": + instance = super().__new__(cls) + cls._instances.add(instance) + return instance def __init__( self, @@ -69,7 +75,7 @@ def __init__( ): self._request = request self._retrying = retrying - self._release_futures = weakref.WeakSet() + self._release_futures: weakref.WeakSet[Future] = weakref.WeakSet() if request.operation_id: self._operation_id = request.operation_id else: @@ -120,6 +126,13 @@ def _release_thread_pool(self) -> ThreadPoolExecutor: self._release_thread_pool_instance = ThreadPoolExecutor(max_workers=os.cpu_count() or 8) return self._release_thread_pool_instance + @classmethod + def shutdown_threadpool_if_idle(cls) -> None: + with cls._lock: + if not cls._instances and cls._release_thread_pool_instance is not None: + cls._release_thread_pool_instance.shutdown() + cls._release_thread_pool_instance = None + def send(self, value: Any) -> pb2.ExecutePlanResponse: # will trigger reattach in case the stream completed without result_complete if not self._has_next(): From adf262543e3935995e308c0473e51245a5af709f Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Fri, 6 Feb 2026 14:44:06 -0800 Subject: [PATCH 3/7] Do not wait --- python/pyspark/sql/connect/client/reattach.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 3a22bb8914d8c..8bcede9287d5a 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -130,7 +130,7 @@ def _release_thread_pool(self) -> ThreadPoolExecutor: def shutdown_threadpool_if_idle(cls) -> None: with cls._lock: if not cls._instances and cls._release_thread_pool_instance is not None: - cls._release_thread_pool_instance.shutdown() + cls._release_thread_pool_instance.shutdown(wait=False) cls._release_thread_pool_instance = None def send(self, value: Any) -> pb2.ExecutePlanResponse: From 6f844efeda3771f85eaa39489947976a273942a7 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Fri, 6 Feb 2026 14:44:30 -0800 Subject: [PATCH 4/7] Reformat --- python/pyspark/sql/connect/client/reattach.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 8bcede9287d5a..11578761987ad 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -23,7 +23,7 @@ import uuid from collections.abc import Generator from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, ClassVar -from concurrent.futures import Future,ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor import os import weakref @@ -59,10 +59,12 @@ class ExecutePlanResponseReattachableIterator(Generator): _lock: ClassVar[RLock] = RLock() _release_thread_pool_instance: Optional[ThreadPoolExecutor] = None - _instances: ClassVar[weakref.WeakSet["ExecutePlanResponseReattachableIterator"]] = weakref.WeakSet() + _instances: ClassVar[ + weakref.WeakSet["ExecutePlanResponseReattachableIterator"] + ] = weakref.WeakSet() def __new__(cls, *args, **kwargs) -> "ExecutePlanResponseReattachableIterator": - instance = super().__new__(cls) + instance = super().__new__(cls) cls._instances.add(instance) return instance @@ -123,7 +125,9 @@ def _release_thread_pool(self) -> ThreadPoolExecutor: return self._release_thread_pool_instance with self._lock: if self._release_thread_pool_instance is None: - self._release_thread_pool_instance = ThreadPoolExecutor(max_workers=os.cpu_count() or 8) + self._release_thread_pool_instance = ThreadPoolExecutor( + max_workers=os.cpu_count() or 8 + ) return self._release_thread_pool_instance @classmethod From 33b6713f35313ec2822019aff9c98f24f0d2e3f6 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Fri, 6 Feb 2026 14:55:55 -0800 Subject: [PATCH 5/7] Put everything in lock --- python/pyspark/sql/connect/client/reattach.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 11578761987ad..0caafefcd14d8 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -121,8 +121,6 @@ def release_futures(self) -> weakref.WeakSet[Future]: @property def _release_thread_pool(self) -> ThreadPoolExecutor: - if self._release_thread_pool_instance is not None: - return self._release_thread_pool_instance with self._lock: if self._release_thread_pool_instance is None: self._release_thread_pool_instance = ThreadPoolExecutor( From a5d08f6e6f3143c7a171b5ef432dee8eb4487073 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Fri, 6 Feb 2026 14:59:14 -0800 Subject: [PATCH 6/7] Actually it would work --- python/pyspark/sql/connect/client/reattach.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 0caafefcd14d8..0017b651ecf77 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -121,6 +121,11 @@ def release_futures(self) -> weakref.WeakSet[Future]: @property def _release_thread_pool(self) -> ThreadPoolExecutor: + if self._release_thread_pool_instance is not None: + # It's impossible for the thread pool to shutdown when the iterator is still alive + # We can safely return the thread pool instance here as long as it is not used + # by external objects. + return self._release_thread_pool_instance with self._lock: if self._release_thread_pool_instance is None: self._release_thread_pool_instance = ThreadPoolExecutor( From 98d7447af44f848c63f3ad4fa82ef0eb05f52240 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Fri, 6 Feb 2026 16:23:34 -0800 Subject: [PATCH 7/7] Fix type hint --- python/pyspark/sql/connect/client/reattach.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 0017b651ecf77..b25107aba2652 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -63,7 +63,7 @@ class ExecutePlanResponseReattachableIterator(Generator): weakref.WeakSet["ExecutePlanResponseReattachableIterator"] ] = weakref.WeakSet() - def __new__(cls, *args, **kwargs) -> "ExecutePlanResponseReattachableIterator": + def __new__(cls, *args: Any, **kwargs: Any) -> "ExecutePlanResponseReattachableIterator": instance = super().__new__(cls) cls._instances.add(instance) return instance