Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions horizon/authentication.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,35 @@
from typing import Annotated

from fastapi import Header, HTTPException, status
from fastapi import Header, HTTPException, Request, status

from horizon.config import MOCK_API_KEY, sidecar_config
from horizon.startup.api_keys import get_env_api_key


def enforce_pdp_token(authorization: Annotated[str | None, Header()]):
def _parse_bearer_token(authorization: str | None, header_name: str) -> str:
if authorization is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header")
schema, token = authorization.split(" ")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=f"Missing {header_name} header")
parts = authorization.split(" ")
if len(parts) != 2:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=f"bad authz header: {authorization}")
schema, token = parts

if schema.strip().lower() != "bearer":
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid PDP token")
return token.strip()


def extract_pdp_api_key(request: Request) -> str:
header_name = sidecar_config.AUTH_HEADER or "Authorization"
return _parse_bearer_token(request.headers.get(header_name), header_name)


def get_pdp_authorization_header(request: Request) -> str:
return f"Bearer {extract_pdp_api_key(request)}"


if schema.strip().lower() != "bearer" or token.strip() != get_env_api_key():
def enforce_pdp_token(request: Request):
if extract_pdp_api_key(request) != get_env_api_key():
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid PDP token")


Expand All @@ -24,7 +42,7 @@ def enforce_pdp_control_key(authorization: Annotated[str | None, Header()]):

if authorization is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header")
schema, token = authorization.split(" ")
token = _parse_bearer_token(authorization, "Authorization")

if schema.strip().lower() != "bearer" or token.strip() != sidecar_config.CONTAINER_CONTROL_KEY:
if token != sidecar_config.CONTAINER_CONTROL_KEY:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid PDP token")
7 changes: 7 additions & 0 deletions horizon/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def __new__(cls, *, prefix=None, is_model=True): # noqa: ARG004
description="set this to your environment's API key if you prefer to use the environment level API key.",
)

# request header used to authenticate PDP API calls
AUTH_HEADER = confi.str(
"AUTH_HEADER",
"Authorization",
description="HTTP header used to read the PDP bearer token from incoming API calls.",
)

# access token to your organization
ORG_API_KEY = confi.str(
"ORG_API_KEY",
Expand Down
16 changes: 1 addition & 15 deletions horizon/enforcer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import parse_obj_as
from starlette.responses import JSONResponse

from horizon.authentication import enforce_pdp_token
from horizon.authentication import enforce_pdp_token, extract_pdp_api_key
from horizon.config import sidecar_config
from horizon.enforcer.schemas import (
AllTenantsAuthorizationResult,
Expand Down Expand Up @@ -62,20 +62,6 @@
)


def extract_pdp_api_key(request: Request) -> str:
authorization: str = request.headers.get(AUTHZ_HEADER, "")
parts = authorization.split(" ")
if len(parts) != 2:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
detail=f"bad authz header: {authorization}",
)
schema, token = parts
if schema.strip().lower() != "bearer":
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid PDP token")
return token


