From af3638c18808aff6a3fc27633a14c6360c779cf9 Mon Sep 17 00:00:00 2001 From: Omri SirComp Date: Wed, 20 May 2026 12:14:42 +0300 Subject: [PATCH] Support configurable PDP auth header --- horizon/authentication.py | 32 ++++++++++++++++++++++------ horizon/config.py | 7 ++++++ horizon/enforcer/api.py | 16 +------------- horizon/facts/client.py | 4 +++- horizon/facts/router.py | 22 +++++++++---------- horizon/proxy/api.py | 10 +++------ horizon/tests/test_authentication.py | 32 ++++++++++++++++++++++++++++ 7 files changed, 82 insertions(+), 41 deletions(-) create mode 100644 horizon/tests/test_authentication.py diff --git a/horizon/authentication.py b/horizon/authentication.py index c0596da9..959d2d94 100644 --- a/horizon/authentication.py +++ b/horizon/authentication.py @@ -1,17 +1,35 @@ from typing import Annotated -from fastapi import Header, HTTPException, status +from fastapi import Header, HTTPException, Request, status from horizon.config import MOCK_API_KEY, sidecar_config from horizon.startup.api_keys import get_env_api_key -def enforce_pdp_token(authorization: Annotated[str | None, Header()]): +def _parse_bearer_token(authorization: str | None, header_name: str) -> str: if authorization is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header") - schema, token = authorization.split(" ") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=f"Missing {header_name} header") + parts = authorization.split(" ") + if len(parts) != 2: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=f"bad authz header: {authorization}") + schema, token = parts + + if schema.strip().lower() != "bearer": + raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid PDP token") + return token.strip() + + +def extract_pdp_api_key(request: Request) -> str: + header_name = sidecar_config.AUTH_HEADER or "Authorization" + return _parse_bearer_token(request.headers.get(header_name), header_name) + + +def get_pdp_authorization_header(request: Request) -> str: + return f"Bearer {extract_pdp_api_key(request)}" + - if schema.strip().lower() != "bearer" or token.strip() != get_env_api_key(): +def enforce_pdp_token(request: Request): + if extract_pdp_api_key(request) != get_env_api_key(): raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid PDP token") @@ -24,7 +42,7 @@ def enforce_pdp_control_key(authorization: Annotated[str | None, Header()]): if authorization is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header") - schema, token = authorization.split(" ") + token = _parse_bearer_token(authorization, "Authorization") - if schema.strip().lower() != "bearer" or token.strip() != sidecar_config.CONTAINER_CONTROL_KEY: + if token != sidecar_config.CONTAINER_CONTROL_KEY: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid PDP token") diff --git a/horizon/config.py b/horizon/config.py index 87fb42dc..e8ea41da 100644 --- a/horizon/config.py +++ b/horizon/config.py @@ -74,6 +74,13 @@ def __new__(cls, *, prefix=None, is_model=True): # noqa: ARG004 description="set this to your environment's API key if you prefer to use the environment level API key.", ) + # request header used to authenticate PDP API calls + AUTH_HEADER = confi.str( + "AUTH_HEADER", + "Authorization", + description="HTTP header used to read the PDP bearer token from incoming API calls.", + ) + # access token to your organization ORG_API_KEY = confi.str( "ORG_API_KEY", diff --git a/horizon/enforcer/api.py b/horizon/enforcer/api.py index 53a0e8f6..80f29773 100644 --- a/horizon/enforcer/api.py +++ b/horizon/enforcer/api.py @@ -16,7 +16,7 @@ from pydantic import parse_obj_as from starlette.responses import JSONResponse -from horizon.authentication import enforce_pdp_token +from horizon.authentication import enforce_pdp_token, extract_pdp_api_key from horizon.config import sidecar_config from horizon.enforcer.schemas import ( AllTenantsAuthorizationResult, @@ -62,20 +62,6 @@ ) -def extract_pdp_api_key(request: Request) -> str: - authorization: str = request.headers.get(AUTHZ_HEADER, "") - parts = authorization.split(" ") - if len(parts) != 2: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - detail=f"bad authz header: {authorization}", - ) - schema, token = parts - if schema.strip().lower() != "bearer": - raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid PDP token") - return token - - def transform_headers(request: Request) -> dict: token = extract_pdp_api_key(request) return { diff --git a/horizon/facts/client.py b/horizon/facts/client.py index 83e89676..7ccc6a1c 100644 --- a/horizon/facts/client.py +++ b/horizon/facts/client.py @@ -11,6 +11,7 @@ from starlette.responses import Response as FastApiResponse from starlette.responses import StreamingResponse +from horizon.authentication import get_pdp_authorization_header from horizon.config import sidecar_config from horizon.startup.api_keys import get_env_api_key from horizon.startup.remote_config import get_remote_config @@ -53,8 +54,9 @@ async def build_forward_request( :return: HTTPX request """ forward_headers = { - key: value for key, value in request.headers.items() if key.lower() in {"authorization", "content-type"} + key: value for key, value in request.headers.items() if key.lower() == "content-type" } + forward_headers["Authorization"] = get_pdp_authorization_header(request) if is_consistent_update: forward_headers[CONSISTENT_UPDATE_HEADER] = CONSISTENT_UPDATE_HEADER_VALUE remote_config = get_remote_config() diff --git a/horizon/facts/router.py b/horizon/facts/router.py index 6cdfe021..9dab0706 100644 --- a/horizon/facts/router.py +++ b/horizon/facts/router.py @@ -15,7 +15,7 @@ from loguru import logger from opal_common.schemas.data import DataSourceEntry -from horizon.authentication import enforce_pdp_token +from horizon.authentication import enforce_pdp_token, get_pdp_authorization_header from horizon.facts.client import FactsClient, FactsClientDependency from horizon.facts.dependencies import ( DataUpdateSubscriberDependency, @@ -51,7 +51,7 @@ async def create_user( obj_type="users", obj_id=body["id"], obj_key=body["key"], - authorization_header=r.headers.get("Authorization"), + authorization_header=get_pdp_authorization_header(r), update_id=update_id, ) ], @@ -78,7 +78,7 @@ async def create_tenant( obj_type="tenants", obj_id=body["id"], obj_key=body["key"], - authorization_header=r.headers.get("Authorization"), + authorization_header=get_pdp_authorization_header(r), update_id=update_id, ) ], @@ -106,7 +106,7 @@ async def sync_user( obj_type="users", obj_id=body["id"], obj_key=body["key"], - authorization_header=r.headers.get("Authorization"), + authorization_header=get_pdp_authorization_header(r), update_id=update_id, ) ], @@ -134,7 +134,7 @@ async def update_user( obj_type="users", obj_id=body["id"], obj_key=body["key"], - authorization_header=r.headers.get("Authorization"), + authorization_header=get_pdp_authorization_header(r), update_id=update_id, ) ], @@ -150,14 +150,14 @@ def create_role_assignment_data_entries( obj_type="role_assignments", obj_id=body["user_id"], obj_key=f"user:{body['user']}", - authorization_header=request.headers.get("Authorization"), + authorization_header=get_pdp_authorization_header(request), update_id=update_id, ) yield create_data_source_entry( obj_type="users", obj_id=body["user_id"], obj_key=body["user"], - authorization_header=request.headers.get("Authorization"), + authorization_header=get_pdp_authorization_header(request), update_id=update_id, ) else: @@ -167,7 +167,7 @@ def create_role_assignment_data_entries( obj_type="role_assignments", obj_id=body["user_id"], obj_key=f"user:{body['user']}", - authorization_header=request.headers.get("Authorization"), + authorization_header=get_pdp_authorization_header(request), update_id=update_id, ) @@ -271,7 +271,7 @@ async def create_resource_instance( obj_type="resource_instances", obj_id=body["id"], obj_key=f"{body['resource']}:{body['key']}", - authorization_header=r.headers.get("Authorization"), + authorization_header=get_pdp_authorization_header(r), update_id=update_id, ), ], @@ -299,7 +299,7 @@ async def update_resource_instance( obj_type="resource_instances", obj_id=body["id"], obj_key=f"{body['resource']}:{body['key']}", - authorization_header=r.headers.get("Authorization"), + authorization_header=get_pdp_authorization_header(r), update_id=update_id, ), ], @@ -326,7 +326,7 @@ async def create_relationship_tuple( obj_type="relationships", obj_id=body["object_id"], obj_key=body["object"], - authorization_header=r.headers.get("Authorization"), + authorization_header=get_pdp_authorization_header(r), update_id=update_id, ), ], diff --git a/horizon/proxy/api.py b/horizon/proxy/api.py index 5af9973a..e2d1ee63 100644 --- a/horizon/proxy/api.py +++ b/horizon/proxy/api.py @@ -11,6 +11,7 @@ from opal_common.logger import logger from pydantic import BaseModel, Field, parse_obj_as +from horizon.authentication import get_pdp_authorization_header from horizon.config import sidecar_config HTTP_GET = "GET" @@ -174,13 +175,7 @@ async def proxy_request_to_cloud_service( additional_headers: dict[str, str], timeout: int = sidecar_config.CONTROL_PLANE_TIMEOUT, ) -> Response: - auth_header = request.headers.get("Authorization") - if auth_header is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Must provide a bearer token!", - headers={"WWW-Authenticate": "Bearer"}, - ) + auth_header = get_pdp_authorization_header(request) path = f"{cloud_service_url}/{path}" params = dict(request.query_params) or {} @@ -191,6 +186,7 @@ async def proxy_request_to_cloud_service( for header_name in REQUIRED_HTTP_HEADERS: if header_name in original_headers: headers[header_name] = original_headers[header_name] + headers["authorization"] = auth_header # override host header (required by k8s ingress) try: diff --git a/horizon/tests/test_authentication.py b/horizon/tests/test_authentication.py new file mode 100644 index 00000000..88da1590 --- /dev/null +++ b/horizon/tests/test_authentication.py @@ -0,0 +1,32 @@ +from fastapi import Depends, FastAPI, Request +from fastapi.testclient import TestClient +from horizon.authentication import enforce_pdp_token, get_pdp_authorization_header +from horizon.config import sidecar_config + + +def test_pdp_auth_uses_configured_header() -> None: + previous_api_key = sidecar_config.API_KEY + previous_auth_header = sidecar_config.AUTH_HEADER + sidecar_config.API_KEY = "mock_api_key" + sidecar_config.AUTH_HEADER = "X-API-Key" + + app = FastAPI() + + @app.get("/protected", dependencies=[Depends(enforce_pdp_token)]) + def protected(request: Request): + return {"authorization": get_pdp_authorization_header(request)} + + try: + client = TestClient(app) + + response = client.get("/protected", headers={"X-API-Key": "Bearer mock_api_key"}) + + assert response.status_code == 200 + assert response.json() == {"authorization": "Bearer mock_api_key"} + + response = client.get("/protected", headers={"Authorization": "Bearer mock_api_key"}) + + assert response.status_code == 401 + finally: + sidecar_config.API_KEY = previous_api_key + sidecar_config.AUTH_HEADER = previous_auth_header