Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion src/processing/outbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def run_forever(self) -> None:
while True:
await asyncio.sleep(2)
try:
self.process_pending()
await self.process_pending_async()
except duckdb.Error as exc:
logger.warning(
"outbox_processing_failed",
Expand Down Expand Up @@ -107,6 +107,63 @@ def process_pending(self, limit: int = 100) -> int:
processed += 1
return processed

async def process_pending_async(self, limit: int = 100) -> int:
"""Async variant used by run_forever.

DuckDB reads/updates stay on the event loop (the connection may be shared
with the query engine, so it must not be touched from a worker thread),
but the blocking Kafka produce+flush(10) is offloaded so a slow or
unreachable broker can't freeze the whole event loop. (audit_28_06_26.md #1)
"""
rows = self._connection.execute(
"""
SELECT id, event_id, payload, topic, retry_count
FROM outbox
WHERE status = 'pending'
AND (next_attempt_at IS NULL OR next_attempt_at <= ?)
ORDER BY created_at
LIMIT ?
""",
[datetime.now(UTC), limit],
).fetchall()
processed = 0
for row in rows:
if await self._process_row_async(row):
processed += 1
return processed

async def _process_row_async(self, row: tuple[Any, ...]) -> bool:
outbox_id, event_id, payload, topic, retry_count = row
decoded_payload = self._decode_payload(payload)
try:
await asyncio.to_thread(self._producer, topic, decoded_payload)
except (BufferError, ConnectionError, TimeoutError, KafkaException, RuntimeError) as exc:
error_message = str(exc)
if isinstance(exc, RuntimeError) and not (
error_message.startswith("KafkaError{")
or "Kafka message(s) were not delivered" in error_message
):
raise
next_retry_count = int(retry_count or 0) + 1
logger.warning(
"outbox_delivery_retry_scheduled",
outbox_id=outbox_id,
event_id=event_id,
topic=topic,
retry_count=next_retry_count,
error=error_message,
exc_info=True,
)
self._schedule_retry(
outbox_id=outbox_id,
event_id=event_id,
retry_count=next_retry_count,
error_message=error_message,
)
return False
self._mark_sent(outbox_id=outbox_id, event_id=event_id)
return True

def process_entry(self, outbox_id: str) -> bool:
row = self._connection.execute(
"""
Expand Down
75 changes: 60 additions & 15 deletions src/serving/api/alerts/escalation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import httpx
import structlog

from src.serving.api.egress_guard import UnsafeEgressURLError, validate_public_url
from src.serving.api.webhook_dispatcher import _event_body, _signature

from .evaluator import evaluate_rule
Expand Down Expand Up @@ -78,10 +79,6 @@ async def dispatch_alert(
return alert, True, 0

if current_triggered and alert.fired_at is None:
alert.fired_at = now
alert.resolved_at = None
alert.state = "firing"
alert.last_escalation_level = 1
payload = {
"alert_id": alert.id,
"alert_name": alert.name,
Expand All @@ -100,7 +97,7 @@ async def dispatch_alert(
payload["previous_value"] = evaluation["previous_value"]
if evaluation["change_pct"] is not None:
payload["change_pct"] = evaluation["change_pct"]
await deliver(
result = await deliver(
dispatcher,
alert,
payload,
Expand All @@ -110,6 +107,17 @@ async def dispatch_alert(
change_pct=evaluation["change_pct"],
webhook_url=alert.escalation[0].webhook_url,
)
if not result.get("success"):
# Delivery failed: do NOT advance fired state, so the next evaluation
# tick re-attempts the page instead of recording the alert as fired
# and going silent until cooldown. (audit_28_06_26.md #4)
alert.last_condition_triggered = True
alert.updated_at = now
return alert, True, 0
alert.fired_at = now
alert.resolved_at = None
alert.state = "firing"
alert.last_escalation_level = 1
alert.last_triggered_at = now
alert.last_condition_triggered = True
alert.updated_at = now
Expand Down Expand Up @@ -138,7 +146,7 @@ async def dispatch_alert(
payload["previous_value"] = evaluation["previous_value"]
if evaluation["change_pct"] is not None:
payload["change_pct"] = evaluation["change_pct"]
await deliver(
result = await deliver(
dispatcher,
alert,
payload,
Expand All @@ -152,14 +160,17 @@ async def dispatch_alert(
change_pct=evaluation["change_pct"],
webhook_url=next_step.webhook_url,
)
alert.last_triggered_at = now
alert.last_escalation_level = max(
alert.last_escalation_level,
next_step.level,
)
alert.updated_at = now
triggered += 1
alert_changed = True
if result.get("success"):
alert.last_triggered_at = now
alert.last_escalation_level = max(
alert.last_escalation_level,
next_step.level,
)
alert.updated_at = now
triggered += 1
alert_changed = True
# else: leave last_escalation_level unchanged so the next evaluation
# tick re-attempts this escalation step. (audit_28_06_26.md #4)
alert.state = "sustained"
alert.last_condition_triggered = True
return alert, alert_changed, triggered
Expand Down Expand Up @@ -238,13 +249,47 @@ async def deliver(
status_code: int | None = None
error: str | None = None

target_url = webhook_url or alert.webhook_url
# Re-validate at delivery time (DNS rebinding): a name public at registration
# could now resolve to an internal address. (audit_28_06_26.md #2)
try:
await asyncio.to_thread(validate_public_url, target_url)
except UnsafeEgressURLError as exc:
error = f"unsafe egress URL: {exc}"
log_alert_history(
conn,
delivery_id=delivery_id,
alert=alert,
metric=alert.metric,
current_value=current_value,
previous_value=previous_value,
change_pct=change_pct,
threshold=alert.threshold,
condition=alert.condition,
window=alert.window,
event_type=event_type,
status_code=None,
success=False,
error=error,
payload=payload,
)
return {
"delivery_id": delivery_id,
"alert_id": alert.id,
"event_type": event_type,
"success": False,
"status_code": None,
"error": error,
"attempts": 0,
}

async with httpx.AsyncClient(timeout=5.0) as client:
for attempt in range(1, 4):
attempts = attempt
error = None
try:
response = await client.post(
webhook_url or alert.webhook_url,
target_url,
content=body,
headers=headers,
)
Expand Down
6 changes: 5 additions & 1 deletion src/serving/api/auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import structlog
from fastapi import Header, HTTPException, Request, Response
from fastapi.responses import JSONResponse
from starlette.concurrency import run_in_threadpool

from src.constants import DEFAULT_RATE_LIMIT_WINDOW_SECONDS, FAILED_AUTH_WINDOW_SECONDS
from src.serving.api.metrics import AUTH_FAILURES
Expand Down Expand Up @@ -106,7 +107,10 @@ async def __call__(
)

manager.clear_failed_auth(client_ip)
manager.record_usage(tenant_key, path)
# record_usage opens a DuckDB connection, writes, and retries with a
# blocking sleep; running it inline froze the event loop on every
# authenticated request. Offload to a worker thread. (audit_28_06_26.md #13)
await run_in_threadpool(manager.record_usage, tenant_key, path)
is_allowed, remaining, reset_at = await manager.check_rate_limit(tenant_key)
rate_limit_headers = {
"X-RateLimit-Limit": str(tenant_key.rate_limit_rpm),
Expand Down
66 changes: 66 additions & 0 deletions src/serving/api/egress_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Egress URL guard against SSRF.

Webhook and alert targets are tenant-controlled URLs that the server fetches.
Without a guard a tenant can point them at loopback / private / link-local /
cloud-metadata addresses and use the delivery result (status / error) as an
SSRF oracle to map and reach the internal network (audit_28_06_26.md #2).

This module resolves the host and rejects any URL that is not an http(s) target
resolving exclusively to public unicast addresses. It is applied both at
registration time (reject early, 4xx) and immediately before each delivery
(narrowing the DNS-rebinding window — a name that resolved public at creation
could later point at an internal IP).
"""

from __future__ import annotations

import ipaddress
import socket
from urllib.parse import urlsplit

_ALLOWED_SCHEMES = {"http", "https"}


class UnsafeEgressURLError(ValueError):
"""Raised when an outbound URL is not a public http(s) target."""


def _ip_is_public(ip: str) -> bool:
addr = ipaddress.ip_address(ip)
return not (
addr.is_private
or addr.is_loopback
or addr.is_link_local
or addr.is_reserved
or addr.is_multicast
or addr.is_unspecified
)


def validate_public_url(url: str) -> None:
"""Raise :class:`UnsafeEgressURLError` unless ``url`` is an http(s) URL whose
host resolves *only* to public unicast addresses.

Resolution is synchronous (``socket.getaddrinfo``); call it via
``asyncio.to_thread`` on the event loop. IP-literal hosts resolve to
themselves, so loopback/private/link-local literals are rejected without any
network DNS.
"""
parts = urlsplit(url)
scheme = parts.scheme.lower()
if scheme not in _ALLOWED_SCHEMES:
raise UnsafeEgressURLError(f"scheme not allowed: {parts.scheme!r}")
host = parts.hostname
if not host:
raise UnsafeEgressURLError("missing host")
port = parts.port or (443 if scheme == "https" else 80)
try:
infos = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)
except (socket.gaierror, UnicodeError) as exc:
raise UnsafeEgressURLError(f"host does not resolve: {host}") from exc
resolved = {str(info[4][0]) for info in infos}
if not resolved:
raise UnsafeEgressURLError(f"host does not resolve: {host}")
for ip in resolved:
if not _ip_is_public(ip):
raise UnsafeEgressURLError(f"host {host} resolves to non-public address {ip}")
32 changes: 21 additions & 11 deletions src/serving/api/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ def __init__(
self._time_source = time_source
self._windows: dict[str, list[float]] = defaultdict(list)

def _check_local(
self, key: str, limit: int, window_seconds: int, now: float
) -> tuple[bool, int, int]:
"""Per-process sliding-window check (no Redis). Used both when Redis is
unconfigured and as the fail-closed fallback when Redis errors."""
cutoff = now - window_seconds
window = [stamp for stamp in self._windows[key] if stamp > cutoff]
self._windows[key] = window
if len(window) >= limit:
reset_at = int(window[0] + window_seconds) if window else int(now + window_seconds)
return False, 0, reset_at
window.append(now)
reset_at = int(window[0] + window_seconds)
return True, max(0, limit - len(window)), reset_at

async def check(
self,
key: str,
Expand All @@ -43,16 +58,7 @@ async def check(
now = self._time_source()
reset_at = int(now + window_seconds)
if self._redis is None:
cutoff = now - window_seconds
window = [stamp for stamp in self._windows[key] if stamp > cutoff]
self._windows[key] = window
if len(window) >= limit:
if window:
reset_at = int(window[0] + window_seconds)
return False, 0, reset_at
window.append(now)
reset_at = int(window[0] + window_seconds)
return True, max(0, limit - len(window)), reset_at
return self._check_local(key, limit, window_seconds, now)

try:
pipeline = self._redis.pipeline()
Expand All @@ -68,7 +74,11 @@ async def check(
operation="check",
error=str(exc),
)
return True, limit, reset_at
# Fail closed to a per-process cap instead of fail-open: a Redis
# outage must not silently disable rate limiting fleet-wide, which
# would open a brute-force / DoS-amplification window on the
# expensive NL->SQL and entity paths. (audit_28_06_26.md #7)
return self._check_local(key, limit, window_seconds, now)

count = int(results[2])
oldest_entry = results[4]
Expand Down
11 changes: 11 additions & 0 deletions src/serving/api/routers/alerts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Literal, cast

from fastapi import APIRouter, HTTPException, Request, Response, status
Expand All @@ -13,6 +14,7 @@
list_alerts,
update_alert,
)
from src.serving.api.egress_guard import UnsafeEgressURLError, validate_public_url

router = APIRouter(prefix="/v1/alerts", tags=["alerts"])

Expand Down Expand Up @@ -66,6 +68,10 @@ def _validate_metric_request(request: Request, metric: str, window: str) -> None
@router.post("", status_code=status.HTTP_201_CREATED)
async def register_alert(payload: AlertCreateRequest, request: Request) -> dict[str, object]:
_validate_metric_request(request, payload.metric, payload.window)
try:
await asyncio.to_thread(validate_public_url, str(payload.webhook_url))
except UnsafeEgressURLError as exc:
raise HTTPException(status_code=400, detail=f"Unsafe webhook URL: {exc}") from exc
rule = create_alert(
get_alert_config_path(request.app),
name=payload.name,
Expand Down Expand Up @@ -101,6 +107,11 @@ async def modify_alert(
next_metric = updates.get("metric", existing.metric)
next_window = updates.get("window", existing.window)
_validate_metric_request(request, next_metric, next_window)
if "webhook_url" in updates:
try:
await asyncio.to_thread(validate_public_url, str(updates["webhook_url"]))
except UnsafeEgressURLError as exc:
raise HTTPException(status_code=400, detail=f"Unsafe webhook URL: {exc}") from exc

updated = update_alert(path, alert_id, _tenant(request), updates)
if updated is None:
Expand Down
Loading