Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions agave/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .metrics import TaskMetrics, emit_metrics

__all__ = ['TaskMetrics', 'emit_metrics']
26 changes: 26 additions & 0 deletions agave/tasks/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass
from typing import Awaitable, Callable, Optional


@dataclass
class TaskMetrics:
task_name: str
queue_url: str
concurrent_tasks_counter: Callable[[], int]


async def emit_metrics(
metrics: TaskMetrics,
status: str,
duration_ms: float,
metrics_callback: Optional[Callable[..., Awaitable[None]]] = None,
) -> None:
if not metrics_callback:
return
await metrics_callback(
task_name=metrics.task_name,
queue_url=metrics.queue_url,
status=status,
duration_ms=duration_ms,
concurrent_tasks=metrics.concurrent_tasks_counter(),
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
187 changes: 139 additions & 48 deletions agave/tasks/sqs_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import json
import logging
import os
import signal
import time
from functools import wraps
from itertools import count
from json import JSONDecodeError
from typing import AsyncGenerator, Callable, Coroutine
from typing import AsyncGenerator, Awaitable, Callable, Coroutine, Optional

from aiobotocore.httpsession import HTTPClientError
from aiobotocore.session import get_session
Expand All @@ -18,14 +20,15 @@
get_sensitive_fields,
obfuscate_sensitive_data,
)
from .metrics import TaskMetrics, emit_metrics

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


AWS_DEFAULT_REGION = os.getenv('AWS_DEFAULT_REGION', '')

BACKGROUND_TASKS = set()
BACKGROUND_TASKS: set[asyncio.Task] = set()


