diff --git a/agave/tasks/__init__.py b/agave/tasks/__init__.py index e69de29b..86426099 100644 --- a/agave/tasks/__init__.py +++ b/agave/tasks/__init__.py @@ -0,0 +1,3 @@ +from .metrics import TaskMetrics, emit_metrics + +__all__ = ['TaskMetrics', 'emit_metrics'] diff --git a/agave/tasks/metrics.py b/agave/tasks/metrics.py new file mode 100644 index 00000000..ed712eb8 --- /dev/null +++ b/agave/tasks/metrics.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import Awaitable, Callable, Optional + + +@dataclass +class TaskMetrics: + task_name: str + queue_url: str + concurrent_tasks_counter: Callable[[], int] + + +async def emit_metrics( + metrics: TaskMetrics, + status: str, + duration_ms: float, + metrics_callback: Optional[Callable[..., Awaitable[None]]] = None, +) -> None: + if not metrics_callback: + return + await metrics_callback( + task_name=metrics.task_name, + queue_url=metrics.queue_url, + status=status, + duration_ms=duration_ms, + concurrent_tasks=metrics.concurrent_tasks_counter(), + ) diff --git a/agave/tasks/sqs_tasks.py b/agave/tasks/sqs_tasks.py index cda571c5..3dd758d5 100644 --- a/agave/tasks/sqs_tasks.py +++ b/agave/tasks/sqs_tasks.py @@ -2,10 +2,12 @@ import json import logging import os +import signal +import time from functools import wraps from itertools import count from json import JSONDecodeError -from typing import AsyncGenerator, Callable, Coroutine +from typing import AsyncGenerator, Awaitable, Callable, Coroutine, Optional from aiobotocore.httpsession import HTTPClientError from aiobotocore.session import get_session @@ -18,6 +20,7 @@ get_sensitive_fields, obfuscate_sensitive_data, ) +from .metrics import TaskMetrics, emit_metrics logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -25,7 +28,7 @@ AWS_DEFAULT_REGION = os.getenv('AWS_DEFAULT_REGION', '') -BACKGROUND_TASKS = set() +BACKGROUND_TASKS: set[asyncio.Task] = set() async def run_task( @@ -36,6 +39,10 @@ async def run_task( queue_url: str, message_receive_count: int, max_retries: int, + delete_on_failure: bool, + task_name: str, + task_module: str, + metrics_callback: Optional[Callable[..., Awaitable[None]]] = None, ) -> None: delete_message = True request_model = get_request_model(task_func) @@ -48,8 +55,8 @@ async def run_task( response_log_config_fields = get_sensitive_fields(response_model) log_data = { 'request': { - 'task_func': task_func.__name__, - 'task_module': task_func.__module__, + 'task_func': task_name, + 'task_module': task_module, 'queue_url': queue_url, 'max_retries': max_retries, 'body': ofuscated_request_body, @@ -61,19 +68,28 @@ async def run_task( 'status': 'success', }, } + start_time = time.monotonic() try: resp = await task_func(body) + except asyncio.CancelledError: + delete_message = False + log_data['response']['status'] = 'cancelled' + raise except RetryTask as retry: - delete_message = message_receive_count >= max_retries + 1 - if not delete_message and retry.countdown and retry.countdown > 0: + retries_exhausted = message_receive_count >= max_retries + 1 + delete_message = retries_exhausted and delete_on_failure + if not retries_exhausted and retry.countdown and retry.countdown > 0: await sqs.change_message_visibility( QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'], VisibilityTimeout=retry.countdown, ) log_data['response']['delete_message'] = delete_message - log_data['response']['status'] = 'retrying' + log_data['response']['status'] = ( + 'failed' if retries_exhausted else 'retrying' + ) except Exception as exp: + delete_message = delete_on_failure log_data['response']['status'] = 'failed' log_data['response']['error'] = str(exp) else: @@ -86,12 +102,29 @@ async def run_task( ofuscated_response_body = resp log_data['response']['body'] = ofuscated_response_body finally: + duration_ms = round((time.monotonic() - start_time) * 1000, 2) + log_data['response']['duration_ms'] = duration_ms if delete_message: await sqs.delete_message( QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'], ) logger.info(json.dumps(log_data, default=str)) + if metrics_callback: + try: + task_metrics = TaskMetrics( + task_name=task_name, + queue_url=queue_url, + concurrent_tasks_counter=lambda: len(BACKGROUND_TASKS), + ) + await emit_metrics( + metrics=task_metrics, + status=log_data['response']['status'], + duration_ms=duration_ms, + metrics_callback=metrics_callback, + ) + except Exception as exc: + logger.warning(f'metrics_callback failed: {exc}') async def message_consumer( @@ -100,9 +133,12 @@ async def message_consumer( visibility_timeout: int, can_read: asyncio.Event, sqs, + shutdown_event: asyncio.Event, ) -> AsyncGenerator: for _ in count(): await can_read.wait() + if shutdown_event.is_set(): + return try: response = await sqs.receive_message( QueueUrl=queue_url, @@ -112,10 +148,16 @@ async def message_consumer( ) messages = response['Messages'] except KeyError: + if shutdown_event.is_set(): + return continue except HTTPClientError: + if shutdown_event.is_set(): + return await asyncio.sleep(1) continue + if shutdown_event.is_set(): + return for message in messages: yield message @@ -135,11 +177,14 @@ def task( visibility_timeout: int = 3600, max_retries: int = 1, max_concurrent_tasks: int = 5, + delete_on_failure: bool = True, + metrics_callback: Optional[Callable[..., Awaitable[None]]] = None, ): def task_builder(task_func: Callable): @wraps(task_func) async def start_task(*args, **kwargs) -> None: can_read = asyncio.Event() + shutdown_event = asyncio.Event() concurrency_semaphore = asyncio.Semaphore(max_concurrent_tasks) can_read.set() @@ -151,51 +196,97 @@ async def concurrency_controller(coro: Coroutine) -> None: try: await coro finally: - can_read.set() + if not shutdown_event.is_set(): + can_read.set() - session = get_session() + loop = asyncio.get_running_loop() - task_with_validators = validate_call(task_func) + def _handle_signal(sig): + logger.info( + f'Received {sig.name}, initiating graceful shutdown' + ) + shutdown_event.set() + can_read.set() - async with session.create_client('sqs', region_name) as sqs: - async for message in message_consumer( - queue_url, - wait_time_seconds, - visibility_timeout, - can_read, - sqs, - ): - try: - body = json.loads(message['Body']) - except JSONDecodeError: - continue - - message_receive_count = int( - message['Attributes']['ApproximateReceiveCount'] - ) - bg_task = asyncio.create_task( - concurrency_controller( - run_task( - task_with_validators, - body, - message, - sqs, - queue_url, - message_receive_count, - max_retries, + previous_handlers = {} + for sig in (signal.SIGTERM, signal.SIGINT): + previous_handlers[sig] = signal.getsignal(sig) + loop.add_signal_handler(sig, _handle_signal, sig) + + try: + session = get_session() + + task_with_validators = validate_call(task_func) + + async with session.create_client('sqs', region_name) as sqs: + async for message in message_consumer( + queue_url, + wait_time_seconds, + visibility_timeout, + can_read, + sqs, + shutdown_event, + ): + try: + body = json.loads(message['Body']) + except JSONDecodeError: + msg_id = message['MessageId'] + log_data = dict( + message_id=msg_id, + status='invalid_json', + ) + logger.warning(json.dumps(log_data)) + await sqs.delete_message( + QueueUrl=queue_url, + ReceiptHandle=message['ReceiptHandle'], + ) + continue + + message_receive_count = int( + message['Attributes']['ApproximateReceiveCount'] + ) + bg_task = asyncio.create_task( + concurrency_controller( + run_task( + task_with_validators, + body, + message, + sqs, + queue_url, + message_receive_count, + max_retries, + delete_on_failure, + task_name=task_func.__name__, + task_module=task_func.__module__, + metrics_callback=metrics_callback, + ), ), - ), - name='fast-agave-task', - ) - BACKGROUND_TASKS.add(bg_task) - bg_task.add_done_callback(BACKGROUND_TASKS.discard) - - # Espera a que terminen todos los tasks pendientes creados por - # `asyncio.create_task`. De esta forma los tasks - # podrán borrar el mensaje del queue usando la misma instancia - # del cliente de SQS - running_tasks = await get_running_fast_agave_tasks() - await asyncio.gather(*running_tasks) + name='fast-agave-task', + ) + BACKGROUND_TASKS.add(bg_task) + bg_task.add_done_callback(BACKGROUND_TASKS.discard) + + running_tasks = await get_running_fast_agave_tasks() + if shutdown_event.is_set(): + try: + await asyncio.wait_for( + asyncio.gather(*running_tasks), + timeout=visibility_timeout, + ) + except asyncio.TimeoutError: + logger.warning( + 'Graceful shutdown timeout, tasks may retry' + ) + else: + await asyncio.gather(*running_tasks) + finally: + for sig, handler in previous_handlers.items(): + loop.remove_signal_handler(sig) + if handler and handler not in ( + signal.SIG_DFL, + signal.SIG_IGN, + ): + signal.signal(sig, handler) return start_task diff --git a/agave/version.py b/agave/version.py index c3b38415..bb4c5a45 100644 --- a/agave/version.py +++ b/agave/version.py @@ -1 +1 @@ -__version__ = '1.5.2' +__version__ = '1.5.3.dev2' diff --git a/tests/conftest.py b/tests/conftest.py index e09323e9..118c4325 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ import datetime as dt import functools +import json import logging import os +import signal from functools import partial from typing import Callable, Generator @@ -291,3 +293,23 @@ def set_log_level(caplog): Automatically set logging level to INFO for all tests. """ caplog.set_level(logging.INFO) + + +TEST_MESSAGE = dict(id='abc123', name='fast-agave') + + +@pytest.fixture +async def enqueued_message(sqs_client): + await sqs_client.send_message( + MessageBody=json.dumps(TEST_MESSAGE), + MessageGroupId='1234', + ) + return TEST_MESSAGE + + +@pytest.fixture +def trigger_shutdown(): + def _trigger(): + os.kill(os.getpid(), signal.SIGTERM) + + return _trigger diff --git a/tests/tasks/test_loggin_tasks.py b/tests/tasks/test_loggin_tasks.py index bf527b1d..cfbc19c1 100644 --- a/tests/tasks/test_loggin_tasks.py +++ b/tests/tasks/test_loggin_tasks.py @@ -278,8 +278,8 @@ async def my_task(data: dict) -> None: == '2' ) - # For the third execution - assert log_data[2]['response']['status'] == 'retrying' + # For the third execution (retries exhausted) + assert log_data[2]['response']['status'] == 'failed' assert ( log_data[2]['request']['message_attributes']['ApproximateReceiveCount'] == '3' diff --git a/tests/tasks/test_metrics.py b/tests/tasks/test_metrics.py new file mode 100644 index 00000000..31df4824 --- /dev/null +++ b/tests/tasks/test_metrics.py @@ -0,0 +1,43 @@ +from unittest.mock import AsyncMock + +from agave.tasks.metrics import TaskMetrics, emit_metrics + +TASK_NAME = 'my_task' +QUEUE_URL = 'https://sqs.us-east-1.amazonaws.com/queue' + + +async def test_emit_metrics_calls_callback() -> None: + callback = AsyncMock() + metrics = TaskMetrics( + task_name=TASK_NAME, + queue_url=QUEUE_URL, + concurrent_tasks_counter=lambda: 3, + ) + + await emit_metrics( + metrics=metrics, + status='success', + duration_ms=150.5, + metrics_callback=callback, + ) + + callback.assert_called_once_with( + task_name=TASK_NAME, + queue_url=QUEUE_URL, + status='success', + duration_ms=150.5, + concurrent_tasks=3, + ) + + +async def test_emit_metrics_skips_when_no_callback() -> None: + metrics = TaskMetrics( + task_name=TASK_NAME, + queue_url=QUEUE_URL, + concurrent_tasks_counter=lambda: 0, + ) + await emit_metrics( + metrics=metrics, + status='success', + duration_ms=100.0, + ) diff --git a/tests/tasks/test_sqs_tasks.py b/tests/tasks/test_sqs_tasks.py index 3a2523c8..5af08a0e 100644 --- a/tests/tasks/test_sqs_tasks.py +++ b/tests/tasks/test_sqs_tasks.py @@ -13,6 +13,7 @@ from agave.tasks.sqs_tasks import ( BACKGROUND_TASKS, get_running_fast_agave_tasks, + message_consumer, task, ) @@ -427,3 +428,389 @@ async def my_task(data: dict) -> None: resp = await sqs_client.receive_message() assert 'Messages' not in resp assert len(BACKGROUND_TASKS) == 0 + + +async def test_delete_on_failure_false_unhandled_exception( + sqs_client, + enqueued_message, + trigger_shutdown, +) -> None: + """ + Cuando delete_on_failure=False, un mensaje que falla por una excepción + no controlada NO debe ser eliminado del queue (para que el redrive + policy lo envíe al DLQ) + """ + async_mock_function = AsyncMock() + + async def my_task(data: dict) -> None: + await async_mock_function(data) + trigger_shutdown() + raise Exception('something went wrong :(') + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + max_retries=1, + delete_on_failure=False, + )(my_task)() + + assert async_mock_function.call_count == 1 + async_mock_function.assert_called_with(enqueued_message) + + resp = await sqs_client.receive_message(WaitTimeSeconds=2) + assert 'Messages' in resp + + +async def test_delete_on_failure_false_retries_exhausted( + sqs_client, + enqueued_message, + trigger_shutdown, +) -> None: + """ + Cuando delete_on_failure=False y se agotan los reintentos por RetryTask, + el mensaje NO debe ser eliminado del queue + """ + retry_count = 0 + async_mock_function = AsyncMock(side_effect=RetryTask) + + async def my_task(data: dict) -> None: + nonlocal retry_count + retry_count += 1 + if retry_count >= 2: + trigger_shutdown() + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + max_retries=1, + delete_on_failure=False, + )(my_task)() + + expected_calls = [call(enqueued_message)] * 2 + assert async_mock_function.call_count == len(expected_calls) + async_mock_function.assert_has_calls(expected_calls) + + resp = await sqs_client.receive_message(WaitTimeSeconds=2) + assert 'Messages' in resp + + +async def test_delete_on_failure_true_is_default_behavior( + sqs_client, + enqueued_message, +) -> None: + """ + Verifica que delete_on_failure=True (el default) mantiene el + comportamiento actual: elimina el mensaje cuando hay una excepción + no controlada + """ + async_mock_function = AsyncMock( + side_effect=Exception('something went wrong :(') + ) + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + max_retries=1, + delete_on_failure=True, + )(my_task)() + + async_mock_function.assert_called_with(enqueued_message) + assert async_mock_function.call_count == 1 + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + + +async def test_graceful_shutdown_completes_inflight_task( + sqs_client, + enqueued_message, + trigger_shutdown, +) -> None: + """ + Verifica que al recibir SIGTERM, los tasks en vuelo completan + antes de que el listener se detenga + """ + task_completed = False + + async def my_task(data: dict) -> None: + nonlocal task_completed + trigger_shutdown() + await asyncio.sleep(0.5) + task_completed = True + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=10, + )(my_task)() + + assert task_completed is True + + +async def test_no_new_messages_after_shutdown( + sqs_client, + trigger_shutdown, +) -> None: + """ + Verifica que después de recibir SIGTERM con max_concurrent=1, + el consumer se detiene y los mensajes restantes quedan en + el queue + """ + for i in range(3): + await sqs_client.send_message( + MessageBody=json.dumps(dict(id=f'msg{i}')), + MessageGroupId=str(i), + ) + + processed = [] + + async def my_task(data: dict) -> None: + processed.append(data['id']) + if data['id'] == 'msg0': + trigger_shutdown() + await asyncio.sleep(0.5) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=10, + max_concurrent_tasks=1, + )(my_task)() + + assert processed[0] == 'msg0' + assert 'msg2' not in processed + + +async def test_shutdown_before_receive() -> None: + """ + Cuando shutdown_event ya está activo antes de que el consumer + intente leer, debe detenerse inmediatamente sin llamar receive_message + """ + shutdown_event = asyncio.Event() + can_read = asyncio.Event() + can_read.set() + shutdown_event.set() + + mock_sqs = AsyncMock() + + messages = [] + async for msg in message_consumer( + 'queue_url', + 1, + 1, + can_read, + mock_sqs, + shutdown_event, + ): + messages.append(msg) + + assert messages == [] + mock_sqs.receive_message.assert_not_called() + + +async def test_http_client_error_during_shutdown() -> None: + """ + Cuando ocurre un HTTPClientError y shutdown_event ya está activo, + el consumer debe detenerse inmediatamente + """ + shutdown_event = asyncio.Event() + can_read = asyncio.Event() + can_read.set() + + mock_sqs = AsyncMock() + + async def receive_and_shutdown(**kw): + shutdown_event.set() + raise HTTPClientError(error='Connection reset') + + mock_sqs.receive_message = receive_and_shutdown + + messages = [] + async for msg in message_consumer( + 'queue_url', + 1, + 1, + can_read, + mock_sqs, + shutdown_event, + ): + messages.append(msg) + + assert messages == [] + + +async def test_graceful_shutdown_timeout( + sqs_client, + enqueued_message, + trigger_shutdown, + caplog, +) -> None: + """ + Cuando un task en vuelo no termina dentro del visibility_timeout, + el shutdown debe completar con un warning de timeout + """ + + async def my_task(data: dict) -> None: + trigger_shutdown() + await asyncio.sleep(60) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + )(my_task)() + + assert any( + 'Graceful shutdown timeout' in r.message for r in caplog.records + ) + + +async def test_duration_ms_in_log( + sqs_client, + enqueued_message, + caplog, +) -> None: + """ + Verifica que el log de request/response incluye duration_ms + """ + + async def my_task(data: dict) -> None: + await asyncio.sleep(0.1) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + )(my_task)() + + request_logs = [ + json.loads(r.message) + for r in caplog.records + if r.name == 'agave.tasks.sqs_tasks' + and '{' in r.message + and 'duration_ms' in r.message + ] + assert len(request_logs) >= 1 + assert request_logs[0]['response']['duration_ms'] >= 100 + + +async def test_metrics_callback_is_called( + sqs_client, + enqueued_message, +) -> None: + """ + Verifica que metrics_callback se invoca con los datos + correctos después de ejecutar un task + """ + metrics_mock = AsyncMock() + + async def my_task(data: dict) -> None: + pass + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + metrics_callback=metrics_mock, + )(my_task)() + + metrics_mock.assert_called_once() + call_kwargs = metrics_mock.call_args.kwargs + assert call_kwargs['task_name'] == 'my_task' + assert call_kwargs['status'] == 'success' + assert call_kwargs['duration_ms'] >= 0 + assert 'concurrent_tasks' in call_kwargs + assert 'queue_url' in call_kwargs + + +async def test_metrics_callback_on_failure( + sqs_client, + enqueued_message, +) -> None: + """ + Verifica que metrics_callback reporta status='failed' + cuando el task falla + """ + metrics_mock = AsyncMock() + + async def my_task(data: dict) -> None: + raise Exception('boom') + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + metrics_callback=metrics_mock, + )(my_task)() + + metrics_mock.assert_called_once() + assert metrics_mock.call_args.kwargs['status'] == 'failed' + + +async def test_metrics_callback_on_retry( + sqs_client, + enqueued_message, +) -> None: + """ + Verifica que metrics_callback reporta status='retrying' + cuando el task hace retry + """ + metrics_mock = AsyncMock() + + async def my_task(data: dict) -> None: + raise RetryTask() + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + max_retries=1, + metrics_callback=metrics_mock, + )(my_task)() + + assert metrics_mock.call_count == 2 + statuses = [c.kwargs['status'] for c in metrics_mock.call_args_list] + assert 'retrying' in statuses + + +async def test_metrics_callback_error_is_handled( + sqs_client, + enqueued_message, +) -> None: + """ + Si metrics_callback lanza una excepción, el task debe completar + normalmente sin propagar el error + """ + async_mock_function = AsyncMock() + metrics_mock = AsyncMock(side_effect=Exception('metrics broke')) + + async def my_task(data: dict) -> None: + await async_mock_function(data) + + await task( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + metrics_callback=metrics_mock, + )(my_task)() + + async_mock_function.assert_called_with(enqueued_message) + metrics_mock.assert_called_once()