Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions infrastructure/configuration/airflow.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[api]
auth_backends = airflow.api.auth.backend.basic_auth
auth_backends = configuration.keycloak_bearer_auth, airflow.api.auth.backend.session

[core]
executor = CeleryExecutor
Expand All @@ -16,22 +16,13 @@ parallelism = 10000
broker_url = sqs://
celery_config_options = configuration.celery_config.CELERY_CONFIG


[github_enterprise]
api_rev = v3
host = github.com
client_id = Iv23lil9JEmXAM6QJlFe
client_secret = 8cbd483d2cb4e73599dffba93dbd0295ef0830c5
oauth_callback_route = /home
allowed_teams = VEDA

[webserver]
authenticate = True
auth_backends = airflow.contrib.auth.backends.github_enterprise_auth
dag_default_view = grid
expose_config = true
dag_orientation = TB
warn_deployment_exposure = false
base_url = ${sm2a_base_url}

# On ECS, you can deploy the CloudWatch agent as a sidecar to your application container to collect metrics.
# https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/deploy_servicelens_CloudWatch_agent_deploy_ECS.html
Expand Down
3 changes: 1 addition & 2 deletions infrastructure/configuration/airflow.cfg.tmpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[api]
auth_backends = airflow.api.auth.backend.basic_auth
auth_backends = configuration.keycloak_bearer_auth, airflow.api.auth.backend.basic_auth

[core]
executor = CeleryExecutor
Expand All @@ -16,7 +16,6 @@ parallelism = 10000
broker_url = sqs://
celery_config_options = configuration.celery_config.CELERY_CONFIG


[webserver]
authenticate = True
dag_default_view = grid
Expand Down
146 changes: 146 additions & 0 deletions infrastructure/configuration/keycloak_bearer_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
Custom Airflow API auth backend: Keycloak Bearer tokens.

Configured via `[api] auth_backends = keycloak_bearer_auth` in `airflow.cfg`.

Docs: https://airflow.apache.org/docs/apache-airflow-providers-fab/1.5.4/auth-manager/api-authentication.html#roll-your-own-api-authentication
"""

from __future__ import annotations

import functools
import logging
import os
from typing import Any, Callable

import jwt
import requests
from flask import Response, request
from flask_appbuilder.security.manager import AUTH_DB, AUTH_LDAP, AUTH_OAUTH
from flask_appbuilder.security.sqla.models import User
from flask_login import login_user

from airflow.www.extensions import get_auth_manager

log = logging.getLogger(__name__)


class _AuthError(Exception):
def __init__(self, response: Response):
self.response = response


def init_app(app) -> None:
pass


def requires_authentication(fn: Callable) -> Callable:

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any):
try:
user = _auth_current_user()
except _AuthError as err:
return err.response

if user is not None:
return fn(*args, **kwargs)
return Response("Unauthorized", 401, {"WWW-Authenticate": "Bearer"})

return wrapper


def _auth_current_user() -> User | None:
client_id = os.getenv("KEYCLOAK_CLIENT_ID", "")
ab_security_manager = get_auth_manager().security_manager
user = None
if ab_security_manager.auth_type == AUTH_OAUTH:
if 'Authorization' not in request.headers or not request.headers['Authorization']:
return None

token = str(request.headers['Authorization']).replace("Bearer ", "")
token_decoded = jwt.decode(token, options={"verify_signature": False, "verify_aud": False})

try:
claims = _introspect_token(token)
except Exception:
log.exception("Keycloak token introspection failed")
raise _AuthError(_unauthorized("Token introspection failed"))

if not claims.get("active"):
raise _AuthError(_unauthorized("Token is not active"))

token_roles = claims.get("resource_access", {}).get(client_id, {}).get("roles", [])
allowed_roles = ["Admin", "User", "Dag_Launcher"]
if not any(role in token_roles for role in allowed_roles):
raise _AuthError(_forbidden("Token is valid but missing required role"))

username = token_decoded.get("preferred_username") or token_decoded.get("username") or token_decoded.get("sub", "")
userinfo = {
"username": username,
"email": token_decoded.get("email"),
"first_name": token_decoded.get("given_name"),
"last_name": token_decoded.get("family_name"),
"role_keys": token_roles,
}
user = ab_security_manager.auth_user_oauth(userinfo)
else:
auth = request.authorization
if auth is None or not auth.username or not auth.password:
return None
if ab_security_manager.auth_type == AUTH_LDAP:
user = ab_security_manager.auth_user_ldap(auth.username, auth.password)
if ab_security_manager.auth_type == AUTH_DB:
user = ab_security_manager.auth_user_db(auth.username, auth.password)

log.info("user: %s", user)
if user is not None:
login_user(user, remember=False)
return user


def _unauthorized(message: str) -> Response:
return Response(
message,
status=401,
headers={"WWW-Authenticate": 'Bearer realm="airflow"'},
mimetype="text/plain",
)


def _forbidden(message: str) -> Response:
return Response(
message,
status=403,
mimetype="text/plain",
)


def _keycloak_config() -> tuple[str, str, str, str]:
base_url = os.getenv("KEYCLOAK_BASE_URL", "").rstrip("/")
realm = os.getenv("KEYCLOAK_REALM", "")
client_id = os.getenv("KEYCLOAK_CLIENT_ID", "")
client_secret = os.getenv("KEYCLOAK_CLIENT_SECRET", "")

if not (base_url and realm and client_id and client_secret):
raise RuntimeError(
"Missing Keycloak env vars. Required: KEYCLOAK_BASE_URL, KEYCLOAK_REALM, "
"KEYCLOAK_CLIENT_ID, KEYCLOAK_CLIENT_SECRET"
)
return base_url, realm, client_id, client_secret


def _introspect_token(token: str) -> dict[str, Any]:
base_url, realm, client_id, client_secret = _keycloak_config()
url = f"{base_url}/realms/{realm}/protocol/openid-connect/token/introspect"

resp = requests.post(
url,
data={"token": token},
auth=(client_id, client_secret),
timeout=10,
)
if resp.status_code >= 400:
raise RuntimeError(f"Keycloak introspection failed: {resp.status_code} {resp.text}")
data = resp.json()
return data
Empty file removed plugins/.gitkeep
Empty file.
112 changes: 112 additions & 0 deletions plugins/keycloak_bearer_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Custom Airflow API auth backend: Keycloak Bearer tokens.

Configured via `[api] auth_backends = keycloak_bearer_auth` in `airflow.cfg`.

Docs: https://airflow.apache.org/docs/apache-airflow-providers-fab/1.5.4/auth-manager/api-authentication.html#roll-your-own-api-authentication
"""