async def run_task(
Expand All @@ -36,6 +39,10 @@ async def run_task(
queue_url: str,
message_receive_count: int,
max_retries: int,
delete_on_failure: bool,
task_name: str,
task_module: str,
metrics_callback: Optional[Callable[..., Awaitable[None]]] = None,
) -> None:
delete_message = True
request_model = get_request_model(task_func)
Expand All @@ -48,8 +55,8 @@ async def run_task(
response_log_config_fields = get_sensitive_fields(response_model)
log_data = {
'request': {
'task_func': task_func.__name__,
'task_module': task_func.__module__,
'task_func': task_name,
'task_module': task_module,
'queue_url': queue_url,
'max_retries': max_retries,
'body': ofuscated_request_body,
Expand All @@ -61,19 +68,28 @@ async def run_task(
'status': 'success',
},
}
start_time = time.monotonic()
try:
resp = await task_func(body)
except asyncio.CancelledError:
delete_message = False
log_data['response']['status'] = 'cancelled'
raise
except RetryTask as retry:
delete_message = message_receive_count >= max_retries + 1
if not delete_message and retry.countdown and retry.countdown > 0:
retries_exhausted = message_receive_count >= max_retries + 1
delete_message = retries_exhausted and delete_on_failure
if not retries_exhausted and retry.countdown and retry.countdown > 0:
await sqs.change_message_visibility(
QueueUrl=queue_url,
ReceiptHandle=message['ReceiptHandle'],
VisibilityTimeout=retry.countdown,
)
log_data['response']['delete_message'] = delete_message
log_data['response']['status'] = 'retrying'
log_data['response']['status'] = (
'failed' if retries_exhausted else 'retrying'
)
except Exception as exp:
delete_message = delete_on_failure
log_data['response']['status'] = 'failed'
log_data['response']['error'] = str(exp)
else:
Expand All @@ -86,12 +102,29 @@ async def run_task(
ofuscated_response_body = resp
log_data['response']['body'] = ofuscated_response_body
finally:
duration_ms = round((time.monotonic() - start_time) * 1000, 2)
log_data['response']['duration_ms'] = duration_ms
if delete_message:
await sqs.delete_message(
QueueUrl=queue_url,
ReceiptHandle=message['ReceiptHandle'],
)
logger.info(json.dumps(log_data, default=str))
if metrics_callback:
try:
task_metrics = TaskMetrics(
task_name=task_name,
queue_url=queue_url,
concurrent_tasks_counter=lambda: len(BACKGROUND_TASKS),
)
await emit_metrics(
metrics=task_metrics,
status=log_data['response']['status'],
duration_ms=duration_ms,
metrics_callback=metrics_callback,
)
except Exception as exc:
logger.warning(f'metrics_callback failed: {exc}')
Comment on lines +113 to +127
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Bound metrics emission with a timeout to avoid stalling workers.

Line 120 awaits external callback flow without a timeout. If it hangs, the worker slot never frees, throttling or stalling consumption under load.

Suggested hardening
         if metrics_callback:
             try:
                 task_metrics = TaskMetrics(
                     task_name=task_name,
                     queue_url=queue_url,
                     concurrent_tasks_counter=lambda: len(BACKGROUND_TASKS),
                 )
-                await emit_metrics(
-                    metrics=task_metrics,
-                    status=log_data['response']['status'],
-                    duration_ms=duration_ms,
-                    metrics_callback=metrics_callback,
-                )
+                await asyncio.wait_for(
+                    emit_metrics(
+                        metrics=task_metrics,
+                        status=log_data['response']['status'],
+                        duration_ms=duration_ms,
+                        metrics_callback=metrics_callback,
+                    ),
+                    timeout=5,
+                )
+            except asyncio.TimeoutError:
+                logger.warning('metrics_callback timed out')
             except Exception as exc:
                 logger.warning(f'metrics_callback failed: {exc}')
🧰 Tools
🪛 Ruff (0.15.9)

[warning] 126-126: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@agave/tasks/sqs_tasks.py` around lines 113 - 127, The metrics emission awaits
an external callback (emit_metrics) without a timeout which can stall workers;
wrap the await emit_metrics(...) call in an asyncio timeout (e.g.,
asyncio.wait_for) with a short configurable constant (or env var) and catch
asyncio.TimeoutError (and other exceptions) to log a warning and move on so
BACKGROUND_TASKS slots are freed; ensure you reference TaskMetrics,
emit_metrics, metrics_callback and keep the existing exception logging for
non-timeout errors.



async def message_consumer(
Expand All @@ -100,9 +133,12 @@ async def message_consumer(
visibility_timeout: int,
can_read: asyncio.Event,
sqs,
shutdown_event: asyncio.Event,
) -> AsyncGenerator:
for _ in count():
await can_read.wait()
if shutdown_event.is_set():
return
try:
response = await sqs.receive_message(
QueueUrl=queue_url,
Expand All @@ -112,10 +148,16 @@ async def message_consumer(
)
messages = response['Messages']
except KeyError:
if shutdown_event.is_set():
return
continue
except HTTPClientError:
if shutdown_event.is_set():
return
await asyncio.sleep(1)
continue
if shutdown_event.is_set():
return
for message in messages:
yield message

Expand All @@ -135,11 +177,14 @@ def task(
visibility_timeout: int = 3600,
max_retries: int = 1,
max_concurrent_tasks: int = 5,
delete_on_failure: bool = True,
metrics_callback: Optional[Callable[..., Awaitable[None]]] = None,
):
def task_builder(task_func: Callable):
@wraps(task_func)
async def start_task(*args, **kwargs) -> None:
can_read = asyncio.Event()
shutdown_event = asyncio.Event()
concurrency_semaphore = asyncio.Semaphore(max_concurrent_tasks)
can_read.set()

Expand All @@ -151,51 +196,97 @@ async def concurrency_controller(coro: Coroutine) -> None:
try:
await coro
finally:
can_read.set()
if not shutdown_event.is_set():
can_read.set()

session = get_session()
loop = asyncio.get_running_loop()

task_with_validators = validate_call(task_func)
def _handle_signal(sig):
logger.info(
f'Received {sig.name}, initiating graceful shutdown'
)
shutdown_event.set()
can_read.set()

async with session.create_client('sqs', region_name) as sqs:
async for message in message_consumer(
queue_url,
wait_time_seconds,
visibility_timeout,
can_read,
sqs,
):
try:
body = json.loads(message['Body'])
except JSONDecodeError:
continue

message_receive_count = int(
message['Attributes']['ApproximateReceiveCount']
)
bg_task = asyncio.create_task(
concurrency_controller(
run_task(
task_with_validators,
body,
message,
sqs,
queue_url,
message_receive_count,
max_retries,
previous_handlers = {}
for sig in (signal.SIGTERM, signal.SIGINT):
previous_handlers[sig] = signal.getsignal(sig)
loop.add_signal_handler(sig, _handle_signal, sig)

try:
session = get_session()

task_with_validators = validate_call(task_func)

async with session.create_client('sqs', region_name) as sqs:
async for message in message_consumer(
queue_url,
wait_time_seconds,
visibility_timeout,
can_read,
sqs,
shutdown_event,
):
try:
body = json.loads(message['Body'])
except JSONDecodeError:
msg_id = message['MessageId']
log_data = dict(
message_id=msg_id,
status='invalid_json',
)
logger.warning(json.dumps(log_data))
await sqs.delete_message(
QueueUrl=queue_url,
ReceiptHandle=message['ReceiptHandle'],
)
continue

message_receive_count = int(
message['Attributes']['ApproximateReceiveCount']
)
bg_task = asyncio.create_task(
concurrency_controller(
run_task(
task_with_validators,
body,
message,
sqs,
queue_url,
message_receive_count,
max_retries,
delete_on_failure,
task_name=task_func.__name__,
task_module=task_func.__module__,
metrics_callback=metrics_callback,
),
),
),
name='fast-agave-task',
)
BACKGROUND_TASKS.add(bg_task)
bg_task.add_done_callback(BACKGROUND_TASKS.discard)

# Espera a que terminen todos los tasks pendientes creados por
# `asyncio.create_task`. De esta forma los tasks
# podrán borrar el mensaje del queue usando la misma instancia
# del cliente de SQS
running_tasks = await get_running_fast_agave_tasks()
await asyncio.gather(*running_tasks)
name='fast-agave-task',
)
BACKGROUND_TASKS.add(bg_task)
bg_task.add_done_callback(BACKGROUND_TASKS.discard)

running_tasks = await get_running_fast_agave_tasks()
if shutdown_event.is_set():
try:
await asyncio.wait_for(
asyncio.gather(*running_tasks),
timeout=visibility_timeout,
)
except asyncio.TimeoutError:
logger.warning(
'Graceful shutdown timeout, tasks may retry'
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
else:
await asyncio.gather(*running_tasks)
finally:
for sig, handler in previous_handlers.items():
loop.remove_signal_handler(sig)
if handler and handler not in (
signal.SIG_DFL,
signal.SIG_IGN,
):
signal.signal(sig, handler)

return start_task

Expand Down
2 changes: 1 addition & 1 deletion agave/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.5.2'
__version__ = '1.5.3.dev2'
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import datetime as dt
import functools
import json
import logging
import os
import signal
from functools import partial
from typing import Callable, Generator

Expand Down Expand Up @@ -291,3 +293,23 @@ def set_log_level(caplog):
Automatically set logging level to INFO for all tests.
"""
caplog.set_level(logging.INFO)


TEST_MESSAGE = dict(id='abc123', name='fast-agave')


@pytest.fixture
async def enqueued_message(sqs_client):
await sqs_client.send_message(
MessageBody=json.dumps(TEST_MESSAGE),
MessageGroupId='1234',
)
return TEST_MESSAGE


@pytest.fixture
def trigger_shutdown():
def _trigger():
os.kill(os.getpid(), signal.SIGTERM)

return _trigger
4 changes: 2 additions & 2 deletions tests/tasks/test_loggin_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ async def my_task(data: dict) -> None:
== '2'
)

# For the third execution
assert log_data[2]['response']['status'] == 'retrying'
# For the third execution (retries exhausted)
assert log_data[2]['response']['status'] == 'failed'
assert (
log_data[2]['request']['message_attributes']['ApproximateReceiveCount']
== '3'
Expand Down
Loading
Loading