Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions backend/infrahub/api/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from infrahub.auth.auth import (
ExternalIdentity,
SSOStateCache,
extract_sso_groups,
get_groups_from_provider,
signin_sso_account,
validate_auth_response,
Expand Down Expand Up @@ -145,9 +146,12 @@ async def token(

validate_auth_response(response=userinfo_response, provider_type="OAuth 2.0")
user_info = userinfo_response.json()
sso_groups = user_info.get("groups", []) or await get_groups_from_provider(
provider=provider, service=service, payload=payload, user_info=user_info
)
sso_groups = extract_sso_groups(
payload=user_info,
claim_key=provider.groups_claim,
provider_name=provider_name,
source="oauth2_userinfo",
) or await get_groups_from_provider(provider=provider, service=service, payload=payload, user_info=user_info)

log.info(
"SSO user authenticated",
Expand Down
30 changes: 26 additions & 4 deletions backend/infrahub/api/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from infrahub.auth.auth import (
ExternalIdentity,
SSOStateCache,
extract_sso_groups,
get_groups_from_provider,
signin_sso_account,
validate_auth_response,
Expand Down Expand Up @@ -189,9 +190,19 @@ async def token(
validate_auth_response(response=userinfo_response, provider_type="OIDC")
user_info: dict[str, Any] = userinfo_response.json()
sso_groups = (
user_info.get("groups")
extract_sso_groups(
payload=user_info,
claim_key=provider.groups_claim,
provider_name=provider_name,
source="oidc_userinfo",
)
or await _get_id_token_groups(
oidc_config=oidc_config, service=service, payload=payload, client_id=provider.client_id
oidc_config=oidc_config,
service=service,
payload=payload,
client_id=provider.client_id,
claim_key=provider.groups_claim,
provider_name=provider_name,
)
or await get_groups_from_provider(provider=provider, service=service, payload=payload, user_info=user_info)
)
Expand Down Expand Up @@ -256,7 +267,13 @@ async def token(


async def _get_id_token_groups(
oidc_config: OIDCDiscoveryConfig, service: InfrahubServices, payload: dict[str, Any], client_id: str
oidc_config: OIDCDiscoveryConfig,
service: InfrahubServices,
payload: dict[str, Any],
client_id: str,
*,
claim_key: str = "groups",
provider_name: str = "",
Comment on lines +274 to +276
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every caller passes claim_key and provider_name explicitly. The defaults ("groups" / "") only exist for tests. Suggest dropping the defaults and making both required.

) -> list[str]:
id_token = payload.get("id_token")
if not id_token:
Expand All @@ -278,4 +295,9 @@ async def _get_id_token_groups(
options={"verify_signature": False, "verify_aud": False, "verify_iss": False},
)

return decoded_token.get("groups", [])
return extract_sso_groups(
payload=decoded_token,
claim_key=claim_key,
provider_name=provider_name,
source="oidc_id_token",
)
46 changes: 46 additions & 0 deletions backend/infrahub/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from infrahub.log import get_logger

if TYPE_CHECKING:
from collections.abc import Mapping

import httpx

from infrahub.database import InfrahubDatabase
Expand Down Expand Up @@ -587,6 +589,50 @@ async def get_groups_from_provider(
return []


def extract_sso_groups(
*,
payload: Mapping[str, Any],
claim_key: str,
provider_name: str,
source: str,
) -> list[str]:
if claim_key not in payload:
log.warning(
"sso groups claim miss",
provider=provider_name,
source=source,
configured_claim=claim_key,
available_keys=sorted(payload.keys()),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The miss log dumps all top-level claim keys at WARN. Key names aren't usually PII, but URI-namespaced claims can embed tenant/customer identifiers (e.g. https://acme-corp.example/claims/...). Is WARN the right level for this, or would DEBUG be more appropriate?

miss_reason="absent",
)
return []

value = payload[claim_key]
if not isinstance(value, list):
log.warning(
"sso groups claim miss",
provider=provider_name,
source=source,
configured_claim=claim_key,
available_keys=sorted(payload.keys()),
miss_reason="not_list",
)
return []

if not all(isinstance(item, str) for item in value):
log.warning(
"sso groups claim miss",
provider=provider_name,
source=source,
configured_claim=claim_key,
available_keys=sorted(payload.keys()),
miss_reason="list_has_non_string",
)
return []

return value


def safe_get_response_body(response: httpx.Response, raise_error_on_empty_body: bool = True) -> str | dict[str, Any]:
"""Safely extract response body from HTTP response.

Expand Down
30 changes: 30 additions & 0 deletions backend/infrahub/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,21 @@ class SecurityOIDCBaseSettings(BaseSettings):
pkce_enabled: bool = Field(
default=True, description="Enable PKCE (RFC 7636) with S256 method for authorization code flow"
)
groups_claim: str = Field(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P3: New groups_claim logic duplicated in two base classes. Easy drift later. Extract one shared mixin/helper and reuse it.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At backend/infrahub/config.py, line 604:

<comment>New groups_claim logic duplicated in two base classes. Easy drift later. Extract one shared mixin/helper and reuse it.</comment>

<file context>
@@ -601,6 +601,21 @@ class SecurityOIDCBaseSettings(BaseSettings):
     pkce_enabled: bool = Field(
         default=True, description="Enable PKCE (RFC 7636) with S256 method for authorization code flow"
     )
+    groups_claim: str = Field(
+        default="groups",
+        description=(
</file context>

default="groups",
description=(
"Top-level key in the IdP claim payload from which the user's groups are read. "
"Defaults to `groups`. Set per provider when your IdP emits group memberships "
"under a different claim name (e.g., `roles`)."
),
)

@field_validator("groups_claim")
@classmethod
def _validate_groups_claim(cls, value: str) -> str:
if not value.strip():
raise ValueError("groups_claim must not be empty or whitespace-only")
return value


class SecurityOIDCSettings(SecurityOIDCBaseSettings):
Expand Down Expand Up @@ -657,6 +672,21 @@ class SecurityOAuth2BaseSettings(BaseSettings):
pkce_enabled: bool = Field(
default=True, description="Enable PKCE (RFC 7636) with S256 method for authorization code flow"
)
groups_claim: str = Field(
default="groups",
description=(
"Top-level key in the IdP claim payload from which the user's groups are read. "
"Defaults to `groups`. Set per provider when your IdP emits group memberships "
"under a different claim name (e.g., `roles`)."
),
)

@field_validator("groups_claim")
@classmethod
def _validate_groups_claim(cls, value: str) -> str:
if not value.strip():
raise ValueError("groups_claim must not be empty or whitespace-only")
return value


class SecurityOAuth2Settings(SecurityOAuth2BaseSettings):
Expand Down
2 changes: 2 additions & 0 deletions backend/tests/fixtures/config_files/sso_config_methods.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ userinfo_url = "http://localhost:8180/infrahub-users/infrahub/protocol/openid-co
display_label = "Keycloak Users"
icon = "mdi:security-lock-outline"
userinfo_method = "post"
groups_claim = "roles"

[security.oauth2_provider_settings.provider2]
client_id = "infrahub-admin-client"
Expand All @@ -30,6 +31,7 @@ discovery_url = "http://localhost:8180/realms/infrahub-users/.well-known/openid-
display_label = "OIDC Users"
icon = "mdi:security-lock-outline"
userinfo_method = "post"
groups_claim = "roles"


[security.oidc_provider_settings.provider2]
Expand Down
148 changes: 148 additions & 0 deletions backend/tests/unit/api/test_oauth2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import json

import httpx
from structlog.testing import capture_logs

from infrahub.auth.auth import extract_sso_groups
from infrahub.config import SecurityOAuth2Provider1
from infrahub.services import InfrahubServices
from tests.adapters.http import MemoryHTTP


def _build_provider(groups_claim: str = "groups") -> SecurityOAuth2Provider1:
return SecurityOAuth2Provider1(
client_id="infrahub-user-client",
client_secret="secret",
authorization_url="https://idp.example.com/auth",
token_url="https://idp.example.com/token",
userinfo_url="https://idp.example.com/userinfo",
groups_claim=groups_claim,
)


async def test_oauth2_userinfo_extracts_groups_from_custom_claim_key() -> None:
memory_http = MemoryHTTP()
service = await InfrahubServices.new(http=memory_http)
provider = _build_provider(groups_claim="roles")

userinfo_body = {
"sub": "u1",
"name": "Otto",
"email": "o@x.com",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using fake data, we should use a RFC 2606/BCP 32 reserved domain rather than domain owned by an actual company.

"roles": ["network-engineering"],
}
memory_http.add_get_response(
url=provider.userinfo_url,
response=httpx.Response(status_code=200, content=json.dumps(userinfo_body)),
)

userinfo_response = await service.http.get(provider.userinfo_url)
user_info = userinfo_response.json()

sso_groups = extract_sso_groups(
payload=user_info,
claim_key=provider.groups_claim,
provider_name="provider1",
source="oauth2_userinfo",
)

assert sso_groups == ["network-engineering"]


async def test_oauth2_userinfo_custom_claim_key_does_not_read_groups_key() -> None:
memory_http = MemoryHTTP()
service = await InfrahubServices.new(http=memory_http)
provider = _build_provider(groups_claim="roles")

userinfo_body = {
"sub": "u1",
"name": "Otto",
"email": "o@x.com",
"groups": ["legacy-group"],
}
memory_http.add_get_response(
url=provider.userinfo_url,
response=httpx.Response(status_code=200, content=json.dumps(userinfo_body)),
)

userinfo_response = await service.http.get(provider.userinfo_url)
user_info = userinfo_response.json()

with capture_logs() as records:
sso_groups = extract_sso_groups(
payload=user_info,
claim_key=provider.groups_claim,
provider_name="provider1",
source="oauth2_userinfo",
)

assert sso_groups == []
warnings = [r for r in records if r.get("event") == "sso groups claim miss"]
assert len(warnings) == 1
assert warnings[0]["miss_reason"] == "absent"
assert warnings[0]["source"] == "oauth2_userinfo"


async def test_default_claim_key_preserves_existing_behavior() -> None:
memory_http = MemoryHTTP()
service = await InfrahubServices.new(http=memory_http)
provider = _build_provider()

assert provider.groups_claim == "groups"

userinfo_body = {
"sub": "u",
"name": "Otto",
"email": "o@x.com",
"groups": ["admin-otter"],
}
memory_http.add_get_response(
url=provider.userinfo_url,
response=httpx.Response(status_code=200, content=json.dumps(userinfo_body)),
)

userinfo_response = await service.http.get(provider.userinfo_url)
user_info = userinfo_response.json()

with capture_logs() as records:
sso_groups = extract_sso_groups(
payload=user_info,
claim_key=provider.groups_claim,
provider_name="provider1",
source="oauth2_userinfo",
)

assert sso_groups == ["admin-otter"]
assert not any(r.get("event") == "sso groups claim miss" for r in records)


async def test_oauth2_and_oidc_with_different_claim_keys_coexist() -> None:
oauth2_provider = _build_provider(groups_claim="memberships")
oauth2_payload = {
"sub": "u1",
"name": "Otto",
"email": "o@x.com",
"memberships": ["membership-a"],
}
oauth2_groups = extract_sso_groups(
payload=oauth2_payload,
claim_key=oauth2_provider.groups_claim,
provider_name="provider1",
source="oauth2_userinfo",
)

oidc_payload = {
"sub": "u2",
"name": "Otto",
"email": "o@x.com",
"roles": ["role-x"],
}
oidc_groups = extract_sso_groups(
payload=oidc_payload,
claim_key="roles",
provider_name="provider2",
source="oidc_id_token",
)

assert oauth2_groups == ["membership-a"]
assert oidc_groups == ["role-x"]
Loading
Loading