-
Notifications
You must be signed in to change notification settings - Fork 1
Add TaskMetrics and emit_metrics for task performance tracking #181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .metrics import TaskMetrics, emit_metrics | ||
|
|
||
| __all__ = ['TaskMetrics', 'emit_metrics'] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| from dataclasses import dataclass | ||
| from typing import Awaitable, Callable, Optional | ||
|
|
||
|
|
||
| @dataclass | ||
| class TaskMetrics: | ||
| task_name: str | ||
| queue_url: str | ||
| concurrent_tasks_counter: Callable[[], int] | ||
|
|
||
|
|
||
| async def emit_metrics( | ||
| metrics: TaskMetrics, | ||
| status: str, | ||
| duration_ms: float, | ||
| metrics_callback: Optional[Callable[..., Awaitable[None]]] = None, | ||
| ) -> None: | ||
| if not metrics_callback: | ||
| return | ||
| await metrics_callback( | ||
| task_name=metrics.task_name, | ||
| queue_url=metrics.queue_url, | ||
| status=status, | ||
| duration_ms=duration_ms, | ||
| concurrent_tasks=metrics.concurrent_tasks_counter(), | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,10 +2,12 @@ | |
| import json | ||
| import logging | ||
| import os | ||
| import signal | ||
| import time | ||
| from functools import wraps | ||
| from itertools import count | ||
| from json import JSONDecodeError | ||
| from typing import AsyncGenerator, Callable, Coroutine | ||
| from typing import AsyncGenerator, Awaitable, Callable, Coroutine, Optional | ||
|
|
||
| from aiobotocore.httpsession import HTTPClientError | ||
| from aiobotocore.session import get_session | ||
|
|
@@ -18,14 +20,15 @@ | |
| get_sensitive_fields, | ||
| obfuscate_sensitive_data, | ||
| ) | ||
| from .metrics import TaskMetrics, emit_metrics | ||
|
|
||
| logging.basicConfig(level=logging.INFO) | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| AWS_DEFAULT_REGION = os.getenv('AWS_DEFAULT_REGION', '') | ||
|
|
||
| BACKGROUND_TASKS = set() | ||
| BACKGROUND_TASKS: set[asyncio.Task] = set() | ||
|
|
||
|
|
||
| async def run_task( | ||
|
|
@@ -36,6 +39,10 @@ async def run_task( | |
| queue_url: str, | ||
| message_receive_count: int, | ||
| max_retries: int, | ||
| delete_on_failure: bool, | ||
| task_name: str, | ||
| task_module: str, | ||
| metrics_callback: Optional[Callable[..., Awaitable[None]]] = None, | ||
| ) -> None: | ||
| delete_message = True | ||
| request_model = get_request_model(task_func) | ||
|
|
@@ -48,8 +55,8 @@ async def run_task( | |
| response_log_config_fields = get_sensitive_fields(response_model) | ||
| log_data = { | ||
| 'request': { | ||
| 'task_func': task_func.__name__, | ||
| 'task_module': task_func.__module__, | ||
| 'task_func': task_name, | ||
| 'task_module': task_module, | ||
| 'queue_url': queue_url, | ||
| 'max_retries': max_retries, | ||
| 'body': ofuscated_request_body, | ||
|
|
@@ -61,19 +68,28 @@ async def run_task( | |
| 'status': 'success', | ||
| }, | ||
| } | ||
| start_time = time.monotonic() | ||
| try: | ||
| resp = await task_func(body) | ||
| except asyncio.CancelledError: | ||
| delete_message = False | ||
| log_data['response']['status'] = 'cancelled' | ||
| raise | ||
| except RetryTask as retry: | ||
| delete_message = message_receive_count >= max_retries + 1 | ||
| if not delete_message and retry.countdown and retry.countdown > 0: | ||
| retries_exhausted = message_receive_count >= max_retries + 1 | ||
| delete_message = retries_exhausted and delete_on_failure | ||
| if not retries_exhausted and retry.countdown and retry.countdown > 0: | ||
| await sqs.change_message_visibility( | ||
| QueueUrl=queue_url, | ||
| ReceiptHandle=message['ReceiptHandle'], | ||
| VisibilityTimeout=retry.countdown, | ||
| ) | ||
| log_data['response']['delete_message'] = delete_message | ||
| log_data['response']['status'] = 'retrying' | ||
| log_data['response']['status'] = ( | ||
| 'failed' if retries_exhausted else 'retrying' | ||
| ) | ||
| except Exception as exp: | ||
| delete_message = delete_on_failure | ||
| log_data['response']['status'] = 'failed' | ||
| log_data['response']['error'] = str(exp) | ||
| else: | ||
|
|
@@ -86,12 +102,29 @@ async def run_task( | |
| ofuscated_response_body = resp | ||
| log_data['response']['body'] = ofuscated_response_body | ||
| finally: | ||
| duration_ms = round((time.monotonic() - start_time) * 1000, 2) | ||
| log_data['response']['duration_ms'] = duration_ms | ||
| if delete_message: | ||
| await sqs.delete_message( | ||
| QueueUrl=queue_url, | ||
| ReceiptHandle=message['ReceiptHandle'], | ||
| ) | ||
| logger.info(json.dumps(log_data, default=str)) | ||
| if metrics_callback: | ||
| try: | ||
| task_metrics = TaskMetrics( | ||
| task_name=task_name, | ||
| queue_url=queue_url, | ||
| concurrent_tasks_counter=lambda: len(BACKGROUND_TASKS), | ||
| ) | ||
| await emit_metrics( | ||
| metrics=task_metrics, | ||
| status=log_data['response']['status'], | ||
| duration_ms=duration_ms, | ||
| metrics_callback=metrics_callback, | ||
| ) | ||
| except Exception as exc: | ||
| logger.warning(f'metrics_callback failed: {exc}') | ||
|
Comment on lines
+113
to
+127
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bound metrics emission with a timeout to avoid stalling workers. Line 120 awaits external callback flow without a timeout. If it hangs, the worker slot never frees, throttling or stalling consumption under load. Suggested hardening if metrics_callback:
try:
task_metrics = TaskMetrics(
task_name=task_name,
queue_url=queue_url,
concurrent_tasks_counter=lambda: len(BACKGROUND_TASKS),
)
- await emit_metrics(
- metrics=task_metrics,
- status=log_data['response']['status'],
- duration_ms=duration_ms,
- metrics_callback=metrics_callback,
- )
+ await asyncio.wait_for(
+ emit_metrics(
+ metrics=task_metrics,
+ status=log_data['response']['status'],
+ duration_ms=duration_ms,
+ metrics_callback=metrics_callback,
+ ),
+ timeout=5,
+ )
+ except asyncio.TimeoutError:
+ logger.warning('metrics_callback timed out')
except Exception as exc:
logger.warning(f'metrics_callback failed: {exc}')🧰 Tools🪛 Ruff (0.15.9)[warning] 126-126: Do not catch blind exception: (BLE001) 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| async def message_consumer( | ||
|
|
@@ -100,9 +133,12 @@ async def message_consumer( | |
| visibility_timeout: int, | ||
| can_read: asyncio.Event, | ||
| sqs, | ||
| shutdown_event: asyncio.Event, | ||
| ) -> AsyncGenerator: | ||
| for _ in count(): | ||
| await can_read.wait() | ||
| if shutdown_event.is_set(): | ||
| return | ||
| try: | ||
| response = await sqs.receive_message( | ||
| QueueUrl=queue_url, | ||
|
|
@@ -112,10 +148,16 @@ async def message_consumer( | |
| ) | ||
| messages = response['Messages'] | ||
| except KeyError: | ||
| if shutdown_event.is_set(): | ||
| return | ||
| continue | ||
| except HTTPClientError: | ||
| if shutdown_event.is_set(): | ||
| return | ||
| await asyncio.sleep(1) | ||
| continue | ||
| if shutdown_event.is_set(): | ||
| return | ||
| for message in messages: | ||
| yield message | ||
|
|
||
|
|
@@ -135,11 +177,14 @@ def task( | |
| visibility_timeout: int = 3600, | ||
| max_retries: int = 1, | ||
| max_concurrent_tasks: int = 5, | ||
| delete_on_failure: bool = True, | ||
| metrics_callback: Optional[Callable[..., Awaitable[None]]] = None, | ||
| ): | ||
| def task_builder(task_func: Callable): | ||
| @wraps(task_func) | ||
| async def start_task(*args, **kwargs) -> None: | ||
| can_read = asyncio.Event() | ||
| shutdown_event = asyncio.Event() | ||
| concurrency_semaphore = asyncio.Semaphore(max_concurrent_tasks) | ||
| can_read.set() | ||
|
|
||
|
|
@@ -151,51 +196,97 @@ async def concurrency_controller(coro: Coroutine) -> None: | |
| try: | ||
| await coro | ||
| finally: | ||
| can_read.set() | ||
| if not shutdown_event.is_set(): | ||
| can_read.set() | ||
|
|
||
| session = get_session() | ||
| loop = asyncio.get_running_loop() | ||
|
|
||
| task_with_validators = validate_call(task_func) | ||
| def _handle_signal(sig): | ||
| logger.info( | ||
| f'Received {sig.name}, initiating graceful shutdown' | ||
| ) | ||
| shutdown_event.set() | ||
| can_read.set() | ||
|
|
||
| async with session.create_client('sqs', region_name) as sqs: | ||
| async for message in message_consumer( | ||
| queue_url, | ||
| wait_time_seconds, | ||
| visibility_timeout, | ||
| can_read, | ||
| sqs, | ||
| ): | ||
| try: | ||
| body = json.loads(message['Body']) | ||
| except JSONDecodeError: | ||
| continue | ||
|
|
||
| message_receive_count = int( | ||
| message['Attributes']['ApproximateReceiveCount'] | ||
| ) | ||
| bg_task = asyncio.create_task( | ||
| concurrency_controller( | ||
| run_task( | ||
| task_with_validators, | ||
| body, | ||
| message, | ||
| sqs, | ||
| queue_url, | ||
| message_receive_count, | ||
| max_retries, | ||
| previous_handlers = {} | ||
| for sig in (signal.SIGTERM, signal.SIGINT): | ||
| previous_handlers[sig] = signal.getsignal(sig) | ||
| loop.add_signal_handler(sig, _handle_signal, sig) | ||
|
|
||
| try: | ||
| session = get_session() | ||
|
|
||
| task_with_validators = validate_call(task_func) | ||
|
|
||
| async with session.create_client('sqs', region_name) as sqs: | ||
| async for message in message_consumer( | ||
| queue_url, | ||
| wait_time_seconds, | ||
| visibility_timeout, | ||
| can_read, | ||
| sqs, | ||
| shutdown_event, | ||
| ): | ||
| try: | ||
| body = json.loads(message['Body']) | ||
| except JSONDecodeError: | ||
| msg_id = message['MessageId'] | ||
| log_data = dict( | ||
| message_id=msg_id, | ||
| status='invalid_json', | ||
| ) | ||
| logger.warning(json.dumps(log_data)) | ||
| await sqs.delete_message( | ||
| QueueUrl=queue_url, | ||
| ReceiptHandle=message['ReceiptHandle'], | ||
| ) | ||
| continue | ||
|
|
||
| message_receive_count = int( | ||
| message['Attributes']['ApproximateReceiveCount'] | ||
| ) | ||
| bg_task = asyncio.create_task( | ||
| concurrency_controller( | ||
| run_task( | ||
| task_with_validators, | ||
| body, | ||
| message, | ||
| sqs, | ||
| queue_url, | ||
| message_receive_count, | ||
| max_retries, | ||
| delete_on_failure, | ||
| task_name=task_func.__name__, | ||
| task_module=task_func.__module__, | ||
| metrics_callback=metrics_callback, | ||
| ), | ||
| ), | ||
| ), | ||
| name='fast-agave-task', | ||
| ) | ||
| BACKGROUND_TASKS.add(bg_task) | ||
| bg_task.add_done_callback(BACKGROUND_TASKS.discard) | ||
|
|
||
| # Espera a que terminen todos los tasks pendientes creados por | ||
| # `asyncio.create_task`. De esta forma los tasks | ||
| # podrán borrar el mensaje del queue usando la misma instancia | ||
| # del cliente de SQS | ||
| running_tasks = await get_running_fast_agave_tasks() | ||
| await asyncio.gather(*running_tasks) | ||
| name='fast-agave-task', | ||
| ) | ||
| BACKGROUND_TASKS.add(bg_task) | ||
| bg_task.add_done_callback(BACKGROUND_TASKS.discard) | ||
|
|
||
| running_tasks = await get_running_fast_agave_tasks() | ||
| if shutdown_event.is_set(): | ||
| try: | ||
| await asyncio.wait_for( | ||
| asyncio.gather(*running_tasks), | ||
| timeout=visibility_timeout, | ||
| ) | ||
| except asyncio.TimeoutError: | ||
| logger.warning( | ||
| 'Graceful shutdown timeout, tasks may retry' | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| else: | ||
| await asyncio.gather(*running_tasks) | ||
| finally: | ||
| for sig, handler in previous_handlers.items(): | ||
| loop.remove_signal_handler(sig) | ||
| if handler and handler not in ( | ||
| signal.SIG_DFL, | ||
| signal.SIG_IGN, | ||
| ): | ||
| signal.signal(sig, handler) | ||
|
|
||
| return start_task | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| __version__ = '1.5.2' | ||
| __version__ = '1.5.3.dev2' |
Uh oh!
There was an error while loading. Please reload this page.