diff --git a/snuba/subscriptions/executor_consumer.py b/snuba/subscriptions/executor_consumer.py index afb94964266..18c67de7ff4 100644 --- a/snuba/subscriptions/executor_consumer.py +++ b/snuba/subscriptions/executor_consumer.py @@ -292,15 +292,27 @@ def poll(self) -> None: ) except QueryException as exc: cause = exc.__cause__ - if isinstance(cause, ClickhouseError): - if cause.code in NON_RETRYABLE_CLICKHOUSE_ERROR_CODES: - logger.exception("Error running subscription query %r", exc) - self.__metrics.increment( - "subscription_executor_nonretryable_error", - tags={"error_type": str(cause.code)}, - ) - else: - raise SubscriptionQueryException(exc.message) + # Retryable ClickHouse errors are re-raised so the consumer + # crashes and the message is retried. Every other failure + # (non-retryable ClickHouse error codes as well as non-ClickHouse + # causes such as QueryTooLongException) is non-retryable: log it, + # emit a metric and skip the message instead of submitting an + # unassigned transformed_message downstream. + if ( + isinstance(cause, ClickhouseError) + and cause.code not in NON_RETRYABLE_CLICKHOUSE_ERROR_CODES + ): + raise SubscriptionQueryException(exc.message) + + logger.exception("Error running subscription query %r", exc) + error_type = ( + str(cause.code) if isinstance(cause, ClickhouseError) else type(cause).__name__ + ) + self.__metrics.increment( + "subscription_executor_nonretryable_error", + tags={"error_type": error_type}, + ) + continue self.__next_step.submit(transformed_message) diff --git a/tests/subscriptions/test_executor_consumer.py b/tests/subscriptions/test_executor_consumer.py index 15c66353650..9f1b90a5f78 100644 --- a/tests/subscriptions/test_executor_consumer.py +++ b/tests/subscriptions/test_executor_consumer.py @@ -1,6 +1,7 @@ import json import time import uuid +from concurrent.futures import Future from datetime import datetime, timedelta from typing import Iterator, Mapping, Optional from unittest import mock @@ -42,6 +43,7 @@ get_default_kafka_configuration, ) from snuba.utils.streams.topics import Topic as SnubaTopic +from snuba.web import QueryException, QueryTooLongException from tests.backends.metrics import Increment, TestingMetricsBackend @@ -97,7 +99,7 @@ def on_partitions_assigned(partitions: Mapping[Partition, int]) -> None: # We need to wait for the consumer to receive partitions otherwise, # when we try to consume messages, we will not find anything. # Subscription is an async process. - assert assigned == True, "Did not receive assignment within 10 attempts" + assert assigned, "Did not receive assignment within 10 attempts" consumer_group = str(uuid.uuid1().hex) auto_offset_reset = "latest" @@ -288,6 +290,55 @@ def test_execute_query_exception() -> None: strategy.join() +@pytest.mark.redis_db +@pytest.mark.clickhouse_db +def test_poll_skips_non_retryable_query_exception() -> None: + """ + A QueryException whose cause is not a retryable ClickhouseError (e.g. + QueryTooLongException) must be logged and skipped, not re-raised, and must + not fall through to submitting an unassigned transformed_message (which + previously raised UnboundLocalError). See SNUBA-9E1. + """ + metrics = TestingMetricsBackend() + next_step = mock.Mock() + + strategy = ExecuteQuery( + dataset=get_dataset("events"), + entity_names=["events"], + max_concurrent_queries=2, + stale_threshold_seconds=None, + metrics=metrics, + next_step=next_step, + ) + + exc = QueryException("boom") + exc.__cause__ = QueryTooLongException("query is too long") + + future: "Future[object]" = Future() + future.set_exception(exc) + + message = Message(BrokerValue("payload", Partition(Topic("test"), 0), 0, datetime(1970, 1, 1))) + result_future = mock.Mock() + result_future.future = future + + strategy._ExecuteQuery__queue.append((message, result_future)) # type: ignore[attr-defined] + + strategy.poll() + + assert next_step.submit.call_count == 0 + assert ( + Increment( + "subscription_executor_nonretryable_error", + 1, + {"error_type": "QueryTooLongException"}, + ) + in metrics.calls + ) + + strategy.close() + strategy.join() + + @pytest.mark.redis_db @pytest.mark.clickhouse_db def test_too_many_concurrent_queries() -> None: