From 1a853758c99885e9c480a6292895f8a241941680 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Cabrera?= Date: Wed, 8 Apr 2026 11:33:31 -0600 Subject: [PATCH 1/2] Add TaskMetrics and emit_metrics for task performance tracking - Introduced TaskMetrics class to encapsulate task-related metrics. - Implemented emit_metrics function to asynchronously send metrics data. - Updated run_task to include metrics tracking and logging. - Enhanced message consumer and task decorator to support graceful shutdown and metrics emission. - Added tests for metrics functionality and graceful shutdown behavior. --- agave/tasks/__init__.py | 3 + agave/tasks/metrics.py | 26 +++ agave/tasks/sqs_tasks.py | 177 ++++++++++++---- tests/conftest.py | 22 ++ tests/tasks/test_metrics.py | 43 ++++ tests/tasks/test_sqs_tasks.py | 389 ++++++++++++++++++++++++++++++++++ 6 files changed, 613 insertions(+), 47 deletions(-) create mode 100644 agave/tasks/metrics.py create mode 100644 tests/tasks/test_metrics.py diff --git a/agave/tasks/__init__.py b/agave/tasks/__init__.py index e69de29b..86426099 100644 --- a/agave/tasks/__init__.py +++ b/agave/tasks/__init__.py @@ -0,0 +1,3 @@ +from .metrics import TaskMetrics, emit_metrics + +__all__ = ['TaskMetrics', 'emit_metrics'] diff --git a/agave/tasks/metrics.py b/agave/tasks/metrics.py new file mode 100644 index 00000000..00ac96a9 --- /dev/null +++ b/agave/tasks/metrics.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import Callable, Optional + + +@dataclass +class TaskMetrics: + task_name: str + queue_url: str + concurrent_tasks_counter: Callable[[], int] + + +async def emit_metrics( + metrics: TaskMetrics, + status: str, + duration_ms: float, + metrics_callback: Optional[Callable] = None, +) -> None: + if not metrics_callback: + return + await metrics_callback( + task_name=metrics.task_name, + queue_url=metrics.queue_url, + status=status, + duration_ms=duration_ms, + concurrent_tasks=metrics.concurrent_tasks_counter(), + ) diff --git a/agave/tasks/sqs_tasks.py b/agave/tasks/sqs_tasks.py index cda571c5..e2b53d78 100644 --- a/agave/tasks/sqs_tasks.py +++ b/agave/tasks/sqs_tasks.py @@ -2,10 +2,12 @@ import json import logging import os +import signal +import time from functools import wraps from itertools import count from json import JSONDecodeError -from typing import AsyncGenerator, Callable, Coroutine +from typing import AsyncGenerator, Callable, Coroutine, Optional from aiobotocore.httpsession import HTTPClientError from aiobotocore.session import get_session @@ -18,6 +20,7 @@ get_sensitive_fields, obfuscate_sensitive_data, ) +from .metrics import TaskMetrics, emit_metrics logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -25,7 +28,7 @@ AWS_DEFAULT_REGION = os.getenv('AWS_DEFAULT_REGION', '') -BACKGROUND_TASKS = set() +BACKGROUND_TASKS: set[asyncio.Task] = set() async def run_task( @@ -36,6 +39,10 @@ async def run_task( queue_url: str, message_receive_count: int, max_retries: int, + delete_on_failure: bool, + task_name: str, + task_module: str, + metrics_callback: Optional[Callable] = None, ) -> None: delete_message = True request_model = get_request_model(task_func) @@ -48,8 +55,8 @@ async def run_task( response_log_config_fields = get_sensitive_fields(response_model) log_data = { 'request': { - 'task_func': task_func.__name__, - 'task_module': task_func.__module__, + 'task_func': task_name, + 'task_module': task_module, 'queue_url': queue_url, 'max_retries': max_retries, 'body': ofuscated_request_body, @@ -61,11 +68,13 @@ async def run_task( 'status': 'success', }, } + start_time = time.monotonic() try: resp = await task_func(body) except RetryTask as retry: - delete_message = message_receive_count >= max_retries + 1 - if not delete_message and retry.countdown and retry.countdown > 0: + retries_exhausted = message_receive_count >= max_retries + 1 + delete_message = retries_exhausted and delete_on_failure + if not retries_exhausted and retry.countdown and retry.countdown > 0: await sqs.change_message_visibility( QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'], @@ -74,6 +83,7 @@ async def run_task( log_data['response']['delete_message'] = delete_message log_data['response']['status'] = 'retrying' except Exception as exp: + delete_message = delete_on_failure log_data['response']['status'] = 'failed' log_data['response']['error'] = str(exp) else: @@ -86,12 +96,29 @@ async def run_task( ofuscated_response_body = resp log_data['response']['body'] = ofuscated_response_body finally: + duration_ms = round((time.monotonic() - start_time) * 1000, 2) + log_data['response']['duration_ms'] = duration_ms if delete_message: await sqs.delete_message( QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'], ) logger.info(json.dumps(log_data, default=str)) + if metrics_callback: + try: + task_metrics = TaskMetrics( + task_name=task_name, + queue_url=queue_url, + concurrent_tasks_counter=lambda: len(BACKGROUND_TASKS), + ) + await emit_metrics( + metrics=task_metrics, + status=log_data['response']['status'], + duration_ms=duration_ms, + metrics_callback=metrics_callback, + ) + except Exception as exc: + logger.warning(f'metrics_callback failed: {exc}') async def message_consumer( @@ -100,9 +127,12 @@ async def message_consumer( visibility_timeout: int, can_read: asyncio.Event, sqs, + shutdown_event: asyncio.Event, ) -> AsyncGenerator: for _ in count(): await can_read.wait() + if shutdown_event.is_set(): + return try: response = await sqs.receive_message( QueueUrl=queue_url, @@ -112,8 +142,12 @@ async def message_consumer( ) messages = response['Messages'] except KeyError: + if shutdown_event.is_set(): + return continue except HTTPClientError: + if shutdown_event.is_set(): + return await asyncio.sleep(1) continue for message in messages: @@ -135,11 +169,14 @@ def task( visibility_timeout: int = 3600, max_retries: int = 1, max_concurrent_tasks: int = 5, + delete_on_failure: bool = True, + metrics_callback: Optional[Callable] = None, ): def task_builder(task_func: Callable): @wraps(task_func) async def start_task(*args, **kwargs) -> None: can_read = asyncio.Event() + shutdown_event = asyncio.Event() concurrency_semaphore = asyncio.Semaphore(max_concurrent_tasks) can_read.set() @@ -151,51 +188,97 @@ async def concurrency_controller(coro: Coroutine) -> None: try: await coro finally: - can_read.set() + if not shutdown_event.is_set(): + can_read.set() - session = get_session() + loop = asyncio.get_running_loop() - task_with_validators = validate_call(task_func) + def _handle_signal(sig): + logger.info( + f'Received {sig.name}, initiating graceful shutdown' + ) + shutdown_event.set() + can_read.set() - async with session.create_client('sqs', region_name) as sqs: - async for message in message_consumer( - queue_url, - wait_time_seconds, - visibility_timeout, - can_read, - sqs, - ): - try: - body = json.loads(message['Body']) - except JSONDecodeError: - continue - - message_receive_count = int( - message['Attributes']['ApproximateReceiveCount'] - ) - bg_task = asyncio.create_task( - concurrency_controller( - run_task( - task_with_validators, - body, - message, - sqs, - queue_url, - message_receive_count, - max_retries, + previous_handlers = {} + for sig in (signal.SIGTERM, signal.SIGINT): + previous_handlers[sig] = signal.getsignal(sig) + loop.add_signal_handler(sig, _handle_signal, sig) + + try: + session = get_session() + + task_with_validators = validate_call(task_func) + + async with session.create_client('sqs', region_name) as sqs: + async for message in message_consumer( + queue_url, + wait_time_seconds, + visibility_timeout, + can_read, + sqs, + shutdown_event, + ): + try: + body = json.loads(message['Body']) + except JSONDecodeError: + msg_id = message['MessageId'] + log_data = dict( + message_id=msg_id, + status='invalid_json', + ) + logger.warning(json.dumps(log_data)) + await sqs.delete_message( + QueueUrl=queue_url, + ReceiptHandle=message['ReceiptHandle'], + ) + continue + + message_receive_count = int( + message['Attributes']['ApproximateReceiveCount'] + ) + bg_task = asyncio.create_task( + concurrency_controller( + run_task( + task_with_validators, + body, + message, + sqs, + queue_url, + message_receive_count, + max_retries, + delete_on_failure, + task_name=task_func.__name__, + task_module=task_func.__module__, + metrics_callback=metrics_callback, + ), ), - ), - name='fast-agave-task', - ) - BACKGROUND_TASKS.add(bg_task) - bg_task.add_done_callback(BACKGROUND_TASKS.discard) - - # Espera a que terminen todos los tasks pendientes creados por - # `asyncio.create_task`. De esta forma los tasks - # podrán borrar el mensaje del queue usando la misma instancia - # del cliente de SQS - running_tasks = await get_running_fast_agave_tasks() - await asyncio.gather(*running_tasks) + name='fast-agave-task', + ) + BACKGROUND_TASKS.add(bg_task) + bg_task.add_done_callback(BACKGROUND_TASKS.discard) + + running_tasks = await get_running_fast_agave_tasks() + if shutdown_event.is_set(): + try: + await asyncio.wait_for( + asyncio.gather(*running_tasks), + timeout=visibility_timeout, + ) + except asyncio.TimeoutError: + logger.warning( + 'Graceful shutdown timeout, tasks may retry' + ) + else: + await asyncio.gather(*running_tasks) + finally: + for sig, handler in previous_handlers.items(): + loop.remove_signal_handler(sig) + if handler and handler not in ( + signal.SIG_DFL, + signal.SIG_IGN, + ): + signal.signal(sig, handler) return start_task diff --git a/tests/conftest.py b/tests/conftest.py index e09323e9..118c4325 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ import datetime as dt import functools +import json import logging import os +import signal from functools import partial from typing import Callable, Generator @@ -291,3 +293,23 @@ def set_log_level(caplog): Automatically set logging level to INFO for all tests. """ caplog.set_level(logging.INFO) + + +TEST_MESSAGE = dict(id='abc123', name='fast-agave') + + +@pytest.fixture +async def enqueued_message(sqs_client): + await sqs_client.send_message( + MessageBody=json.dumps(TEST_MESSAGE), + MessageGroupId='1234', + ) + return TEST_MESSAGE + + +@pytest.fixture +def trigger_shutdown(): + def _trigger(): + os.kill(os.getpid(), signal.SIGTERM) + + return _trigger diff --git a/tests/tasks/test_metrics.py b/tests/tasks/test_metrics.py new file mode 100644 index 00000000..31df4824 --- /dev/null +++ b/tests/tasks/test_metrics.py @@ -0,0 +1,43 @@ +from unittest.mock import AsyncMock + +from agave.tasks.metrics import TaskMetrics, emit_metrics + +TASK_NAME = 'my_task' +QUEUE_URL = 'https://sqs.us-east-1.amazonaws.com/queue' + + +async def test_emit_metrics_calls_callback() -> None: + callback = AsyncMock() + metrics = TaskMetrics( + task_name=TASK_NAME, + queue_url=QUEUE_URL, + concurrent_tasks_counter=lambda: 3, + ) + + await emit_metrics( + metrics=metrics, + status='success', + duration_ms=150.5, + metrics_callback=callback, + ) + + callback.assert_called_once_with( + task_name=TASK_NAME, + queue_url=QUEUE_URL, + status='success', + duration_ms=150.5, + concurrent_tasks=3, + ) + + +async def test_emit_metrics_skips_when_no_callback() -> None: + metrics = TaskMetrics( + task_name=TASK_NAME, + queue_url=QUEUE_URL, + concurrent_tasks_counter=lambda: 0, + ) + await emit_metrics( + metrics=metrics, + status='success', + duration_ms=100.0, + ) diff --git a/tests/tasks/test_sqs_tasks.py b/tests/tasks/test_sqs_tasks.py index 3a2523c8..e14382d9 100644 --- a/tests/tasks/test_sqs_tasks.py +++ b/tests/tasks/test_sqs_tasks.py @@ -13,6 +13,7 @@ from agave.tasks.sqs_tasks import ( BACKGROUND_TASKS, get_running_fast_agave_tasks, + message_consumer, task, ) @@ -427,3 +428,391 @@ async def my_task(data: dict) -> None: resp = await sqs_client.receive_message() assert 'Messages' not in resp assert len(BACKGROUND_TASKS) == 0 + + +async def test_invalid_json_message_is_deleted(sqs_client) -> None: + """ + Verifica que los mensajes con JSON inválido son eliminados del queue + y el task nunca es ejecutado, con visibility_timeout alto para + confirmar que la eliminación es explícita + """ + await sqs_client.send_message( + MessageBody='not valid json!!!', + MessageGroupId='1234', + ) + + async_mock_function = AsyncMock() + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=10, + )(my_task)() + + async_mock_function.assert_not_called() + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + + +async def test_delete_on_failure_false_unhandled_exception( + sqs_client, + enqueued_message, + trigger_shutdown, +) -> None: + """ + Cuando delete_on_failure=False, un mensaje que falla por una excepción + no controlada NO debe ser eliminado del queue (para que el redrive + policy lo envíe al DLQ) + """ + async_mock_function = AsyncMock() + + async def my_task(data: dict) -> None: + await async_mock_function(data) + trigger_shutdown() + raise Exception('something went wrong :(') + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + max_retries=1, + delete_on_failure=False, + )(my_task)() + + assert async_mock_function.call_count == 1 + async_mock_function.assert_called_with(enqueued_message) + + resp = await sqs_client.receive_message(WaitTimeSeconds=2) + assert 'Messages' in resp + + +async def test_delete_on_failure_false_retries_exhausted( + sqs_client, + enqueued_message, + trigger_shutdown, +) -> None: + """ + Cuando delete_on_failure=False y se agotan los reintentos por RetryTask, + el mensaje NO debe ser eliminado del queue + """ + retry_count = 0 + async_mock_function = AsyncMock(side_effect=RetryTask) + + async def my_task(data: dict) -> None: + nonlocal retry_count + retry_count += 1 + if retry_count >= 2: + trigger_shutdown() + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + max_retries=1, + delete_on_failure=False, + )(my_task)() + + expected_calls = [call(enqueued_message)] * 2 + assert async_mock_function.call_count == len(expected_calls) + async_mock_function.assert_has_calls(expected_calls) + + resp = await sqs_client.receive_message(WaitTimeSeconds=2) + assert 'Messages' in resp + + +async def test_delete_on_failure_true_is_default_behavior( + sqs_client, + enqueued_message, +) -> None: + """ + Verifica que delete_on_failure=True (el default) mantiene el + comportamiento actual: elimina el mensaje cuando hay una excepción + no controlada + """ + async_mock_function = AsyncMock( + side_effect=Exception('something went wrong :(') + ) + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + max_retries=1, + delete_on_failure=True, + )(my_task)() + + async_mock_function.assert_called_with(enqueued_message) + assert async_mock_function.call_count == 1 + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + + +async def test_graceful_shutdown_completes_inflight_task( + sqs_client, + enqueued_message, + trigger_shutdown, +) -> None: + """ + Verifica que al recibir SIGTERM, los tasks en vuelo completan + antes de que el listener se detenga + """ + task_completed = False + + async def my_task(data: dict) -> None: + nonlocal task_completed + trigger_shutdown() + await asyncio.sleep(0.5) + task_completed = True + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=10, + )(my_task)() + + assert task_completed is True + + +async def test_no_new_messages_after_shutdown( + sqs_client, + trigger_shutdown, +) -> None: + """ + Verifica que después de recibir SIGTERM con max_concurrent=1, + el consumer se detiene y los mensajes restantes quedan en + el queue + """ + for i in range(3): + await sqs_client.send_message( + MessageBody=json.dumps(dict(id=f'msg{i}')), + MessageGroupId=str(i), + ) + + processed = [] + + async def my_task(data: dict) -> None: + processed.append(data['id']) + if data['id'] == 'msg0': + trigger_shutdown() + await asyncio.sleep(0.5) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=10, + max_concurrent_tasks=1, + )(my_task)() + + assert processed[0] == 'msg0' + assert 'msg2' not in processed + + +async def test_http_client_error_during_shutdown() -> None: + """ + Cuando ocurre un HTTPClientError y shutdown_event ya está activo, + el consumer debe detenerse inmediatamente + """ + shutdown_event = asyncio.Event() + can_read = asyncio.Event() + can_read.set() + + mock_sqs = AsyncMock() + + async def receive_and_shutdown(**kw): + shutdown_event.set() + raise HTTPClientError(error='Connection reset') + + mock_sqs.receive_message = receive_and_shutdown + + messages = [] + async for msg in message_consumer( + 'queue_url', + 1, + 1, + can_read, + mock_sqs, + shutdown_event, + ): + messages.append(msg) + + assert messages == [] + + +async def test_graceful_shutdown_timeout( + sqs_client, + enqueued_message, + trigger_shutdown, + caplog, +) -> None: + """ + Cuando un task en vuelo no termina dentro del visibility_timeout, + el shutdown debe completar con un warning de timeout + """ + + async def my_task(data: dict) -> None: + trigger_shutdown() + await asyncio.sleep(60) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + )(my_task)() + + assert any( + 'Graceful shutdown timeout' in r.message for r in caplog.records + ) + + +async def test_duration_ms_in_log( + sqs_client, + enqueued_message, + caplog, +) -> None: + """ + Verifica que el log de request/response incluye duration_ms + """ + + async def my_task(data: dict) -> None: + await asyncio.sleep(0.1) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + )(my_task)() + + request_logs = [ + json.loads(r.message) + for r in caplog.records + if r.name == 'agave.tasks.sqs_tasks' + and '{' in r.message + and 'duration_ms' in r.message + ] + assert len(request_logs) >= 1 + assert request_logs[0]['response']['duration_ms'] >= 100 + + +async def test_metrics_callback_is_called( + sqs_client, + enqueued_message, +) -> None: + """ + Verifica que metrics_callback se invoca con los datos + correctos después de ejecutar un task + """ + metrics_mock = AsyncMock() + + async def my_task(data: dict) -> None: + pass + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + metrics_callback=metrics_mock, + )(my_task)() + + metrics_mock.assert_called_once() + call_kwargs = metrics_mock.call_args.kwargs + assert call_kwargs['task_name'] == 'my_task' + assert call_kwargs['status'] == 'success' + assert call_kwargs['duration_ms'] >= 0 + assert 'concurrent_tasks' in call_kwargs + assert 'queue_url' in call_kwargs + + +async def test_metrics_callback_on_failure( + sqs_client, + enqueued_message, +) -> None: + """ + Verifica que metrics_callback reporta status='failed' + cuando el task falla + """ + metrics_mock = AsyncMock() + + async def my_task(data: dict) -> None: + raise Exception('boom') + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + metrics_callback=metrics_mock, + )(my_task)() + + metrics_mock.assert_called_once() + assert metrics_mock.call_args.kwargs['status'] == 'failed' + + +async def test_metrics_callback_on_retry( + sqs_client, + enqueued_message, +) -> None: + """ + Verifica que metrics_callback reporta status='retrying' + cuando el task hace retry + """ + metrics_mock = AsyncMock() + + async def my_task(data: dict) -> None: + raise RetryTask() + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + max_retries=1, + metrics_callback=metrics_mock, + )(my_task)() + + assert metrics_mock.call_count == 2 + statuses = [c.kwargs['status'] for c in metrics_mock.call_args_list] + assert 'retrying' in statuses + + +async def test_metrics_callback_error_is_handled( + sqs_client, + enqueued_message, +) -> None: + """ + Si metrics_callback lanza una excepción, el task debe completar + normalmente sin propagar el error + """ + async_mock_function = AsyncMock() + metrics_mock = AsyncMock(side_effect=Exception('metrics broke')) + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + metrics_callback=metrics_mock, + )(my_task)() + + async_mock_function.assert_called_with(enqueued_message) + metrics_mock.assert_called_once() From 17a3655c934af83b1520da2084b1b8456b3501d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Cabrera?= Date: Wed, 8 Apr 2026 12:46:45 -0600 Subject: [PATCH 2/2] Update version to 1.5.3.dev0 and enhance type annotations for metrics callbacks - Bump version from 1.5.2 to 1.5.3.dev0. - Refine type annotations for metrics_callback in emit_metrics and run_task functions to specify Awaitable return type. - Update test assertions to reflect changes in task retry status handling. - Add a new test for immediate shutdown behavior in message consumer. --- agave/tasks/metrics.py | 4 +-- agave/tasks/sqs_tasks.py | 16 ++++++--- agave/version.py | 2 +- tests/tasks/test_loggin_tasks.py | 4 +-- tests/tasks/test_sqs_tasks.py | 56 +++++++++++++++----------------- 5 files changed, 44 insertions(+), 38 deletions(-) diff --git a/agave/tasks/metrics.py b/agave/tasks/metrics.py index 00ac96a9..ed712eb8 100644 --- a/agave/tasks/metrics.py +++ b/agave/tasks/metrics.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Optional +from typing import Awaitable, Callable, Optional @dataclass @@ -13,7 +13,7 @@ async def emit_metrics( metrics: TaskMetrics, status: str, duration_ms: float, - metrics_callback: Optional[Callable] = None, + metrics_callback: Optional[Callable[..., Awaitable[None]]] = None, ) -> None: if not metrics_callback: return diff --git a/agave/tasks/sqs_tasks.py b/agave/tasks/sqs_tasks.py index e2b53d78..3dd758d5 100644 --- a/agave/tasks/sqs_tasks.py +++ b/agave/tasks/sqs_tasks.py @@ -7,7 +7,7 @@ from functools import wraps from itertools import count from json import JSONDecodeError -from typing import AsyncGenerator, Callable, Coroutine, Optional +from typing import AsyncGenerator, Awaitable, Callable, Coroutine, Optional from aiobotocore.httpsession import HTTPClientError from aiobotocore.session import get_session @@ -42,7 +42,7 @@ async def run_task( delete_on_failure: bool, task_name: str, task_module: str, - metrics_callback: Optional[Callable] = None, + metrics_callback: Optional[Callable[..., Awaitable[None]]] = None, ) -> None: delete_message = True request_model = get_request_model(task_func) @@ -71,6 +71,10 @@ async def run_task( start_time = time.monotonic() try: resp = await task_func(body) + except asyncio.CancelledError: + delete_message = False + log_data['response']['status'] = 'cancelled' + raise except RetryTask as retry: retries_exhausted = message_receive_count >= max_retries + 1 delete_message = retries_exhausted and delete_on_failure @@ -81,7 +85,9 @@ async def run_task( VisibilityTimeout=retry.countdown, ) log_data['response']['delete_message'] = delete_message - log_data['response']['status'] = 'retrying' + log_data['response']['status'] = ( + 'failed' if retries_exhausted else 'retrying' + ) except Exception as exp: delete_message = delete_on_failure log_data['response']['status'] = 'failed' @@ -150,6 +156,8 @@ async def message_consumer( return await asyncio.sleep(1) continue + if shutdown_event.is_set(): + return for message in messages: yield message @@ -170,7 +178,7 @@ def task( max_retries: int = 1, max_concurrent_tasks: int = 5, delete_on_failure: bool = True, - metrics_callback: Optional[Callable] = None, + metrics_callback: Optional[Callable[..., Awaitable[None]]] = None, ): def task_builder(task_func: Callable): @wraps(task_func) diff --git a/agave/version.py b/agave/version.py index c3b38415..bb4c5a45 100644 --- a/agave/version.py +++ b/agave/version.py @@ -1 +1 @@ -__version__ = '1.5.2' +__version__ = '1.5.3.dev2' diff --git a/tests/tasks/test_loggin_tasks.py b/tests/tasks/test_loggin_tasks.py index bf527b1d..cfbc19c1 100644 --- a/tests/tasks/test_loggin_tasks.py +++ b/tests/tasks/test_loggin_tasks.py @@ -278,8 +278,8 @@ async def my_task(data: dict) -> None: == '2' ) - # For the third execution - assert log_data[2]['response']['status'] == 'retrying' + # For the third execution (retries exhausted) + assert log_data[2]['response']['status'] == 'failed' assert ( log_data[2]['request']['message_attributes']['ApproximateReceiveCount'] == '3' diff --git a/tests/tasks/test_sqs_tasks.py b/tests/tasks/test_sqs_tasks.py index e14382d9..5af08a0e 100644 --- a/tests/tasks/test_sqs_tasks.py +++ b/tests/tasks/test_sqs_tasks.py @@ -430,35 +430,6 @@ async def my_task(data: dict) -> None: assert len(BACKGROUND_TASKS) == 0 -async def test_invalid_json_message_is_deleted(sqs_client) -> None: - """ - Verifica que los mensajes con JSON inválido son eliminados del queue - y el task nunca es ejecutado, con visibility_timeout alto para - confirmar que la eliminación es explícita - """ - await sqs_client.send_message( - MessageBody='not valid json!!!', - MessageGroupId='1234', - ) - - async_mock_function = AsyncMock() - - async def my_task(data: dict) -> None: - await async_mock_function(data) - - await task( - queue_url=sqs_client.queue_url, - region_name=CORE_QUEUE_REGION, - wait_time_seconds=1, - visibility_timeout=10, - )(my_task)() - - async_mock_function.assert_not_called() - - resp = await sqs_client.receive_message() - assert 'Messages' not in resp - - async def test_delete_on_failure_false_unhandled_exception( sqs_client, enqueued_message, @@ -622,6 +593,33 @@ async def my_task(data: dict) -> None: assert 'msg2' not in processed +async def test_shutdown_before_receive() -> None: + """ + Cuando shutdown_event ya está activo antes de que el consumer + intente leer, debe detenerse inmediatamente sin llamar receive_message + """ + shutdown_event = asyncio.Event() + can_read = asyncio.Event() + can_read.set() + shutdown_event.set() + + mock_sqs = AsyncMock() + + messages = [] + async for msg in message_consumer( + 'queue_url', + 1, + 1, + can_read, + mock_sqs, + shutdown_event, + ): + messages.append(msg) + + assert messages == [] + mock_sqs.receive_message.assert_not_called() + + async def test_http_client_error_during_shutdown() -> None: """ Cuando ocurre un HTTPClientError y shutdown_event ya está activo,