from __future__ import annotations

import functools
import logging
import os
from typing import Any, Callable
import jwt

import requests
from flask import Response, request

log = logging.getLogger(__name__)

def init_app(app) -> None:
pass

def requires_authentication(fn: Callable) -> Callable:

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any):
token = _extract_bearer_token()
if not token:
return _unauthorized("Missing Bearer token")

try:
claims = _introspect_token(token)
except Exception:
log.exception("Keycloak token introspection failed")
return _unauthorized("Token introspection failed")

if not claims.get("active"):
return _unauthorized("Token is not active")

token_roles = _extract_roles(token)
allowed_roles = ["Admin", "User", "Dag_Launcher"]
if not any(role in token_roles for role in allowed_roles):
return _forbidden("Token is valid but missing required role")

return fn(*args, **kwargs)

return wrapper


def _extract_bearer_token() -> str | None:
auth = request.headers.get("Authorization", "")
if not auth:
return None
parts = auth.split(None, 1)
if len(parts) != 2 or parts[0].lower() != "bearer":
return None
return parts[1].strip() or None


def _unauthorized(message: str) -> Response:
return Response(
message,
status=401,
headers={"WWW-Authenticate": 'Bearer realm="airflow"'},
mimetype="text/plain",
)

def _forbidden(message: str) -> Response:
return Response(
message,
status=403,
mimetype="text/plain",
)


def _keycloak_config() -> tuple[str, str, str, str]:
base_url = os.getenv("KEYCLOAK_BASE_URL", "").rstrip("/")
realm = os.getenv("KEYCLOAK_REALM", "")
client_id = os.getenv("KEYCLOAK_CLIENT_ID", "")
client_secret = os.getenv("KEYCLOAK_CLIENT_SECRET", "")

if not (base_url and realm and client_id and client_secret):
raise RuntimeError(
"Missing Keycloak env vars. Required: KEYCLOAK_BASE_URL, KEYCLOAK_REALM, "
"KEYCLOAK_CLIENT_ID, KEYCLOAK_CLIENT_SECRET"
)
return base_url, realm, client_id, client_secret


def _introspect_token(token: str) -> dict[str, Any]:
base_url, realm, client_id, client_secret = _keycloak_config()
url = f"{base_url}/realms/{realm}/protocol/openid-connect/token/introspect"

resp = requests.post(
url,
data={"token": token},
auth=(client_id, client_secret),
timeout=10,
)
if resp.status_code >= 400:
raise RuntimeError(f"Keycloak introspection failed: {resp.status_code} {resp.text}")
data = resp.json()
return data


def _extract_roles(token: str) -> set[str]:
client_id = os.getenv("KEYCLOAK_CLIENT_ID", "")
decoded = jwt.decode(token, options={"verify_signature": False, "verify_aud": False})
roles = decoded.get("resource_access").get(client_id, {}).get("roles", [])
return roles
2 changes: 1 addition & 1 deletion sm2a-local-config/local_airflow.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[api]
auth_backends = airflow.api.auth.backend.basic_auth
auth_backends = airflow.api.auth.backend.session

[core]
executor = CeleryExecutor
Expand Down