diff --git a/src/serving/api/alerts/escalation.py b/src/serving/api/alerts/escalation.py index 80fed6d..2346464 100644 --- a/src/serving/api/alerts/escalation.py +++ b/src/serving/api/alerts/escalation.py @@ -194,8 +194,9 @@ async def dispatch_alert( continue if step.webhook_url not in notified_urls: notified_urls.append(step.webhook_url) + all_delivered = True for webhook_url in notified_urls or [alert.webhook_url]: - await deliver( + result = await deliver( dispatcher, alert, payload, @@ -205,7 +206,20 @@ async def dispatch_alert( change_pct=evaluation["change_pct"], webhook_url=webhook_url, ) - triggered += 1 + if result.get("success"): + triggered += 1 + else: + all_delivered = False + if not all_delivered: + # A resolved notification failed to deliver: do NOT clear fired_at or + # advance to the "resolved" state this tick, so the next evaluation + # tick re-attempts. Otherwise the resolved page is lost for good and + # the receiver keeps treating the incident as open — the same silent + # loss the firing/escalation paths were hardened against. + # (audit_30_06_26.md C1; mirrors audit_28_06_26.md #4) + alert.last_condition_triggered = False + alert.updated_at = now + return alert, True, triggered alert.state = "resolved" alert.resolved_at = now alert.fired_at = None diff --git a/src/serving/api/analytics.py b/src/serving/api/analytics.py index 580ed7b..81f6668 100644 --- a/src/serving/api/analytics.py +++ b/src/serving/api/analytics.py @@ -23,6 +23,15 @@ Awaitable[Response], ] +# Analytics runs before the route's Pydantic validation, so cap the persisted +# query text defensively at the same bound /v1/query enforces (1000 chars). +_MAX_QUERY_TEXT_CHARS = 1000 + +# Auth/throttle outcomes whose requests must never be recorded: the analytics +# middleware sits OUTSIDE AuthMiddleware, so recording these would let +# unauthenticated/rejected traffic drive un-throttled DB writes. (audit_30 S1) +_UNRECORDED_STATUS_CODES = frozenset({401, 403, 429, 503}) + def ensure_analytics_table(db_path: Path | str) -> None: for attempt in range(10): @@ -98,22 +107,35 @@ async def receive() -> Message: response = await call_next(request) # failure telemetry is best-effort before re-raising the original error except Exception: # nosec B110 - # Record downstream failures before re-raising them through the client stack. - _schedule_session_write( - request.app.state.auth_manager.db_path, - request_id, - _build_session_record( - request=request, - request_id=request_id, - status_code=500, - duration_ms=(time.perf_counter() - started_at) * 1000, - cache_hit=False, - body=body, - ), - ) + # Record downstream failures before re-raising — but only for an + # authenticated request, so an unauthenticated error can't drive an + # un-throttled DB write/thread spawn. (audit_30_06_26.md S1) + if getattr(request.state, "tenant_key", None) is not None: + _schedule_session_write( + request.app.state.auth_manager.db_path, + request_id, + _build_session_record( + request=request, + request_id=request_id, + status_code=500, + duration_ms=(time.perf_counter() - started_at) * 1000, + cache_hit=False, + body=body, + ), + ) raise response.headers["X-Request-Id"] = request_id + # Record analytics only for authenticated, non-rejected requests. This + # middleware runs OUTSIDE AuthMiddleware, so without this gate an + # unauthenticated/failed/throttled request would spawn a DB-writing + # thread and persist an attacker-controlled body with neither auth nor + # rate-limiting in front of it — a remote DoS. (audit_30_06_26.md S1) + if ( + getattr(request.state, "tenant_key", None) is None + or response.status_code in _UNRECORDED_STATUS_CODES + ): + return response background = response.background if background is None: background = BackgroundTasks() @@ -520,7 +542,9 @@ def _build_session_record( payload = {} question = payload.get("question") if isinstance(question, str): - query_text = question + # Truncate: analytics runs before the route validates the body, + # so an oversized question would otherwise be persisted verbatim. + query_text = question[:_MAX_QUERY_TEXT_CHARS] return { "request_id": request_id, diff --git a/src/serving/api/routers/agent_query.py b/src/serving/api/routers/agent_query.py index 5e83a2e..f044d5d 100644 --- a/src/serving/api/routers/agent_query.py +++ b/src/serving/api/routers/agent_query.py @@ -516,6 +516,21 @@ async def get_metric( ) _ensure_metric_allowed(req, metric_name) + # Reject windows the metric doesn't declare. The engine silently maps an + # unknown window to "1 hour" (and active_sessions ignores it entirely), so + # without this the response would echo the *requested* window while + # returning a *different* window's value, and each bogus window string would + # pollute the metric cache. Mirrors the alerts router. (audit_30_06_26.md A1) + available_windows = catalog.metrics[metric_name].available_windows + if window not in available_windows: + raise HTTPException( + status_code=422, + detail=( + f"Unsupported window '{window}' for metric '{metric_name}'. " + f"Available: {available_windows}" + ), + ) + as_of = _normalize_as_of(as_of) as_of_text = _as_of_iso_text(as_of) try: diff --git a/src/serving/api/routers/lineage.py b/src/serving/api/routers/lineage.py index 81220ac..8972dcd 100644 --- a/src/serving/api/routers/lineage.py +++ b/src/serving/api/routers/lineage.py @@ -4,6 +4,7 @@ from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel +from starlette.concurrency import run_in_threadpool router = APIRouter(prefix="/v1/lineage", tags=["lineage"]) @@ -66,48 +67,57 @@ def _quality_score(rows: list[dict], *, default: float | None = None) -> float | def _fetch_matching_events(request: Request, entity_type: str, entity_id: str) -> list[dict]: - conn = request.app.state.query_engine._conn - columns = {row[1] for row in conn.execute("PRAGMA table_info('pipeline_events')").fetchall()} - if "entity_id" not in columns: - return [] - - tenant_id = getattr(request.state, "tenant_id", None) - if tenant_id is not None and "tenant_id" not in columns and tenant_id != "default": - return [] - time_column = "processed_at" if "processed_at" in columns else "created_at" - select_columns = [ - "event_id", - "topic", - f"{time_column} AS processed_at", - ( - "COALESCE(tenant_id, 'default') AS tenant_id" - if "tenant_id" in columns - else "'default' AS tenant_id" - ), - "event_type" if "event_type" in columns else "NULL AS event_type", - "entity_id", - "latency_ms" if "latency_ms" in columns else "NULL AS latency_ms", - ] - where_clauses = ["entity_id = ?"] - params: list[object] = [entity_id] - if "entity_type" in columns: - where_clauses.append("entity_type = ?") - params.append(entity_type) - if tenant_id is not None and "tenant_id" in columns: - where_clauses.append("COALESCE(tenant_id, 'default') = ?") - params.append(str(tenant_id)) - - cursor = conn.execute( - ( - # selected columns come from the schema allowlist - f"SELECT {', '.join(select_columns)} " # nosec B608 - "FROM pipeline_events " - f"WHERE {' AND '.join(where_clauses)} ORDER BY {time_column} ASC" - ), - params, - ) - result_columns = [description[0] for description in cursor.description] - return [dict(zip(result_columns, row, strict=False)) for row in cursor.fetchall()] + # Runs on a worker thread (get_lineage offloads it) so the full-scan can't + # block the event loop and starve every other tenant on the worker. Use a + # dedicated cursor rather than the shared connection object so concurrent + # reads don't collide on it. (audit_30_06_26.md A2) + cursor = request.app.state.query_engine._conn.cursor() + try: + columns = { + row[1] for row in cursor.execute("PRAGMA table_info('pipeline_events')").fetchall() + } + if "entity_id" not in columns: + return [] + + tenant_id = getattr(request.state, "tenant_id", None) + if tenant_id is not None and "tenant_id" not in columns and tenant_id != "default": + return [] + time_column = "processed_at" if "processed_at" in columns else "created_at" + select_columns = [ + "event_id", + "topic", + f"{time_column} AS processed_at", + ( + "COALESCE(tenant_id, 'default') AS tenant_id" + if "tenant_id" in columns + else "'default' AS tenant_id" + ), + "event_type" if "event_type" in columns else "NULL AS event_type", + "entity_id", + "latency_ms" if "latency_ms" in columns else "NULL AS latency_ms", + ] + where_clauses = ["entity_id = ?"] + params: list[object] = [entity_id] + if "entity_type" in columns: + where_clauses.append("entity_type = ?") + params.append(entity_type) + if tenant_id is not None and "tenant_id" in columns: + where_clauses.append("COALESCE(tenant_id, 'default') = ?") + params.append(str(tenant_id)) + + cursor.execute( + ( + # selected columns come from the schema allowlist + f"SELECT {', '.join(select_columns)} " # nosec B608 + "FROM pipeline_events " + f"WHERE {' AND '.join(where_clauses)} ORDER BY {time_column} ASC" + ), + params, + ) + result_columns = [description[0] for description in cursor.description] + return [dict(zip(result_columns, row, strict=False)) for row in cursor.fetchall()] + finally: + cursor.close() @router.get( @@ -137,7 +147,7 @@ async def get_lineage(entity_type: str, entity_id: str, request: Request) -> Lin detail=f"API key '{tenant_key.name}' cannot access entity type '{entity_type}'.", ) - rows = _fetch_matching_events(request, entity_type, entity_id) + rows = await run_in_threadpool(_fetch_matching_events, request, entity_type, entity_id) if not rows: raise HTTPException( status_code=404, diff --git a/src/serving/api/routers/slo.py b/src/serving/api/routers/slo.py index ee3fa0b..9484fd9 100644 --- a/src/serving/api/routers/slo.py +++ b/src/serving/api/routers/slo.py @@ -7,6 +7,7 @@ from fastapi import APIRouter, FastAPI, HTTPException, Request from pydantic import BaseModel, Field +from starlette.concurrency import run_in_threadpool try: import yaml @@ -60,8 +61,11 @@ def load_slos(path: Path) -> list[SLODefinition]: def _pipeline_event_columns(request: Request) -> set[str]: - conn = request.app.state.query_engine._conn - return {row[1] for row in conn.execute("PRAGMA table_info('pipeline_events')").fetchall()} + cursor = request.app.state.query_engine._conn.cursor() + try: + return {row[1] for row in cursor.execute("PRAGMA table_info('pipeline_events')").fetchall()} + finally: + cursor.close() def _time_column(columns: set[str]) -> str | None: @@ -96,7 +100,10 @@ def _measurement_value( if time_column is None: return None - conn = request.app.state.query_engine._conn + # A dedicated cursor (not the shared connection object) keeps concurrent + # /v1/slo requests — offloaded to worker threads by get_slos — from + # colliding on the connection. (audit_30_06_26.md A2) + conn = request.app.state.query_engine._conn.cursor() window = f"{definition.window_days} days" tenant_sql, tenant_params = _tenant_filter(columns, _tenant_id(request)) @@ -195,9 +202,12 @@ def _error_budget_remaining(target: float, current: float) -> float: return max(0.0, min(1.0, 1.0 - consumed)) -@router.get("", response_model=SLOResponse) -async def get_slos(request: Request) -> SLOResponse: - definitions = load_slos(get_slo_config_path(request.app)) +def _compute_slo_statuses(request: Request, definitions: list[SLODefinition]) -> list[SLOStatus]: + # Runs on a worker thread (get_slos offloads it) so the per-SLO aggregate + # scans can't block the event loop for every tenant on the worker. The + # helpers each open their own short-lived cursor, so concurrent /v1/slo + # requests on different threads never collide on the shared connection. + # (audit_30_06_26.md A2) columns = _pipeline_event_columns(request) time_column = _time_column(columns) statuses = [] @@ -232,4 +242,11 @@ async def get_slos(request: Request) -> SLOResponse: ) ) + return statuses + + +@router.get("", response_model=SLOResponse) +async def get_slos(request: Request) -> SLOResponse: + definitions = load_slos(get_slo_config_path(request.app)) + statuses = await run_in_threadpool(_compute_slo_statuses, request, definitions) return SLOResponse(slos=statuses) diff --git a/src/serving/api/routers/stream.py b/src/serving/api/routers/stream.py index 1fd7c68..1293af4 100644 --- a/src/serving/api/routers/stream.py +++ b/src/serving/api/routers/stream.py @@ -5,9 +5,11 @@ from collections.abc import AsyncIterator from datetime import datetime +import duckdb from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse from opentelemetry import trace +from starlette.concurrency import run_in_threadpool router = APIRouter(prefix="/v1/stream", tags=["stream"]) tracer = trace.get_tracer("agentflow.api") @@ -19,9 +21,38 @@ async def fetch_recent_events( entity_id: str | None = None, limit: int = 10, ) -> list[dict[str, object]]: - """Fetch recent pipeline events from DuckDB with optional filters.""" - conn = request.app.state.query_engine._conn - columns = {row[1] for row in conn.execute("PRAGMA table_info('pipeline_events')").fetchall()} + """Fetch recent pipeline events from DuckDB with optional filters. + + Offloaded to a worker thread: the SSE generator calls this once per second + per open stream, so running the scan inline would block the event loop (and + every other tenant on the worker) for the scan's duration. (audit_30 A2) + """ + return await run_in_threadpool(_fetch_recent_events_sync, request, event_type, entity_id, limit) + + +def _fetch_recent_events_sync( + request: Request, + event_type: str | None, + entity_id: str | None, + limit: int, +) -> list[dict[str, object]]: + # Use a dedicated cursor (not the shared connection object) so concurrent + # streams running on different worker threads don't collide on it. + cursor = request.app.state.query_engine._conn.cursor() + try: + return _fetch_recent_events_with_cursor(cursor, request, event_type, entity_id, limit) + finally: + cursor.close() + + +def _fetch_recent_events_with_cursor( + cursor: duckdb.DuckDBPyConnection, + request: Request, + event_type: str | None, + entity_id: str | None, + limit: int, +) -> list[dict[str, object]]: + columns = {row[1] for row in cursor.execute("PRAGMA table_info('pipeline_events')").fetchall()} time_column = "processed_at" if "processed_at" in columns else "created_at" tenant_id = getattr(request.state, "tenant_id", None) if tenant_id is not None and "tenant_id" not in columns and tenant_id != "default": @@ -80,8 +111,8 @@ async def fetch_recent_events( sql = f"{sql} ORDER BY {time_column} DESC LIMIT ?" params.append(limit) - rows = conn.execute(sql, params).fetchall() - result_columns = [description[0] for description in conn.description] + rows = cursor.execute(sql, params).fetchall() + result_columns = [description[0] for description in cursor.description] return [dict(zip(result_columns, row, strict=False)) for row in rows] diff --git a/src/serving/api/webhook_dispatcher.py b/src/serving/api/webhook_dispatcher.py index 5c2dd53..08ccacd 100644 --- a/src/serving/api/webhook_dispatcher.py +++ b/src/serving/api/webhook_dispatcher.py @@ -258,19 +258,55 @@ async def dispatch_new_events(self) -> None: seen_key = _seen_event_key(event) if not event_id or event_id in self.seen_event_ids or seen_key in self.seen_event_ids: continue - self.seen_event_ids.add(seen_key) tenant = str(event.get("tenant_id") or "default") + enqueued_all = True for webhook in webhooks_by_tenant.get(tenant, []): - if _matches_filters(event, webhook.filters): - # Record the delivery durably *before* attempting it, then - # attempt inline (low latency for the happy path). A failure - # leaves a 'pending' row that process_delivery_queue re-drives - # — surviving all-retries-failed and a process restart, which - # the in-memory seen-set alone could not (audit #3). - self._enqueue_delivery(webhook, event) + if not _matches_filters(event, webhook.filters): + continue + # Record the delivery durably *before* attempting it, then attempt + # inline (low latency for the happy path). A failure leaves a + # 'pending' row that process_delivery_queue re-drives — surviving + # all-retries-failed and a process restart, which the in-memory + # seen-set alone could not (audit #3). Each webhook is isolated: + # one webhook's exception must neither abort the scan (skipping + # later webhooks) nor mark the event seen before it was durably + # enqueued. (audit_30_06_26.md C2) + try: + inserted = self._enqueue_delivery(webhook, event) + except Exception as exc: + logger.warning( + "webhook_enqueue_failed", + webhook_id=webhook.id, + event_id=event_id, + error=str(exc), + ) + enqueued_all = False + continue + if not inserted: + # Already enqueued on an earlier scan (this event stayed unseen + # because some other webhook's enqueue failed); its durable row + # is re-driven by process_delivery_queue — don't re-POST inline. + continue + try: result = await self.deliver(webhook, event) self._record_delivery_outcome(webhook.id, event_id, result) + except Exception as exc: + # Durable row is already 'pending'; let process_delivery_queue + # re-drive it instead of unwinding the whole scan. + logger.warning( + "webhook_inline_delivery_failed", + webhook_id=webhook.id, + event_id=event_id, + error=str(exc), + ) + + # Mark the event seen only once every matching webhook is durably + # enqueued. This also drives metric-cache invalidation (main.py wraps + # this method and invalidates on seen-set growth), so it still runs + # for events with zero matching webhooks (enqueued_all stays True). + if enqueued_all: + self.seen_event_ids.add(seen_key) async def deliver(self, webhook: WebhookRegistration, event: dict) -> dict: """Deliver one event now (the ``/test`` endpoint and the inline dispatch @@ -424,14 +460,24 @@ def _fetch_pipeline_events(self, tenant: str | None = None) -> list[dict]: result_columns = [description[0] for description in cursor.description] return [dict(zip(result_columns, row, strict=False)) for row in cursor.fetchall()] - def _enqueue_delivery(self, webhook: WebhookRegistration, event: dict) -> None: + def _enqueue_delivery(self, webhook: WebhookRegistration, event: dict) -> bool: """Durably record a (webhook, event) delivery as ``pending`` (idempotent - on the primary key — a re-scan of the same event never duplicates it).""" + on the primary key — a re-scan of the same event never duplicates it). + + Returns ``True`` only when a new row is inserted, so the caller can + inline-deliver exactly the fresh rows and never re-POST a (webhook, + event) that was already enqueued on an earlier scan.""" event_id = str(event.get("event_id") or "") if not event_id: - return + return False conn = self.app.state.query_engine._conn ensure_webhook_delivery_queue_table(conn) + existing = conn.execute( + "SELECT 1 FROM webhook_delivery_queue WHERE webhook_id = ? AND event_id = ?", + [webhook.id, event_id], + ).fetchone() + if existing is not None: + return False now = datetime.now(UTC) conn.execute( """ @@ -452,6 +498,7 @@ def _enqueue_delivery(self, webhook: WebhookRegistration, event: dict) -> None: now, ], ) + return True def _record_delivery_outcome(self, webhook_id: str, event_id: str, result: dict) -> None: """Advance a queue row from the outcome of one delivery round: success → diff --git a/src/serving/masking.py b/src/serving/masking.py index 972f37d..3027e03 100644 --- a/src/serving/masking.py +++ b/src/serving/masking.py @@ -19,7 +19,14 @@ def __init__(self, config_path: str | Path = "config/pii_fields.yaml"): self.config_path = Path(config_path) self._config = yaml.safe_load(self.config_path.read_text(encoding="utf-8")) or {} - def mask(self, entity_type: str, data: dict, tenant: str) -> dict: + def mask( + self, + entity_type: str, + data: dict, + tenant: str, + *, + source_columns: dict[str, set[str]] | None = None, + ) -> dict: masking = self._config.get("masking", {}) if tenant in masking.get("pii_exempt_tenants", []): return dict(data) @@ -28,13 +35,36 @@ def mask(self, entity_type: str, data: dict, tenant: str) -> dict: masked = dict(data) for rule in rules: field = rule.get("field") - if field in masked: - masked[field] = self._apply_strategy( - masked[field], - rule.get("strategy", default_strategy), - ) + if not field: + continue + strategy = rule.get("strategy", default_strategy) + for output_col in self._output_columns_for_field(field, masked, source_columns): + masked[output_col] = self._apply_strategy(masked[output_col], strategy) return masked + def _output_columns_for_field( + self, + field: str, + data: dict, + source_columns: dict[str, set[str]] | None, + ) -> set[str]: + """Which result columns to mask for a rule's source ``field``. + + With projection lineage (from a SELECT), mask every output column that + *derives* from the source field — this catches a renamed/derived PII + column such as ``email AS contact`` or ``lower(email) AS e`` that the old + output-name match silently let through as cleartext. (audit_30_06_26.md D2) + Without lineage (a single-entity payload, ``SELECT *``, or unparseable + SQL) the output column keeps the source name, so match by name. + """ + if source_columns is not None: + return { + output_col + for output_col, sources in source_columns.items() + if field in sources and output_col in data + } + return {field} if field in data else set() + def mask_query_results( self, sql: str, @@ -48,15 +78,45 @@ def mask_query_results( } if not entity_types: return [dict(row) for row in rows], False + # Resolve projection lineage so a renamed/derived PII column is masked by + # what it's *built from*, not by its output name. (audit_30_06_26.md D2) + source_columns = self._projection_source_columns(sql) # Apply every matched entity's masking rules. A multi-entity JOIN must not # bypass masking — returning the rows unmasked leaked cleartext PII # (e.g. users_enriched JOIN orders_v2). Mask the union of all matched # entities rather than failing open. (audit_28_06_26.md #6) masked_rows = [dict(row) for row in rows] for entity_type in entity_types: - masked_rows = [self.mask(entity_type, row, tenant) for row in masked_rows] + masked_rows = [ + self.mask(entity_type, row, tenant, source_columns=source_columns) + for row in masked_rows + ] return masked_rows, masked_rows != rows + def _projection_source_columns(self, sql: str) -> dict[str, set[str]] | None: + """Map each output column to the source column names that feed it. + + Returns ``None`` when the projection can't be resolved precisely — a + ``SELECT *`` (whose outputs are the source names verbatim) or unparseable + SQL — so masking falls back to matching rule fields against output names. + """ + try: + parsed = sqlglot.parse_one(sql, read="duckdb") + except sqlglot.errors.ParseError: + return None + select = parsed.find(exp.Select) + if select is None: + return None + mapping: dict[str, set[str]] = {} + for projection in select.expressions: + if isinstance(projection, exp.Star) or projection.find(exp.Star) is not None: + return None + output_name = projection.alias_or_name + if not output_name: + continue + mapping[output_name] = {col.name for col in projection.find_all(exp.Column) if col.name} + return mapping + def _extract_table_names(self, sql: str) -> set[str]: try: parsed = sqlglot.parse_one(sql, read="duckdb") diff --git a/src/serving/semantic_layer/query/sql_builder.py b/src/serving/semantic_layer/query/sql_builder.py index 97c8e35..904228d 100644 --- a/src/serving/semantic_layer/query/sql_builder.py +++ b/src/serving/semantic_layer/query/sql_builder.py @@ -6,6 +6,7 @@ import sqlglot from sqlglot import exp +from sqlglot.optimizer.scope import traverse_scope from src.serving.api.auth import get_current_tenant_id @@ -79,32 +80,37 @@ def _scope_sql(self: SQLBuilderHost, sql: str, tenant_id: str | None) -> str: known_tables.add("pipeline_events") parsed = sqlglot.parse_one(sql, dialect="duckdb") - cte_names = { - cte.alias_or_name.lower() for cte in parsed.find_all(exp.CTE) if cte.alias_or_name - } + # Classify every table reference by scope so a CTE whose name collides + # with a real table — e.g. `WITH orders_v2 AS (SELECT * FROM orders_v2) + # SELECT * FROM orders_v2` — cannot hide the *physical* inner reference + # from tenant rescoping. The old global cte_names skip dropped any table + # whose name matched any CTE in the statement, so the inner physical + # `orders_v2` stayed unqualified, bound to the shared `main` schema, and + # leaked every tenant's rows (this is the *sole* isolation mechanism: + # one DuckDB DB, a schema per tenant, no per-connection search_path). + # Scope resolution rescopes the physical ref while leaving the genuine + # CTE reference alone. (audit_30_06_26.md D1; builds on audit_28 #5) + physical_tables = [ + table + for scope in traverse_scope(parsed) + for table in scope.tables + if table.name + and table.name.lower() in known_tables + and table.name.lower() not in {name.lower() for name in scope.cte_sources} + ] + schema = self._get_tenant_schema(tenant_id) if schema is None: - for table in parsed.find_all(exp.Table): - table_name = table.name - if ( - not table_name - or table.db - or table.catalog - or table_name.lower() not in known_tables - or table_name.lower() in cte_names - ): - continue - self._qualify_table(table_name, tenant_id) + for table in physical_tables: + if not table.db and not table.catalog: + # No tenant schema resolved: keep the "tenant context is + # required" guard firing for a physical tenant-scoped table + # even when its name collides with a CTE (the old skip let + # such a query silently read `main`). + self._qualify_table(table.name, tenant_id) return sql - for table in parsed.find_all(exp.Table): - table_name = table.name - if ( - not table_name - or table_name.lower() not in known_tables - or table_name.lower() in cte_names - ): - continue + for table in physical_tables: # Force the known table into the caller's tenant schema even if it # arrived already schema-qualified — defense-in-depth so a qualified # name can never read another tenant. validate_nl_sql already rejects @@ -112,6 +118,6 @@ def _scope_sql(self: SQLBuilderHost, sql: str, tenant_id: str | None) -> str: # reaches here through another caller. (audit_28_06_26.md #5) table.set("catalog", None) table.set("db", exp.to_identifier(schema, quoted=True)) - table.set("this", exp.to_identifier(table_name, quoted=True)) + table.set("this", exp.to_identifier(table.name, quoted=True)) return parsed.sql(dialect="duckdb") diff --git a/tests/unit/test_agent_query_async.py b/tests/unit/test_agent_query_async.py index 5515592..8523f91 100644 --- a/tests/unit/test_agent_query_async.py +++ b/tests/unit/test_agent_query_async.py @@ -1,96 +1,117 @@ -import asyncio -import time - -import httpx -import pytest -from fastapi import FastAPI - -from src.serving.api.routers.agent_query import router as agent_router -from src.serving.semantic_layer.catalog import DataCatalog - - -class SlowEngine: - def __init__(self, delay_seconds: float): - self.delay_seconds = delay_seconds - - def get_entity( - self, - entity_type: str, - entity_id: str, - tenant_id: str | None = None, - ) -> dict: - time.sleep(self.delay_seconds) - return { - "id": entity_id, - "entity_type": entity_type, - "tenant_id": tenant_id, - } - - def get_metric( - self, - metric_name: str, - window: str = "1h", - as_of=None, - tenant_id: str | None = None, - ) -> dict: - time.sleep(self.delay_seconds) - return { - "value": 42.0, - "unit": "USD", - "components": { - "metric_name": metric_name, - "window": window, - "tenant_id": tenant_id, - "as_of": as_of.isoformat() if as_of is not None else None, - }, - } - - def execute_nl_query( - self, - question: str, - context: dict | None = None, - tenant_id: str | None = None, - ) -> dict: - time.sleep(self.delay_seconds) - return { - "sql": "SELECT * FROM orders LIMIT 1", - "data": [{"question": question, "tenant_id": tenant_id, "context": context}], - "row_count": 1, - "execution_time_ms": int(self.delay_seconds * 1000), - "freshness_seconds": 0, - } - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("method", "url", "json_payload"), - [ - ("GET", "/v1/entity/order/ORD-20260401-0001", None), - ("GET", "/v1/metrics/revenue?window=1h", None), - ("POST", "/v1/query", {"question": "top orders"}), - ], - ids=["entity", "metric", "query"], -) -async def test_hot_path_endpoints_do_not_block_event_loop( - method: str, - url: str, - json_payload: dict | None, -): - app = FastAPI() - app.state.catalog = DataCatalog() - app.state.query_engine = SlowEngine(delay_seconds=0.3) - app.include_router(agent_router, prefix="/v1") - - transport = httpx.ASGITransport(app=app) - async with httpx.AsyncClient( - transport=transport, - base_url="http://testserver", - ) as client: - started_at = time.perf_counter() - responses = await asyncio.gather( - *[client.request(method, url, json=json_payload) for _ in range(4)] - ) - elapsed = time.perf_counter() - started_at - - assert all(response.status_code == 200 for response in responses) - assert elapsed < 0.9, f"Event loop blocked: {elapsed:.2f}s (expected < 0.9s)" +import asyncio +import time + +import httpx +import pytest +from fastapi import FastAPI + +from src.serving.api.routers.agent_query import router as agent_router +from src.serving.semantic_layer.catalog import DataCatalog + + +class SlowEngine: + def __init__(self, delay_seconds: float): + self.delay_seconds = delay_seconds + + def get_entity( + self, + entity_type: str, + entity_id: str, + tenant_id: str | None = None, + ) -> dict: + time.sleep(self.delay_seconds) + return { + "id": entity_id, + "entity_type": entity_type, + "tenant_id": tenant_id, + } + + def get_metric( + self, + metric_name: str, + window: str = "1h", + as_of=None, + tenant_id: str | None = None, + ) -> dict: + time.sleep(self.delay_seconds) + return { + "value": 42.0, + "unit": "USD", + "components": { + "metric_name": metric_name, + "window": window, + "tenant_id": tenant_id, + "as_of": as_of.isoformat() if as_of is not None else None, + }, + } + + def execute_nl_query( + self, + question: str, + context: dict | None = None, + tenant_id: str | None = None, + ) -> dict: + time.sleep(self.delay_seconds) + return { + "sql": "SELECT * FROM orders LIMIT 1", + "data": [{"question": question, "tenant_id": tenant_id, "context": context}], + "row_count": 1, + "execution_time_ms": int(self.delay_seconds * 1000), + "freshness_seconds": 0, + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("method", "url", "json_payload"), + [ + ("GET", "/v1/entity/order/ORD-20260401-0001", None), + ("GET", "/v1/metrics/revenue?window=1h", None), + ("POST", "/v1/query", {"question": "top orders"}), + ], + ids=["entity", "metric", "query"], +) +async def test_hot_path_endpoints_do_not_block_event_loop( + method: str, + url: str, + json_payload: dict | None, +): + app = FastAPI() + app.state.catalog = DataCatalog() + app.state.query_engine = SlowEngine(delay_seconds=0.3) + app.include_router(agent_router, prefix="/v1") + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient( + transport=transport, + base_url="http://testserver", + ) as client: + started_at = time.perf_counter() + responses = await asyncio.gather( + *[client.request(method, url, json=json_payload) for _ in range(4)] + ) + elapsed = time.perf_counter() - started_at + + assert all(response.status_code == 200 for response in responses) + assert elapsed < 0.9, f"Event loop blocked: {elapsed:.2f}s (expected < 0.9s)" + + +@pytest.mark.asyncio +async def test_get_metric_rejects_window_not_in_available_windows(): + # Pre-fix any window was accepted and silently computed as 1h (active_sessions + # always 30m) while the response echoed the requested window — a wrong value + # under a confident label, plus cache-key pollution. (audit_30_06_26.md A1) + app = FastAPI() + app.state.catalog = DataCatalog() + app.state.query_engine = SlowEngine(delay_seconds=0.0) + app.include_router(agent_router, prefix="/v1") + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + bad = await client.get("/v1/metrics/revenue?window=2h") + good = await client.get("/v1/metrics/revenue?window=1h") + bad_active = await client.get("/v1/metrics/active_sessions?window=24h") + + assert bad.status_code == 422 # 2h is not a declared window for revenue + assert good.status_code == 200 + assert bad_active.status_code == 422 # active_sessions only supports "now" diff --git a/tests/unit/test_alert_escalation_delivery.py b/tests/unit/test_alert_escalation_delivery.py index 14b0c48..a02d07f 100644 --- a/tests/unit/test_alert_escalation_delivery.py +++ b/tests/unit/test_alert_escalation_delivery.py @@ -113,6 +113,57 @@ async def test_escalation_level_not_advanced_on_delivery_failure( assert triggered == 0 +@pytest.mark.asyncio +async def test_resolved_does_not_advance_state_on_delivery_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # The condition cleared (triggered=False) but the resolved page fails to + # deliver. Pre-fix the alert was forced to state="resolved"/fired_at=None + # regardless, so the resolved notification was lost for good and the receiver + # kept the incident open. Now fired_at stays set so the next tick re-attempts. + # (audit_30_06_26.md C1) + _patch_eval(monkeypatch, triggered=False) + _patch_deliver(monkeypatch, success=False) + + fired = datetime(2026, 6, 28, 11, 0, tzinfo=UTC) # 60 min before _NOW + alert = _alert( + fired_at=fired, + state="firing", + last_escalation_level=1, + last_condition_triggered=True, + ) + + result, changed, triggered = await escalation.dispatch_alert(None, alert, _NOW) + + assert result.fired_at == fired # not cleared + assert result.state != "resolved" + assert triggered == 0 + + +@pytest.mark.asyncio +async def test_resolved_advances_state_on_delivery_success( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Happy path still resolves: when the resolved page delivers, the alert + # advances to resolved and clears fired_at. + _patch_eval(monkeypatch, triggered=False) + _patch_deliver(monkeypatch, success=True) + + fired = datetime(2026, 6, 28, 11, 0, tzinfo=UTC) + alert = _alert( + fired_at=fired, + state="firing", + last_escalation_level=1, + last_condition_triggered=True, + ) + + result, changed, triggered = await escalation.dispatch_alert(None, alert, _NOW) + + assert result.fired_at is None + assert result.state == "resolved" + assert triggered == 1 + + # --- next_escalation_step: no intermediate-level skip (audit_28_06_26.md §5 medium) --- _FIRED = datetime(2026, 6, 28, 11, 0, tzinfo=UTC) # 60 min before _NOW diff --git a/tests/unit/test_analytics_middleware.py b/tests/unit/test_analytics_middleware.py index 32b8a91..71c3e06 100644 --- a/tests/unit/test_analytics_middleware.py +++ b/tests/unit/test_analytics_middleware.py @@ -1,3 +1,4 @@ +import json from pathlib import Path from types import SimpleNamespace @@ -112,3 +113,90 @@ def fail_ensure(*args, **kwargs): count = duckdb.connect(str(db_path)).execute("SELECT COUNT(*) FROM api_sessions").fetchone()[0] assert count == 1 + + +def _v1_request(app: FastAPI, *, method: str, path: str, body: bytes = b"") -> Request: + async def receive(): + return {"type": "http.request", "body": body, "more_body": False} + + return Request( + { + "type": "http", + "http_version": "1.1", + "method": method, + "path": path, + "raw_path": path.encode(), + "query_string": b"", + "headers": [], + "client": ("127.0.0.1", 12345), + "server": ("testserver", 80), + "scheme": "http", + "state": {}, + "app": app, + }, + receive=receive, + ) + + +@pytest.mark.anyio +async def test_analytics_middleware_skips_unauthenticated_request(tmp_path: Path): + # Analytics runs OUTSIDE AuthMiddleware. An unauthenticated /v1 request + # (no tenant_key on request.state, rejected downstream with 401) must NOT + # schedule a session write — otherwise unauthenticated, un-throttled traffic + # drives a DB-writing thread spawn per request. (audit_30_06_26.md S1) + app = FastAPI() + app.state.auth_manager = SimpleNamespace( + db_path=tmp_path / "usage.duckdb", + has_configured_keys=lambda: True, + ) + middleware = build_analytics_middleware() + request = _v1_request(app, method="POST", path="/v1/query", body=b'{"question": "x"}') + + async def call_next(_: Request) -> Response: + return Response(status_code=401) + + response = await middleware(request, call_next) + + assert response.status_code == 401 + assert getattr(response.background, "tasks", []) == [] + + +@pytest.mark.anyio +async def test_analytics_middleware_records_authenticated_request(tmp_path: Path): + # The happy path still records: an authenticated request schedules exactly + # one session-write background task. + app = FastAPI() + app.state.auth_manager = SimpleNamespace( + db_path=tmp_path / "usage.duckdb", + has_configured_keys=lambda: True, + ) + middleware = build_analytics_middleware() + request = _v1_request(app, method="GET", path="/v1/entity/order/ORD-1") + request.state.tenant_key = SimpleNamespace(tenant="acme", name="Agent") + + async def call_next(_: Request) -> Response: + return Response(status_code=200) + + response = await middleware(request, call_next) + + assert response.status_code == 200 + assert len(getattr(response.background, "tasks", [])) == 1 + + +def test_build_session_record_caps_query_text(tmp_path: Path): + # An oversized /v1/query body is truncated before it is persisted, so it + # can't be used to amplify storage. (audit_30_06_26.md S1) + app = FastAPI() + body = json.dumps({"question": "x" * 5000}).encode() + request = _v1_request(app, method="POST", path="/v1/query", body=body) + + record = analytics_module._build_session_record( + request=request, + request_id="req-1", + status_code=200, + duration_ms=1.0, + cache_hit=False, + body=body, + ) + + assert len(record["query_text"]) == 1000 diff --git a/tests/unit/test_blocking_routes_offloaded.py b/tests/unit/test_blocking_routes_offloaded.py new file mode 100644 index 0000000..c2b14ec --- /dev/null +++ b/tests/unit/test_blocking_routes_offloaded.py @@ -0,0 +1,102 @@ +"""Cold read endpoints must offload their synchronous DuckDB scans to a worker +thread (audit_30_06_26.md A2). The hot paths (entity/metric/query) are pinned by +test_agent_query_async; lineage was left running its scan inline on the event +loop, so concurrent requests serialized and blocked every other tenant on the +worker. This drives the lineage route through ASGI with a connection whose scan +sleeps, and asserts four concurrent requests overlap instead of serializing. +""" + +from __future__ import annotations + +import asyncio +import time +from datetime import UTC, datetime +from types import SimpleNamespace + +import httpx +import pytest +from fastapi import FastAPI + +from src.serving.api.routers.lineage import router as lineage_router +from src.serving.semantic_layer.catalog import DataCatalog + +_PRAGMA_ROWS = [ + (0, "event_id"), + (1, "topic"), + (2, "processed_at"), + (3, "tenant_id"), + (4, "event_type"), + (5, "entity_id"), + (6, "latency_ms"), +] + + +class _SlowExecutor: + """Mimics a DuckDB connection/cursor: ``execute`` returns self and the + provenance SELECT sleeps. Supporting ``execute`` directly (the pre-fix path) + *and* ``cursor()`` (the fixed path) lets the same test serialize on the old + code and overlap on the new.""" + + def __init__(self, delay_seconds: float) -> None: + self.delay_seconds = delay_seconds + self._mode = "pragma" + self.description: list[tuple[str]] = [] + + def execute(self, sql: str, params: object = None) -> _SlowExecutor: + if "PRAGMA" in sql: + self._mode = "pragma" + else: + # The provenance scan — this is what blocked the event loop inline. + time.sleep(self.delay_seconds) + self._mode = "select" + self.description = [ + ("event_id",), + ("topic",), + ("processed_at",), + ("tenant_id",), + ("event_type",), + ("entity_id",), + ("latency_ms",), + ] + return self + + def fetchall(self) -> list[tuple]: + if self._mode == "pragma": + return _PRAGMA_ROWS + return [ + ( + "E1", + "orders.raw", + datetime(2026, 6, 30, tzinfo=UTC), + "default", + "order.created", + "ORD-1", + 5.0, + ) + ] + + def close(self) -> None: + pass + + +class _SlowConn(_SlowExecutor): + def cursor(self) -> _SlowExecutor: + return _SlowExecutor(self.delay_seconds) + + +@pytest.mark.asyncio +async def test_lineage_does_not_block_event_loop() -> None: + app = FastAPI() + app.state.catalog = DataCatalog() + app.state.query_engine = SimpleNamespace(_conn=_SlowConn(delay_seconds=0.3)) + app.include_router(lineage_router) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + started_at = time.perf_counter() + responses = await asyncio.gather(*[client.get("/v1/lineage/order/ORD-1") for _ in range(4)]) + elapsed = time.perf_counter() - started_at + + assert all(response.status_code == 200 for response in responses) + # 4 × 0.3s serialized ≈ 1.2s; offloaded to the threadpool they overlap (~0.3s). + assert elapsed < 0.9, f"Event loop blocked: {elapsed:.2f}s (expected < 0.9s)" diff --git a/tests/unit/test_masking.py b/tests/unit/test_masking.py index 7c4d71a..f163e87 100644 --- a/tests/unit/test_masking.py +++ b/tests/unit/test_masking.py @@ -145,6 +145,70 @@ def test_mask_query_results_masks_single_entity(tmp_path: Path): assert rows == [{"email": "j***@example.com"}] +def _user_email_config(tmp_path: Path) -> Path: + return _write_pii_config( + tmp_path / "pii_fields.yaml", + """ + masking: + default_strategy: partial + entity_fields: + user: + - field: email + strategy: partial + pii_exempt_tenants: [] + """, + ) + + +def test_mask_query_results_masks_aliased_pii_column(tmp_path: Path): + # `email AS contact` previously bypassed masking: the output key was + # "contact", not the rule field "email", so the cleartext address was + # returned with no X-PII-Masked signal. Projection lineage now masks the + # column derived from email. (audit_30_06_26.md D2) + masker = PiiMasker(_user_email_config(tmp_path)) + + rows, masked = masker.mask_query_results( + "SELECT email AS contact FROM users_enriched", + [{"contact": "alice@example.com"}], + tenant="acme", + table_to_entity={"users_enriched": "user"}, + ) + + assert masked is True + assert rows == [{"contact": "a***@example.com"}] + + +def test_mask_query_results_masks_derived_pii_column(tmp_path: Path): + # A derived expression over a PII column is masked by what it's built from. + masker = PiiMasker(_user_email_config(tmp_path)) + + rows, masked = masker.mask_query_results( + "SELECT lower(email) AS e FROM users_enriched", + [{"e": "alice@example.com"}], + tenant="acme", + table_to_entity={"users_enriched": "user"}, + ) + + assert masked is True + assert rows == [{"e": "a***@example.com"}] + + +def test_mask_query_results_masks_select_star_by_name(tmp_path: Path): + # SELECT * has no resolvable projection lineage; masking falls back to + # matching rule fields against the (canonical) output column names. + masker = PiiMasker(_user_email_config(tmp_path)) + + rows, masked = masker.mask_query_results( + "SELECT * FROM users_enriched", + [{"email": "alice@example.com", "user_id": "U-1"}], + tenant="acme", + table_to_entity={"users_enriched": "user"}, + ) + + assert masked is True + assert rows == [{"email": "a***@example.com", "user_id": "U-1"}] + + def test_mask_query_results_masks_union_when_multiple_entities(tmp_path: Path): """A multi-entity JOIN must mask the union of all matched entities, not fail open. The old behaviour returned cleartext PII for any query touching !=1 diff --git a/tests/unit/test_query_engine.py b/tests/unit/test_query_engine.py index e763f37..ff2af55 100644 --- a/tests/unit/test_query_engine.py +++ b/tests/unit/test_query_engine.py @@ -49,6 +49,25 @@ def test_scope_sql_does_not_qualify_cte_aliases(engine: QueryEngine) -> None: assert ("orders_v2", "") in _tables(scoped) +def test_scope_sql_qualifies_physical_table_shadowed_by_cte_name(engine: QueryEngine) -> None: + # A CTE whose name collides with a real table must not hide the *physical* + # inner reference from tenant rescoping. Pre-fix the inner `orders_v2` was + # skipped (its name matched the CTE) and stayed bound to the shared `main` + # schema, leaking every tenant's rows. (audit_30_06_26.md D1) + scoped = engine._scope_sql( + "WITH orders_v2 AS (SELECT * FROM orders_v2) SELECT * FROM orders_v2", + tenant_id="tenant_a", + ) + + tables = _tables(scoped) + # The physical inner reference is now pinned to the caller's tenant schema. + assert ("orders_v2", "tenant_a") in tables + # The only unqualified `orders_v2` left is the outer CTE reference (1, not 2). + assert tables.count(("orders_v2", "")) == 1 + # And nothing fell back to the shared `main` schema. + assert all(db != "main" for _, db in tables) + + def test_scope_sql_qualifies_tables_after_subquery(engine: QueryEngine) -> None: scoped = engine._scope_sql( "SELECT * FROM (SELECT * FROM orders_v2) AS recent, users_enriched", diff --git a/tests/unit/test_webhook_dispatcher_unit.py b/tests/unit/test_webhook_dispatcher_unit.py index 56f2739..73c1a75 100644 --- a/tests/unit/test_webhook_dispatcher_unit.py +++ b/tests/unit/test_webhook_dispatcher_unit.py @@ -20,6 +20,7 @@ from types import SimpleNamespace import duckdb +import httpx import pytest from src.serving.api.webhook_dispatcher import ( @@ -276,6 +277,75 @@ def test_enqueue_delivery_is_idempotent_on_webhook_event() -> None: conn.close() +def test_enqueue_delivery_returns_true_only_for_a_new_row() -> None: + conn = duckdb.connect(":memory:") + try: + dispatcher = WebhookDispatcher(_stub_app(conn)) + webhook = SimpleNamespace(id="wh-1") + # A fresh (webhook, event) inserts and tells the caller to inline-deliver. + assert dispatcher._enqueue_delivery(webhook, _event("e1")) is True + # A re-scan of an already-queued pair is a no-op and must NOT be + # re-delivered inline (that would storm the receiver every poll cycle + # whenever an unrelated webhook left the event unseen). (audit_30_06_26.md C2) + assert dispatcher._enqueue_delivery(webhook, _event("e1")) is False + finally: + conn.close() + + +@pytest.mark.asyncio +async def test_dispatch_isolates_webhook_failure_and_enqueues_all( + config_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + # Two webhooks match the same event; the first one's inline delivery raises + # an error deliver() does not catch (e.g. httpx.InvalidURL). Pre-fix that + # exception propagated out of dispatch_new_events *after* the event was + # already marked seen, so the second webhook was never enqueued and its + # delivery was lost for good. Now each webhook is durably enqueued first and + # isolated, and the event is marked seen only once all are enqueued. + # (audit_30_06_26.md C2) + conn = duckdb.connect(":memory:") + try: + conn.execute( + "CREATE TABLE pipeline_events (event_id VARCHAR, topic VARCHAR, " + "tenant_id VARCHAR DEFAULT 'default', event_type VARCHAR, processed_at TIMESTAMP)" + ) + conn.execute( + "INSERT INTO pipeline_events VALUES " + "('e1', 'orders.raw', 'acme', 'order.created', NOW())" + ) + wh1 = create_webhook( + config_path, url="https://a.test/h1", tenant="acme", filters=WebhookFilters() + ) + wh2 = create_webhook( + config_path, url="https://b.test/h2", tenant="acme", filters=WebhookFilters() + ) + monkeypatch.setattr( + "src.serving.api.webhook_dispatcher.get_webhook_config_path", + lambda app: config_path, + ) + dispatcher = WebhookDispatcher(_stub_app(conn)) + + async def _deliver(webhook: object, event: dict) -> dict: + if getattr(webhook, "id", None) == wh1.id: + raise httpx.InvalidURL("boom") + return {"success": True, "status_code": 200, "event_id": event["event_id"]} + + monkeypatch.setattr(dispatcher, "deliver", _deliver) + + await dispatcher.dispatch_new_events() # must not raise + + # Both webhooks are durably enqueued despite wh1's inline failure... + assert _queue_row(conn, wh1.id, "e1") is not None + assert _queue_row(conn, wh2.id, "e1") is not None + # ...wh2 delivered, wh1 left 'pending' for process_delivery_queue to re-drive. + assert _queue_row(conn, wh2.id, "e1")[0] == "delivered" + assert _queue_row(conn, wh1.id, "e1")[0] == "pending" + # The event is marked seen (durably enqueued) so it isn't re-scanned. + assert "acme:e1" in dispatcher.seen_event_ids + finally: + conn.close() + + def test_record_outcome_success_marks_delivered() -> None: conn = duckdb.connect(":memory:") try: