Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 158 additions & 3 deletions apps/api/app/services/s3_events/subscription_service.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
"""SNS subscription confirmation handling."""
from __future__ import annotations

import asyncio
import socket
from collections.abc import Sequence

from loguru import logger
from urllib.parse import ParseResult, parse_qs, urlparse, urlunparse

from shared.core.config import settings
from shared.services.http.pinned_outbound import send_pinned_outbound_request
from shared.services.http.url_security import validate_http_url_and_resolve_ip_async
from shared.services.http.url_security import (
HTTPURLValidationResult,
validate_http_url_and_resolve_ip_async,
)


SNS_SUBSCRIPTION_TIMEOUT_SECONDS = 10
LOCALSTACK_HOSTNAMES = frozenset(
{
"localstack",
"localhost.localstack.cloud",
}
)


async def confirm_sns_subscription(subscribe_url: str) -> dict[str, str]:
validation = await validate_http_url_and_resolve_ip_async(subscribe_url)
rewritten_url = _rewrite_localstack_subscribe_url(subscribe_url)
validation = await _validate_sns_confirmation_url(rewritten_url)
if not validation.is_valid:
logger.warning(
f"SNS subscription confirmation URL failed validation: {validation.error_message}"
Expand All @@ -25,7 +41,7 @@ async def confirm_sns_subscription(subscribe_url: str) -> dict[str, str]:
try:
response = await send_pinned_outbound_request(
method="GET",
url=subscribe_url,
url=rewritten_url,
pinned_ip=validation.validated_ip,
timeout_seconds=SNS_SUBSCRIPTION_TIMEOUT_SECONDS,
)
Expand All @@ -45,3 +61,142 @@ async def confirm_sns_subscription(subscribe_url: str) -> dict[str, str]:
except Exception as exc:
logger.error(f"Failed to reach the SNS confirmation URL: {exc}")
return {"message": "SNS subscription confirmation failed"}


async def _validate_sns_confirmation_url(url: str) -> HTTPURLValidationResult:
"""Validate the SNS confirmation URL, allowing configured LocalStack endpoints."""
if _is_configured_localstack_confirmation_url(url):
return await _resolve_configured_localstack_confirmation_url(url)

return await validate_http_url_and_resolve_ip_async(url)


async def _resolve_configured_localstack_confirmation_url(
url: str,
) -> HTTPURLValidationResult:
parsed_url = urlparse(url)
hostname = parsed_url.hostname
if not hostname:
return HTTPURLValidationResult(
is_valid=False,
url=url,
error_message="URL must include a hostname",
failure_reason="missing_hostname",
)

try:
loop = asyncio.get_running_loop()
address_infos = await loop.getaddrinfo(hostname, None)
except socket.gaierror as exc:
return HTTPURLValidationResult(
is_valid=False,
url=url,
hostname=hostname,
error_message=f"Unable to resolve hostname {hostname}: {exc}",
failure_reason="hostname_resolution_failed",
)

validated_ip = _extract_first_resolved_ip(address_infos)
if not validated_ip:
return HTTPURLValidationResult(
is_valid=False,
url=url,
hostname=hostname,
error_message=f"Unable to resolve hostname {hostname}",
failure_reason="hostname_resolution_failed",
)

return HTTPURLValidationResult(
is_valid=True,
url=url,
hostname=hostname,
validated_ip=validated_ip,
)


def _rewrite_localstack_subscribe_url(subscribe_url: str) -> str:
parsed_subscribe_url = urlparse(subscribe_url)
storage_endpoint = _parse_configured_storage_endpoint()
if not storage_endpoint:
return subscribe_url

if not _is_confirm_subscription_action(parsed_subscribe_url):
return subscribe_url

if not _is_localstack_endpoint(parsed_subscribe_url):
return subscribe_url

if not _is_localstack_endpoint(storage_endpoint):
return subscribe_url

return urlunparse(
parsed_subscribe_url._replace(
scheme=storage_endpoint.scheme,
netloc=storage_endpoint.netloc,
)
)


def _is_configured_localstack_confirmation_url(url: str) -> bool:
parsed_url = urlparse(url)
storage_endpoint = _parse_configured_storage_endpoint()
if not storage_endpoint:
return False

if not _is_confirm_subscription_action(parsed_url):
return False

if not _is_localstack_endpoint(parsed_url):
return False

if not _is_localstack_endpoint(storage_endpoint):
return False

return _endpoint_origin(parsed_url) == _endpoint_origin(storage_endpoint)


def _parse_configured_storage_endpoint() -> ParseResult | None:
endpoint_url = settings.S3_ENDPOINT_URL.strip()
if not endpoint_url:
return None

parsed_endpoint = urlparse(endpoint_url)
if parsed_endpoint.scheme not in {"http", "https"}:
return None
if not parsed_endpoint.hostname:
return None
return parsed_endpoint


def _is_confirm_subscription_action(parsed_url: ParseResult) -> bool:
actions = parse_qs(parsed_url.query).get("Action", [])
return any(action.lower() == "confirmsubscription" for action in actions)


def _is_localstack_endpoint(parsed_url: ParseResult) -> bool:
hostname = (parsed_url.hostname or "").rstrip(".").lower()
return hostname in LOCALSTACK_HOSTNAMES


def _endpoint_origin(parsed_url: ParseResult) -> tuple[str, str, int | None]:
return (
parsed_url.scheme.lower(),
(parsed_url.hostname or "").rstrip(".").lower(),
parsed_url.port,
)


def _extract_first_resolved_ip(address_infos: Sequence[object]) -> str | None:
for address_info in address_infos:
if not isinstance(address_info, tuple) or len(address_info) < 5:
continue
family = address_info[0]
socket_address = address_info[4]
if family not in (socket.AF_INET, socket.AF_INET6):
continue
if not isinstance(socket_address, tuple) or not socket_address:
continue
address = socket_address[0]
if isinstance(address, str) and address:
return address
return None
19 changes: 19 additions & 0 deletions apps/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ async def lifespan(app: FastAPI):
await load_rules(session)
logger.info("rate limit rules loaded at startup; restart the pod to apply changes")

from shared.services.telemetry.runtime import start_self_hosted_telemetry

app.state.self_hosted_telemetry_client = await start_self_hosted_telemetry(
settings,
service_name="knowhere-api",
api_healthy=True,
postgres_healthy=True,
redis_healthy=True,
)

mcp_server = getattr(app.state, "retrieval_mcp_server", None)
mcp_session_manager = getattr(mcp_server, "session_manager", None)

Expand All @@ -78,6 +88,15 @@ async def lifespan(app: FastAPI):
else:
yield

try:
from shared.services.telemetry.runtime import stop_self_hosted_telemetry

await stop_self_hosted_telemetry(
getattr(app.state, "self_hosted_telemetry_client", None)
)
except Exception as e:
logger.error(f"self-hosted telemetry shutdown failed: {e}")

try:
from shared.services.retrieval.stats.recorder import (
drain_retrieval_hit_stats_updates,
Expand Down
75 changes: 75 additions & 0 deletions apps/api/tests/contract/test_s3_event_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import socket
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from types import SimpleNamespace
from typing import cast
from uuid import uuid4

Expand Down Expand Up @@ -278,6 +279,80 @@ def get(self, url: str, *args: object, **kwargs: object) -> object:
assert contacted_urls == []


@pytest.mark.asyncio
async def test_should_confirm_a_configured_localstack_subscription_url_in_self_hosted_runtime(
monkeypatch: MonkeyPatch,
) -> None:
contacted_requests: list[dict[str, str]] = []

class FakeOutboundResponse:
status = 200

async def send_fake_pinned_outbound_request(
*,
method: str,
url: str,
pinned_ip: str,
timeout_seconds: float,
) -> FakeOutboundResponse:
contacted_requests.append(
{
"method": method,
"url": url,
"pinned_ip": pinned_ip,
"timeout_seconds": str(timeout_seconds),
}
)
return FakeOutboundResponse()

def resolve_localstack_address(
host: str,
port: int | None,
*args: object,
**kwargs: object,
) -> list[tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]]:
assert host == "localstack"
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("172.18.0.10", 0))]

subscription_service = importlib.import_module(
"app.services.s3_events.subscription_service"
)

monkeypatch.setattr(
subscription_service,
"settings",
SimpleNamespace(S3_ENDPOINT_URL="http://localstack:4566"),
)
monkeypatch.setattr(socket, "getaddrinfo", resolve_localstack_address)
monkeypatch.setattr(
subscription_service,
"send_pinned_outbound_request",
send_fake_pinned_outbound_request,
)

response = await subscription_service.confirm_sns_subscription(
"http://localhost.localstack.cloud:4566/"
"?Action=ConfirmSubscription"
"&TopicArn=arn:aws:sns:us-west-1:000000000000:test"
"&Token=contract-token"
)

assert response == {"message": "SNS subscription confirmed"}
assert contacted_requests == [
{
"method": "GET",
"url": (
"http://localstack:4566/"
"?Action=ConfirmSubscription"
"&TopicArn=arn:aws:sns:us-west-1:000000000000:test"
"&Token=contract-token"
),
"pinned_ip": "172.18.0.10",
"timeout_seconds": "10",
}
]


@pytest.mark.asyncio
async def test_should_return_ok_for_a_malformed_event_payload_without_triggering_retries(
api_client_factory: Callable[[], AbstractAsyncContextManager[AsyncClient]],
Expand Down
Loading
Loading