Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

check_dependencies(__name__)

import concurrent.futures
import logging
import threading
import os
Expand All @@ -38,6 +39,7 @@
import sys
import time
import traceback
import weakref
from typing import (
Iterable,
Iterator,
Expand Down Expand Up @@ -751,6 +753,8 @@ def __init__(
self._plan_compression_threshold: Optional[int] = None # Will be fetched lazily
self._plan_compression_algorithm: Optional[str] = None # Will be fetched lazily

self._release_futures: weakref.WeakSet[concurrent.futures.Future] = weakref.WeakSet()

# cleanup ml cache if possible
atexit.register(self._cleanup_ml_cache)

Expand Down Expand Up @@ -1272,7 +1276,8 @@ def close(self) -> None:
"""
Close the channel.
"""
ExecutePlanResponseReattachableIterator.shutdown()
concurrent.futures.wait(self._release_futures)
ExecutePlanResponseReattachableIterator.shutdown_threadpool_if_idle()
self._channel.close()
self._closed = True

Expand Down Expand Up @@ -1487,6 +1492,7 @@ def handle_response(b: pb2.ExecutePlanResponse) -> None:
handle_response(b)
finally:
generator.close()
self._release_futures.update(generator.release_futures)
else:
for attempt in self._retrying():
with attempt:
Expand Down Expand Up @@ -1687,6 +1693,7 @@ def handle_response(
yield from handle_response(b)
finally:
generator.close()
self._release_futures.update(generator.release_futures)
else:
for attempt in self._retrying():
with attempt:
Expand Down
74 changes: 35 additions & 39 deletions python/pyspark/sql/connect/client/reattach.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from threading import RLock
import uuid
from collections.abc import Generator
from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, ClassVar
from concurrent.futures import Future, ThreadPoolExecutor
import os
import weakref

import grpc
from grpc_status import rpc_status
Expand Down Expand Up @@ -56,34 +57,16 @@ class ExecutePlanResponseReattachableIterator(Generator):
ReleaseExecute RPCs that instruct the server to release responses that it already processed.
"""

# Lock to manage the pool
_lock: ClassVar[RLock] = RLock()
_release_thread_pool_instance: Optional[ThreadPoolExecutor] = None
_instances: ClassVar[
weakref.WeakSet["ExecutePlanResponseReattachableIterator"]
] = weakref.WeakSet()

@classmethod
def _get_or_create_release_thread_pool(cls) -> ThreadPoolExecutor:
# Perform a first check outside the critical path.
if cls._release_thread_pool_instance is not None:
return cls._release_thread_pool_instance
with cls._lock:
if cls._release_thread_pool_instance is None:
max_workers = os.cpu_count() or 8
cls._release_thread_pool_instance = ThreadPoolExecutor(max_workers=max_workers)
return cls._release_thread_pool_instance

@classmethod
def shutdown(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None:
"""
When the channel is closed, this method will be called before, to make sure all
outstanding calls are closed.
"""
with cls._lock:
if cls._release_thread_pool_instance is not None:
thread_pool = cls._release_thread_pool_instance
cls._release_thread_pool_instance = None
# This method could be called within the thread pool so don't wait for the
# shutdown to complete. Otherwise it could deadlock.
thread_pool.shutdown(wait=False)
def __new__(cls, *args: Any, **kwargs: Any) -> "ExecutePlanResponseReattachableIterator":
instance = super().__new__(cls)
cls._instances.add(instance)
return instance

def __init__(
self,
Expand All @@ -92,9 +75,9 @@ def __init__(
retrying: Callable[[], Retrying],
metadata: Iterable[Tuple[str, str]],
):
self._release_thread_pool # Trigger initialization
self._request = request
self._retrying = retrying
self._release_futures: weakref.WeakSet[Future] = weakref.WeakSet()
if request.operation_id:
self._operation_id = request.operation_id
else:
Expand Down Expand Up @@ -132,9 +115,30 @@ def __init__(
# Current item from this iterator.
self._current: Optional[pb2.ExecutePlanResponse] = None

@property
def release_futures(self) -> weakref.WeakSet[Future]:
return self._release_futures

@property
def _release_thread_pool(self) -> ThreadPoolExecutor:
return self._get_or_create_release_thread_pool()
if self._release_thread_pool_instance is not None:
# It's impossible for the thread pool to shutdown when the iterator is still alive
# We can safely return the thread pool instance here as long as it is not used
# by external objects.
return self._release_thread_pool_instance
with self._lock:
if self._release_thread_pool_instance is None:
self._release_thread_pool_instance = ThreadPoolExecutor(
max_workers=os.cpu_count() or 8
)
return self._release_thread_pool_instance

@classmethod
def shutdown_threadpool_if_idle(cls) -> None:
with cls._lock:
if not cls._instances and cls._release_thread_pool_instance is not None:
cls._release_thread_pool_instance.shutdown(wait=False)
cls._release_thread_pool_instance = None

def send(self, value: Any) -> pb2.ExecutePlanResponse:
# will trigger reattach in case the stream completed without result_complete
Expand Down Expand Up @@ -214,11 +218,7 @@ def target() -> None:
except Exception as e:
logger.warn(f"ReleaseExecute failed with exception: {e}.")

with self._lock:
if self._release_thread_pool_instance is not None:
thread_pool = self._release_thread_pool
if not thread_pool._shutdown:
thread_pool.submit(target)
self._release_futures.add(self._release_thread_pool.submit(target))

def _release_all(self) -> None:
"""
Expand All @@ -241,11 +241,7 @@ def target() -> None:
except Exception as e:
logger.warn(f"ReleaseExecute failed with exception: {e}.")

with self._lock:
if self._release_thread_pool_instance is not None:
thread_pool = self._release_thread_pool
if not thread_pool._shutdown:
thread_pool.submit(target)
self._release_futures.add(self._release_thread_pool.submit(target))
self._result_complete = True

def _call_iter(self, iter_fun: Callable) -> Any:
Expand Down