Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@
# but still work for all expected requests
_DEFAULT_TIMEOUT = 600

# Errors that indicate a transient connectivity failure. The stream can be
# transparently reconnected and in-flight requests replayed on these errors.
_STREAM_RESUMPTION_EXCEPTIONS = (
exceptions.ServiceUnavailable,
exceptions.Unknown,
)


def _is_retryable_error(reason: Optional[BaseException]) -> bool:
return isinstance(reason, _STREAM_RESUMPTION_EXCEPTIONS)


def _wrap_as_exception(maybe_exception) -> BaseException:
"""Wrap an object as a Python exception, if needed.
Expand Down Expand Up @@ -191,25 +202,36 @@ def close(self, reason: Optional[Exception] = None) -> None:
def _renew_connection(self, reason: Optional[Exception] = None) -> None:
"""Helper function that is called when the RPC connection is closed
without recovery. It first creates a new Connection instance in an
atomic manner, and then cleans up the failed connection. Note that a
new RPC connection is not established by instantiating _Connection,
but only when `send()` is called for the first time.
atomic manner, and then cleans up the failed connection.

On transient errors (:data:`_STREAM_RESUMPTION_EXCEPTIONS`) any
in-flight requests are replayed on the new connection so that callers
do not need to handle reconnection themselves. On non-transient errors
the pending futures are failed immediately as before.
"""
# Creates a new Connection instance, but doesn't establish a new RPC
# connection. New connection is only started when `send()` is called
# again, in order to save resource if the stream is idle. This action
# is atomic.
with self._thread_lock:
_closed_connection = self._connection
self._connection = _Connection(
client=self._client,
writer=self,
metadata=self._metadata,
)

# Cleanup, and marks futures as failed. To minimize the length of the
# critical section, this step is not guaranteed to be atomic.
_closed_connection._shutdown(reason=reason)
# Copy the stream name so the new connection can build routing
# metadata even before the first send().
self._connection._stream_name = self._stream_name

# Shutdown the old connection. On transient errors this returns the
# in-flight (request, future) pairs so we can replay them; on
# non-transient errors it returns an empty list after failing futures.
pending = _closed_connection._shutdown(reason=reason)

if pending:
_LOGGER.debug(
"Replaying %d in-flight request(s) after transient error: %s",
len(pending),
reason,
)
self._connection._reopen_with_pending(pending)

def _on_rpc_done(self, reason: Optional[BaseException] = None) -> None:
"""Callback passecd to _Connection. It's called when the RPC connection
Expand Down Expand Up @@ -257,7 +279,9 @@ def __init__(
self._rpc: Union[bidi.BidiRpc | None] = None
self._consumer: Union[bidi.BackgroundConsumer | None] = None
self._stream_name: str = ""
self._queue: queue.Queue[AppendRowsFuture] = queue.Queue()
self._queue: queue.Queue[
Tuple[gapic_types.AppendRowsRequest, AppendRowsFuture]
] = queue.Queue()

# statuses
self._closed = False
Expand Down Expand Up @@ -314,7 +338,7 @@ def _open(
request = self._make_initial_request(initial_request)

future = AppendRowsFuture(self._writer)
self._queue.put(future)
self._queue.put((initial_request, future))

self._rpc = bidi.BidiRpc(
self._client.append_rows,
Expand Down Expand Up @@ -428,22 +452,32 @@ def send(self, request: gapic_types.AppendRowsRequest) -> AppendRowsFuture:
# future to the queue so that when the response comes, the callback can
# pull it off and notify completion.
future = AppendRowsFuture(self._writer)
self._queue.put(future)
self._queue.put((request, future))
if self._rpc is not None:
self._rpc.send(request)
return future

def _shutdown(self, reason: Optional[Exception] = None) -> None:
def _shutdown(
self, reason: Optional[Exception] = None
) -> List[Tuple[gapic_types.AppendRowsRequest, "AppendRowsFuture"]]:
"""Run the actual shutdown sequence (stop the stream and all helper threads).

Args:
reason:
The reason to close the stream. If ``None``, this is
considered an "intentional" shutdown.

Returns:
A list of ``(request, future)`` pairs for requests that were
in-flight when the connection closed. On transient errors these
are returned instead of being failed so the caller can replay
them on a new connection. On non-transient errors the list is
always empty (futures are failed immediately).
"""
pending: List[Tuple[gapic_types.AppendRowsRequest, "AppendRowsFuture"]] = []
with self._thread_lock:
if self._closed:
return
return pending

# Stop consuming messages.
if self.is_active:
Expand All @@ -459,19 +493,25 @@ def _shutdown(self, reason: Optional[Exception] = None) -> None:

# We know that no new items will be added to the queue because
# we've marked the stream as closed.
retryable = _is_retryable_error(reason)
while not self._queue.empty():
# Mark each future as failed. Since the consumer thread has
# stopped (or at least is attempting to stop), we won't get
# response callbacks to populate the remaining futures.
future = self._queue.get_nowait()
exc: Union[Exception, bqstorage_exceptions.StreamClosedError]
if reason is None:
exc = bqstorage_exceptions.StreamClosedError(
"Stream closed before receiving a response."
)
# On transient errors, collect in-flight requests so they can
# be replayed on a fresh connection instead of surfacing an
# error to the caller.
request, future = self._queue.get_nowait()
if retryable:
pending.append((request, future))
else:
exc = reason
future.set_exception(exc)
exc: Union[Exception, bqstorage_exceptions.StreamClosedError]
if reason is None:
exc = bqstorage_exceptions.StreamClosedError(
"Stream closed before receiving a response."
)
else:
exc = reason
future.set_exception(exc)

return pending

def close(self, reason: Optional[Exception] = None) -> None:
"""Stop consuming messages and shutdown all helper threads.
Expand All @@ -496,7 +536,7 @@ def _on_response(self, response: gapic_types.AppendRowsResponse) -> None:

# Since we have 1 response per request, if we get here from a response
# callback, the queue should never be empty.
future: AppendRowsFuture = self._queue.get_nowait()
_, future = self._queue.get_nowait()
if response.error.code:
exc = exceptions.from_grpc_status(
response.error.code, response.error.message, response=response
Expand All @@ -505,6 +545,86 @@ def _on_response(self, response: gapic_types.AppendRowsResponse) -> None:
else:
future.set_result(response)

def _reopen_with_pending(
self,
pending: List[Tuple[gapic_types.AppendRowsRequest, "AppendRowsFuture"]],
timeout: float = _DEFAULT_TIMEOUT,
) -> None:
"""Open a fresh RPC connection and replay ``pending`` in-flight requests.

The existing :class:`AppendRowsFuture` objects are reused so callers
that already hold references transparently receive their results once
the server acknowledges the replayed requests.

Args:
pending:
``(request, future)`` pairs collected from the failed
connection's queue. The first entry is used as the stream's
initial request (merged with the writer template); subsequent
entries are sent in order once the connection is active.
timeout:
How long (in seconds) to wait for the stream to be ready.
"""
if not pending:
return

initial_user_request, initial_future = pending[0]

with self._thread_lock:
# Inject the existing future so _on_response resolves it.
self._queue.put((initial_user_request, initial_future))

merged = self._make_initial_request(initial_user_request)
metadata = tuple(self._metadata) + (
(
"x-goog-request-params",
f"write_stream={self._stream_name}",
),
)
rpc = bidi.BidiRpc(
self._client.append_rows,
initial_request=merged,
metadata=metadata,
)
rpc.add_done_callback(self._on_rpc_done)

consumer = bidi.BackgroundConsumer(rpc, self._on_response)
consumer.start()

self._rpc = rpc
self._consumer = consumer

start_time = time.monotonic()
while not rpc.is_active and consumer.is_active:
time.sleep(_WRITE_OPEN_INTERVAL)
if timeout is not None and time.monotonic() - start_time > timeout:
break

if not consumer.is_active:
# Connection failed — drain the queue and fail futures directly
# rather than going through close() to avoid triggering another
# reconnect attempt (which would cause an infinite retry loop).
exc = exceptions.Unknown(
"There was a problem reopening the stream after a transient error. "
"Try turning on DEBUG level logs to see the error."
)
with self._thread_lock:
self._closed = True
while not self._queue.empty():
_, future = self._queue.get_nowait()
if not future.done():
future.set_exception(exc)
for _, future in pending:
if not future.done():
future.set_exception(exc)
return
Comment on lines +573 to +620
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are three important issues in _reopen_with_pending that should be addressed:

  1. Thread Safety & Potential AttributeError: Accessing self._rpc and self._consumer outside the lock block can lead to an AttributeError if another thread calls _shutdown and sets them to None. Using local variables rpc and consumer resolved within the lock block makes the wait loop and status check completely thread-safe and eliminates the need for the try-except AttributeError block.
  2. Hanging Futures on Reopen Failure: If the connection fails to reopen, only the first pending request (which was added to self._queue) is failed. The remaining pending[1:] requests are never added to self._queue and are never failed, leaving them hanging indefinitely. We should iterate over all pending futures and fail them, guarding with not future.done() to avoid raising InvalidStateError on already-resolved futures.
  3. Avoid itertools Dependency: We can construct the metadata tuple using standard tuple concatenation instead of itertools.chain, which avoids a potential NameError if itertools is not imported at the module level.
        with self._thread_lock:
            # Inject the existing future so _on_response resolves it.
            self._queue.put((initial_user_request, initial_future))

            merged = self._make_initial_request(initial_user_request)
            metadata = tuple(self._metadata) + (
                (
                    "x-goog-request-params",
                    f"write_stream={self._stream_name}",
                ),
            )
            rpc = bidi.BidiRpc(
                self._client.append_rows,
                initial_request=merged,
                metadata=metadata,
            )
            rpc.add_done_callback(self._on_rpc_done)

            consumer = bidi.BackgroundConsumer(rpc, self._on_response)
            consumer.start()

            self._rpc = rpc
            self._consumer = consumer

        start_time = time.monotonic()
        while not rpc.is_active and consumer.is_active:
            time.sleep(_WRITE_OPEN_INTERVAL)
            if timeout is not None and time.monotonic() - start_time > timeout:
                break

        if not consumer.is_active:
            # Connection failed — drain the queue and fail futures directly
            # rather than going through close() to avoid triggering another
            # reconnect attempt (which would cause an infinite retry loop).
            exc = exceptions.Unknown(
                "There was a problem reopening the stream after a transient error. "
                "Try turning on DEBUG level logs to see the error."
            )
            with self._thread_lock:
                self._closed = True
                while not self._queue.empty():
                    _, future = self._queue.get_nowait()
                    if not future.done():
                        future.set_exception(exc)
                for _, future in pending:
                    if not future.done():
                        future.set_exception(exc)
            return


# Send remaining pending requests over the live connection.
for request, future in pending[1:]:
self._queue.put((request, future))
if self._rpc is not None:
self._rpc.send(request)

def _on_rpc_done(self, future: AppendRowsFuture) -> None:
"""Triggered when the underlying RPC terminates without recovery.

Expand Down
Loading
Loading