def transform_headers(request: Request) -> dict:
token = extract_pdp_api_key(request)
return {
Expand Down
4 changes: 3 additions & 1 deletion horizon/facts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from starlette.responses import Response as FastApiResponse
from starlette.responses import StreamingResponse

from horizon.authentication import get_pdp_authorization_header
from horizon.config import sidecar_config
from horizon.startup.api_keys import get_env_api_key
from horizon.startup.remote_config import get_remote_config
Expand Down Expand Up @@ -53,8 +54,9 @@ async def build_forward_request(
:return: HTTPX request
"""
forward_headers = {
key: value for key, value in request.headers.items() if key.lower() in {"authorization", "content-type"}
key: value for key, value in request.headers.items() if key.lower() == "content-type"
}
forward_headers["Authorization"] = get_pdp_authorization_header(request)
if is_consistent_update:
forward_headers[CONSISTENT_UPDATE_HEADER] = CONSISTENT_UPDATE_HEADER_VALUE
remote_config = get_remote_config()
Expand Down
22 changes: 11 additions & 11 deletions horizon/facts/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from loguru import logger
from opal_common.schemas.data import DataSourceEntry

from horizon.authentication import enforce_pdp_token
from horizon.authentication import enforce_pdp_token, get_pdp_authorization_header
from horizon.facts.client import FactsClient, FactsClientDependency
from horizon.facts.dependencies import (
DataUpdateSubscriberDependency,
Expand Down Expand Up @@ -51,7 +51,7 @@ async def create_user(
obj_type="users",
obj_id=body["id"],
obj_key=body["key"],
authorization_header=r.headers.get("Authorization"),
authorization_header=get_pdp_authorization_header(r),
update_id=update_id,
)
],
Expand All @@ -78,7 +78,7 @@ async def create_tenant(
obj_type="tenants",
obj_id=body["id"],
obj_key=body["key"],
authorization_header=r.headers.get("Authorization"),
authorization_header=get_pdp_authorization_header(r),
update_id=update_id,
)
],
Expand Down Expand Up @@ -106,7 +106,7 @@ async def sync_user(
obj_type="users",
obj_id=body["id"],
obj_key=body["key"],
authorization_header=r.headers.get("Authorization"),
authorization_header=get_pdp_authorization_header(r),
update_id=update_id,
)
],
Expand Down Expand Up @@ -134,7 +134,7 @@ async def update_user(
obj_type="users",
obj_id=body["id"],
obj_key=body["key"],
authorization_header=r.headers.get("Authorization"),
authorization_header=get_pdp_authorization_header(r),
update_id=update_id,
)
],
Expand All @@ -150,14 +150,14 @@ def create_role_assignment_data_entries(
obj_type="role_assignments",
obj_id=body["user_id"],
obj_key=f"user:{body['user']}",
authorization_header=request.headers.get("Authorization"),
authorization_header=get_pdp_authorization_header(request),
update_id=update_id,
)
yield create_data_source_entry(
obj_type="users",
obj_id=body["user_id"],
obj_key=body["user"],
authorization_header=request.headers.get("Authorization"),
authorization_header=get_pdp_authorization_header(request),
update_id=update_id,
)
else:
Expand All @@ -167,7 +167,7 @@ def create_role_assignment_data_entries(
obj_type="role_assignments",
obj_id=body["user_id"],
obj_key=f"user:{body['user']}",
authorization_header=request.headers.get("Authorization"),
authorization_header=get_pdp_authorization_header(request),
update_id=update_id,
)

Expand Down Expand Up @@ -271,7 +271,7 @@ async def create_resource_instance(
obj_type="resource_instances",
obj_id=body["id"],
obj_key=f"{body['resource']}:{body['key']}",
authorization_header=r.headers.get("Authorization"),
authorization_header=get_pdp_authorization_header(r),
update_id=update_id,
),
],
Expand Down Expand Up @@ -299,7 +299,7 @@ async def update_resource_instance(
obj_type="resource_instances",
obj_id=body["id"],
obj_key=f"{body['resource']}:{body['key']}",
authorization_header=r.headers.get("Authorization"),
authorization_header=get_pdp_authorization_header(r),
update_id=update_id,
),
],
Expand All @@ -326,7 +326,7 @@ async def create_relationship_tuple(
obj_type="relationships",
obj_id=body["object_id"],
obj_key=body["object"],
authorization_header=r.headers.get("Authorization"),
authorization_header=get_pdp_authorization_header(r),
update_id=update_id,
),
],
Expand Down
10 changes: 3 additions & 7 deletions horizon/proxy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from opal_common.logger import logger
from pydantic import BaseModel, Field, parse_obj_as

from horizon.authentication import get_pdp_authorization_header
from horizon.config import sidecar_config

HTTP_GET = "GET"
Expand Down Expand Up @@ -174,13 +175,7 @@ async def proxy_request_to_cloud_service(
additional_headers: dict[str, str],
timeout: int = sidecar_config.CONTROL_PLANE_TIMEOUT,
) -> Response:
auth_header = request.headers.get("Authorization")
if auth_header is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Must provide a bearer token!",
headers={"WWW-Authenticate": "Bearer"},
)
auth_header = get_pdp_authorization_header(request)
path = f"{cloud_service_url}/{path}"
params = dict(request.query_params) or {}

Expand All @@ -191,6 +186,7 @@ async def proxy_request_to_cloud_service(
for header_name in REQUIRED_HTTP_HEADERS:
if header_name in original_headers:
headers[header_name] = original_headers[header_name]
headers["authorization"] = auth_header

# override host header (required by k8s ingress)
try:
Expand Down
32 changes: 32 additions & 0 deletions horizon/tests/test_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from fastapi import Depends, FastAPI, Request
from fastapi.testclient import TestClient
from horizon.authentication import enforce_pdp_token, get_pdp_authorization_header
from horizon.config import sidecar_config


def test_pdp_auth_uses_configured_header() -> None:
previous_api_key = sidecar_config.API_KEY
previous_auth_header = sidecar_config.AUTH_HEADER
sidecar_config.API_KEY = "mock_api_key"
sidecar_config.AUTH_HEADER = "X-API-Key"

app = FastAPI()

@app.get("/protected", dependencies=[Depends(enforce_pdp_token)])
def protected(request: Request):
return {"authorization": get_pdp_authorization_header(request)}

try:
client = TestClient(app)

response = client.get("/protected", headers={"X-API-Key": "Bearer mock_api_key"})

assert response.status_code == 200
assert response.json() == {"authorization": "Bearer mock_api_key"}

response = client.get("/protected", headers={"Authorization": "Bearer mock_api_key"})

assert response.status_code == 401
finally:
sidecar_config.API_KEY = previous_api_key
sidecar_config.AUTH_HEADER = previous_auth_header