diff --git a/backend/infrahub/api/oauth2.py b/backend/infrahub/api/oauth2.py index f008f645eab..8d01c4814f7 100644 --- a/backend/infrahub/api/oauth2.py +++ b/backend/infrahub/api/oauth2.py @@ -15,6 +15,7 @@ from infrahub.auth.auth import ( ExternalIdentity, SSOStateCache, + extract_sso_groups, get_groups_from_provider, signin_sso_account, validate_auth_response, @@ -145,9 +146,12 @@ async def token( validate_auth_response(response=userinfo_response, provider_type="OAuth 2.0") user_info = userinfo_response.json() - sso_groups = user_info.get("groups", []) or await get_groups_from_provider( - provider=provider, service=service, payload=payload, user_info=user_info - ) + sso_groups = extract_sso_groups( + payload=user_info, + claim_key=provider.groups_claim, + provider_name=provider_name, + source="oauth2_userinfo", + ) or await get_groups_from_provider(provider=provider, service=service, payload=payload, user_info=user_info) log.info( "SSO user authenticated", diff --git a/backend/infrahub/api/oidc.py b/backend/infrahub/api/oidc.py index 7831064df04..0f7225c5476 100644 --- a/backend/infrahub/api/oidc.py +++ b/backend/infrahub/api/oidc.py @@ -17,6 +17,7 @@ from infrahub.auth.auth import ( ExternalIdentity, SSOStateCache, + extract_sso_groups, get_groups_from_provider, signin_sso_account, validate_auth_response, @@ -189,9 +190,19 @@ async def token( validate_auth_response(response=userinfo_response, provider_type="OIDC") user_info: dict[str, Any] = userinfo_response.json() sso_groups = ( - user_info.get("groups") + extract_sso_groups( + payload=user_info, + claim_key=provider.groups_claim, + provider_name=provider_name, + source="oidc_userinfo", + ) or await _get_id_token_groups( - oidc_config=oidc_config, service=service, payload=payload, client_id=provider.client_id + oidc_config=oidc_config, + service=service, + payload=payload, + client_id=provider.client_id, + claim_key=provider.groups_claim, + provider_name=provider_name, ) or await get_groups_from_provider(provider=provider, service=service, payload=payload, user_info=user_info) ) @@ -256,7 +267,13 @@ async def token( async def _get_id_token_groups( - oidc_config: OIDCDiscoveryConfig, service: InfrahubServices, payload: dict[str, Any], client_id: str + oidc_config: OIDCDiscoveryConfig, + service: InfrahubServices, + payload: dict[str, Any], + client_id: str, + *, + claim_key: str = "groups", + provider_name: str = "", ) -> list[str]: id_token = payload.get("id_token") if not id_token: @@ -278,4 +295,9 @@ async def _get_id_token_groups( options={"verify_signature": False, "verify_aud": False, "verify_iss": False}, ) - return decoded_token.get("groups", []) + return extract_sso_groups( + payload=decoded_token, + claim_key=claim_key, + provider_name=provider_name, + source="oidc_id_token", + ) diff --git a/backend/infrahub/auth/auth.py b/backend/infrahub/auth/auth.py index c79e4cbe267..929e96504db 100644 --- a/backend/infrahub/auth/auth.py +++ b/backend/infrahub/auth/auth.py @@ -37,6 +37,8 @@ from infrahub.log import get_logger if TYPE_CHECKING: + from collections.abc import Mapping + import httpx from infrahub.database import InfrahubDatabase @@ -587,6 +589,50 @@ async def get_groups_from_provider( return [] +def extract_sso_groups( + *, + payload: Mapping[str, Any], + claim_key: str, + provider_name: str, + source: str, +) -> list[str]: + if claim_key not in payload: + log.warning( + "sso groups claim miss", + provider=provider_name, + source=source, + configured_claim=claim_key, + available_keys=sorted(payload.keys()), + miss_reason="absent", + ) + return [] + + value = payload[claim_key] + if not isinstance(value, list): + log.warning( + "sso groups claim miss", + provider=provider_name, + source=source, + configured_claim=claim_key, + available_keys=sorted(payload.keys()), + miss_reason="not_list", + ) + return [] + + if not all(isinstance(item, str) for item in value): + log.warning( + "sso groups claim miss", + provider=provider_name, + source=source, + configured_claim=claim_key, + available_keys=sorted(payload.keys()), + miss_reason="list_has_non_string", + ) + return [] + + return value + + def safe_get_response_body(response: httpx.Response, raise_error_on_empty_body: bool = True) -> str | dict[str, Any]: """Safely extract response body from HTTP response. diff --git a/backend/infrahub/config.py b/backend/infrahub/config.py index 2493953c65f..9b254da2c6a 100644 --- a/backend/infrahub/config.py +++ b/backend/infrahub/config.py @@ -602,6 +602,21 @@ class SecurityOIDCBaseSettings(BaseSettings): pkce_enabled: bool = Field( default=True, description="Enable PKCE (RFC 7636) with S256 method for authorization code flow" ) + groups_claim: str = Field( + default="groups", + description=( + "Top-level key in the IdP claim payload from which the user's groups are read. " + "Defaults to `groups`. Set per provider when your IdP emits group memberships " + "under a different claim name (e.g., `roles`)." + ), + ) + + @field_validator("groups_claim") + @classmethod + def _validate_groups_claim(cls, value: str) -> str: + if not value.strip(): + raise ValueError("groups_claim must not be empty or whitespace-only") + return value class SecurityOIDCSettings(SecurityOIDCBaseSettings): @@ -657,6 +672,21 @@ class SecurityOAuth2BaseSettings(BaseSettings): pkce_enabled: bool = Field( default=True, description="Enable PKCE (RFC 7636) with S256 method for authorization code flow" ) + groups_claim: str = Field( + default="groups", + description=( + "Top-level key in the IdP claim payload from which the user's groups are read. " + "Defaults to `groups`. Set per provider when your IdP emits group memberships " + "under a different claim name (e.g., `roles`)." + ), + ) + + @field_validator("groups_claim") + @classmethod + def _validate_groups_claim(cls, value: str) -> str: + if not value.strip(): + raise ValueError("groups_claim must not be empty or whitespace-only") + return value class SecurityOAuth2Settings(SecurityOAuth2BaseSettings): diff --git a/backend/tests/fixtures/config_files/sso_config_methods.toml b/backend/tests/fixtures/config_files/sso_config_methods.toml index a780f8ea2d7..23420ee4fdf 100644 --- a/backend/tests/fixtures/config_files/sso_config_methods.toml +++ b/backend/tests/fixtures/config_files/sso_config_methods.toml @@ -12,6 +12,7 @@ userinfo_url = "http://localhost:8180/infrahub-users/infrahub/protocol/openid-co display_label = "Keycloak Users" icon = "mdi:security-lock-outline" userinfo_method = "post" +groups_claim = "roles" [security.oauth2_provider_settings.provider2] client_id = "infrahub-admin-client" @@ -30,6 +31,7 @@ discovery_url = "http://localhost:8180/realms/infrahub-users/.well-known/openid- display_label = "OIDC Users" icon = "mdi:security-lock-outline" userinfo_method = "post" +groups_claim = "roles" [security.oidc_provider_settings.provider2] diff --git a/backend/tests/unit/api/test_oauth2.py b/backend/tests/unit/api/test_oauth2.py new file mode 100644 index 00000000000..5740e982a25 --- /dev/null +++ b/backend/tests/unit/api/test_oauth2.py @@ -0,0 +1,148 @@ +import json + +import httpx +from structlog.testing import capture_logs + +from infrahub.auth.auth import extract_sso_groups +from infrahub.config import SecurityOAuth2Provider1 +from infrahub.services import InfrahubServices +from tests.adapters.http import MemoryHTTP + + +def _build_provider(groups_claim: str = "groups") -> SecurityOAuth2Provider1: + return SecurityOAuth2Provider1( + client_id="infrahub-user-client", + client_secret="secret", + authorization_url="https://idp.example.com/auth", + token_url="https://idp.example.com/token", + userinfo_url="https://idp.example.com/userinfo", + groups_claim=groups_claim, + ) + + +async def test_oauth2_userinfo_extracts_groups_from_custom_claim_key() -> None: + memory_http = MemoryHTTP() + service = await InfrahubServices.new(http=memory_http) + provider = _build_provider(groups_claim="roles") + + userinfo_body = { + "sub": "u1", + "name": "Otto", + "email": "o@x.com", + "roles": ["network-engineering"], + } + memory_http.add_get_response( + url=provider.userinfo_url, + response=httpx.Response(status_code=200, content=json.dumps(userinfo_body)), + ) + + userinfo_response = await service.http.get(provider.userinfo_url) + user_info = userinfo_response.json() + + sso_groups = extract_sso_groups( + payload=user_info, + claim_key=provider.groups_claim, + provider_name="provider1", + source="oauth2_userinfo", + ) + + assert sso_groups == ["network-engineering"] + + +async def test_oauth2_userinfo_custom_claim_key_does_not_read_groups_key() -> None: + memory_http = MemoryHTTP() + service = await InfrahubServices.new(http=memory_http) + provider = _build_provider(groups_claim="roles") + + userinfo_body = { + "sub": "u1", + "name": "Otto", + "email": "o@x.com", + "groups": ["legacy-group"], + } + memory_http.add_get_response( + url=provider.userinfo_url, + response=httpx.Response(status_code=200, content=json.dumps(userinfo_body)), + ) + + userinfo_response = await service.http.get(provider.userinfo_url) + user_info = userinfo_response.json() + + with capture_logs() as records: + sso_groups = extract_sso_groups( + payload=user_info, + claim_key=provider.groups_claim, + provider_name="provider1", + source="oauth2_userinfo", + ) + + assert sso_groups == [] + warnings = [r for r in records if r.get("event") == "sso groups claim miss"] + assert len(warnings) == 1 + assert warnings[0]["miss_reason"] == "absent" + assert warnings[0]["source"] == "oauth2_userinfo" + + +async def test_default_claim_key_preserves_existing_behavior() -> None: + memory_http = MemoryHTTP() + service = await InfrahubServices.new(http=memory_http) + provider = _build_provider() + + assert provider.groups_claim == "groups" + + userinfo_body = { + "sub": "u", + "name": "Otto", + "email": "o@x.com", + "groups": ["admin-otter"], + } + memory_http.add_get_response( + url=provider.userinfo_url, + response=httpx.Response(status_code=200, content=json.dumps(userinfo_body)), + ) + + userinfo_response = await service.http.get(provider.userinfo_url) + user_info = userinfo_response.json() + + with capture_logs() as records: + sso_groups = extract_sso_groups( + payload=user_info, + claim_key=provider.groups_claim, + provider_name="provider1", + source="oauth2_userinfo", + ) + + assert sso_groups == ["admin-otter"] + assert not any(r.get("event") == "sso groups claim miss" for r in records) + + +async def test_oauth2_and_oidc_with_different_claim_keys_coexist() -> None: + oauth2_provider = _build_provider(groups_claim="memberships") + oauth2_payload = { + "sub": "u1", + "name": "Otto", + "email": "o@x.com", + "memberships": ["membership-a"], + } + oauth2_groups = extract_sso_groups( + payload=oauth2_payload, + claim_key=oauth2_provider.groups_claim, + provider_name="provider1", + source="oauth2_userinfo", + ) + + oidc_payload = { + "sub": "u2", + "name": "Otto", + "email": "o@x.com", + "roles": ["role-x"], + } + oidc_groups = extract_sso_groups( + payload=oidc_payload, + claim_key="roles", + provider_name="provider2", + source="oidc_id_token", + ) + + assert oauth2_groups == ["membership-a"] + assert oidc_groups == ["role-x"] diff --git a/backend/tests/unit/api/test_oidc.py b/backend/tests/unit/api/test_oidc.py index da647706656..16a40199f11 100644 --- a/backend/tests/unit/api/test_oidc.py +++ b/backend/tests/unit/api/test_oidc.py @@ -7,6 +7,7 @@ import httpx from jwcrypto import jwk, jwt from pydantic import HttpUrl +from structlog.testing import capture_logs from infrahub.api.oidc import OIDCDiscoveryConfig, _get_id_token_groups from infrahub.services import InfrahubServices @@ -114,7 +115,15 @@ def __init__(self) -> None: ] } - def generate_token_response(self, username: str, groups: list[str], client_id: str, issuer: str) -> dict[str, Any]: + def generate_token_response( + self, + username: str, + groups: list[str], + client_id: str, + issuer: str, + *, + claim_key: str = "groups", + ) -> dict[str, Any]: current_time = int(time.time()) expiration_time = current_time + 600 @@ -128,7 +137,7 @@ def generate_token_response(self, username: str, groups: list[str], client_id: s "iat": current_time, "auth_time": current_time, "name": username, - "groups": groups, + claim_key: groups, }, ) id_token.make_signed_token(self.key) @@ -178,3 +187,146 @@ def generate_token_response(self, username: str, groups: list[str], client_id: s ), }, ) + + +async def test_get_id_token_groups_with_custom_claim_key() -> None: + memory_http = MemoryHTTP() + service = await InfrahubServices.new(http=memory_http) + client_id = "testing-oidc-roles" + + helper = OIDCTestHelper() + token_response = helper.generate_token_response( + username="testuser", + groups=["network-engineering"], + client_id=client_id, + issuer=str(OIDC_CONFIG.issuer), + claim_key="roles", + ) + + memory_http.add_get_response( + url=str(OIDC_CONFIG.jwks_uri), + response=httpx.Response(status_code=200, content=json.dumps(helper.jwks_payload)), + ) + + groups = await _get_id_token_groups( + oidc_config=OIDC_CONFIG, + service=service, + payload=token_response, + client_id=client_id, + claim_key="roles", + provider_name="provider1", + ) + + assert groups == ["network-engineering"] + + +async def test_get_id_token_groups_with_custom_claim_key_miss_emits_warning() -> None: + memory_http = MemoryHTTP() + service = await InfrahubServices.new(http=memory_http) + client_id = "testing-oidc-miss" + + helper = OIDCTestHelper() + token_response = helper.generate_token_response( + username="testuser", + groups=["operators"], + client_id=client_id, + issuer=str(OIDC_CONFIG.issuer), + ) + + memory_http.add_get_response( + url=str(OIDC_CONFIG.jwks_uri), + response=httpx.Response(status_code=200, content=json.dumps(helper.jwks_payload)), + ) + + with capture_logs() as records: + groups = await _get_id_token_groups( + oidc_config=OIDC_CONFIG, + service=service, + payload=token_response, + client_id=client_id, + claim_key="roles", + provider_name="provider1", + ) + + assert groups == [] + warnings = [r for r in records if r.get("event") == "sso groups claim miss"] + assert len(warnings) == 1 + assert warnings[0]["source"] == "oidc_id_token" + assert warnings[0]["configured_claim"] == "roles" + assert warnings[0]["miss_reason"] == "absent" + assert warnings[0]["provider"] == "provider1" + + +async def test_default_claim_key_preserves_existing_behavior() -> None: + memory_http = MemoryHTTP() + service = await InfrahubServices.new(http=memory_http) + client_id = "testing-oidc-default" + + helper = OIDCTestHelper() + token_response = helper.generate_token_response( + username="testuser", + groups=["operators"], + client_id=client_id, + issuer=str(OIDC_CONFIG.issuer), + ) + + memory_http.add_get_response( + url=str(OIDC_CONFIG.jwks_uri), + response=httpx.Response(status_code=200, content=json.dumps(helper.jwks_payload)), + ) + + with capture_logs() as records: + groups = await _get_id_token_groups( + oidc_config=OIDC_CONFIG, + service=service, + payload=token_response, + client_id=client_id, + ) + + assert groups == ["operators"] + assert not any(r.get("event") == "sso groups claim miss" for r in records) + + +async def test_two_providers_use_independent_claim_keys() -> None: + memory_http = MemoryHTTP() + service = await InfrahubServices.new(http=memory_http) + client_id_1 = "testing-oidc-p1" + client_id_2 = "testing-oidc-p2" + + helper = OIDCTestHelper() + token_response_provider1 = helper.generate_token_response( + username="alice", + groups=["ops"], + client_id=client_id_1, + issuer=str(OIDC_CONFIG.issuer), + claim_key="roles", + ) + token_response_provider2 = helper.generate_token_response( + username="bob", + groups=["dev"], + client_id=client_id_2, + issuer=str(OIDC_CONFIG.issuer), + ) + + memory_http.add_get_response( + url=str(OIDC_CONFIG.jwks_uri), + response=httpx.Response(status_code=200, content=json.dumps(helper.jwks_payload)), + ) + + groups_p2 = await _get_id_token_groups( + oidc_config=OIDC_CONFIG, + service=service, + payload=token_response_provider2, + client_id=client_id_2, + ) + groups_p1 = await _get_id_token_groups( + oidc_config=OIDC_CONFIG, + service=service, + payload=token_response_provider1, + client_id=client_id_1, + claim_key="roles", + provider_name="provider1", + ) + + assert groups_p1 == ["ops"] + assert groups_p2 == ["dev"] diff --git a/backend/tests/unit/api/test_sso_groups_claim.py b/backend/tests/unit/api/test_sso_groups_claim.py new file mode 100644 index 00000000000..b2d037dd676 --- /dev/null +++ b/backend/tests/unit/api/test_sso_groups_claim.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass +from typing import Any + +import pytest +from structlog.testing import capture_logs + +from infrahub.auth.auth import extract_sso_groups + + +@dataclass +class HitCase: + name: str + claim_key: str + payload: dict[str, Any] + expected: list[str] + + +HIT_CASES: list[HitCase] = [ + HitCase( + name="default_key_with_groups", + claim_key="groups", + payload={"sub": "u1", "name": "Otto", "email": "o@x.com", "groups": ["admins"]}, + expected=["admins"], + ), + HitCase( + name="default_key_empty_list", + claim_key="groups", + payload={"sub": "u1", "groups": []}, + expected=[], + ), + HitCase( + name="custom_key_roles", + claim_key="roles", + payload={"sub": "u1", "name": "Otto", "email": "o@x.com", "roles": ["network-engineering"]}, + expected=["network-engineering"], + ), + HitCase( + name="custom_key_memberships", + claim_key="memberships", + payload={"sub": "u1", "memberships": ["g1", "g2", "g3"]}, + expected=["g1", "g2", "g3"], + ), + HitCase( + name="namespaced_uri_key", + claim_key="https://example.com/claims/groups", + payload={"sub": "u1", "https://example.com/claims/groups": ["ops"]}, + expected=["ops"], + ), +] + + +@pytest.mark.parametrize("case", HIT_CASES, ids=lambda c: c.name) +def test_extract_sso_groups_hit_returns_list_verbatim(case: HitCase) -> None: + with capture_logs() as records: + result = extract_sso_groups( + payload=case.payload, + claim_key=case.claim_key, + provider_name="provider1", + source="oidc_userinfo", + ) + assert result == case.expected + assert all(record.get("event") != "sso groups claim miss" for record in records) + + +@dataclass +class MissCase: + name: str + claim_key: str + payload: dict[str, Any] + expected_reason: str + + +MISS_CASES: list[MissCase] = [ + MissCase( + name="absent_key", + claim_key="roles", + payload={"sub": "u1", "groups": ["admins"]}, + expected_reason="absent", + ), + MissCase( + name="value_is_string", + claim_key="groups", + payload={"sub": "u1", "groups": "admins"}, + expected_reason="not_list", + ), + MissCase( + name="value_is_int", + claim_key="groups", + payload={"sub": "u1", "groups": 42}, + expected_reason="not_list", + ), + MissCase( + name="value_is_dict", + claim_key="groups", + payload={"sub": "u1", "groups": {"a": "b"}}, + expected_reason="not_list", + ), + MissCase( + name="value_is_none", + claim_key="groups", + payload={"sub": "u1", "groups": None}, + expected_reason="not_list", + ), + MissCase( + name="list_with_int", + claim_key="groups", + payload={"sub": "u1", "groups": ["admin", 7]}, + expected_reason="list_has_non_string", + ), + MissCase( + name="list_with_dict", + claim_key="groups", + payload={"sub": "u1", "groups": [{"name": "admin"}]}, + expected_reason="list_has_non_string", + ), +] + + +@pytest.mark.parametrize("case", MISS_CASES, ids=lambda c: c.name) +def test_extract_sso_groups_miss_returns_empty_and_warns(case: MissCase) -> None: + with capture_logs() as records: + result = extract_sso_groups( + payload=case.payload, + claim_key=case.claim_key, + provider_name="provider1", + source="oidc_userinfo", + ) + + assert result == [] + + warnings = [r for r in records if r.get("event") == "sso groups claim miss"] + assert len(warnings) == 1 + warning = warnings[0] + assert warning["log_level"] == "warning" + assert warning["provider"] == "provider1" + assert warning["source"] == "oidc_userinfo" + assert warning["configured_claim"] == case.claim_key + assert warning["available_keys"] == sorted(case.payload.keys()) + assert warning["miss_reason"] == case.expected_reason + + +def test_warning_never_includes_payload_values() -> None: + payload = { + "sub": "user-12345", + "email": "otter@example.com", + "name": "Otto the Otter", + "aud": "client-abc", + "iss": "https://idp.example.com/realms/infrahub", + "fake_token": "eyJhbGciOiJSUzI1NiJ9.payload.signature", + } + sensitive_values = list(payload.values()) + + with capture_logs() as records: + extract_sso_groups( + payload=payload, + claim_key="roles", + provider_name="provider1", + source="oidc_userinfo", + ) + + warnings = [r for r in records if r.get("event") == "sso groups claim miss"] + assert len(warnings) == 1 + serialized = repr(warnings[0]) + for value in sensitive_values: + assert value not in serialized + + +def test_every_miss_emits_warning_no_throttling() -> None: + payload_miss = {"sub": "u", "groups": ["admins"]} + payload_hit = {"sub": "u", "roles": ["ops"]} + + with capture_logs() as records: + for _ in range(3): + extract_sso_groups( + payload=payload_miss, + claim_key="roles", + provider_name="provider1", + source="oidc_userinfo", + ) + extract_sso_groups( + payload=payload_hit, + claim_key="roles", + provider_name="provider1", + source="oidc_userinfo", + ) + extract_sso_groups( + payload=payload_miss, + claim_key="roles", + provider_name="provider1", + source="oidc_userinfo", + ) + + warnings = [r for r in records if r.get("event") == "sso groups claim miss"] + assert len(warnings) == 4 diff --git a/backend/tests/unit/config/test_config.py b/backend/tests/unit/config/test_config.py index dbbf8c98d64..a1488865a40 100644 --- a/backend/tests/unit/config/test_config.py +++ b/backend/tests/unit/config/test_config.py @@ -9,6 +9,10 @@ SETTINGS, GitSettings, MainSettings, + SecurityOAuth2Provider1, + SecurityOAuth2Provider2, + SecurityOIDCProvider1, + SecurityOIDCProvider2, Settings, StorageSettings, UserInfoMethod, @@ -70,3 +74,72 @@ def test_storage_max_file_size_environment_variable() -> None: with patch.dict(os.environ, {"INFRAHUB_STORAGE_MAX_FILE_SIZE": "75"}): assert StorageSettings().max_file_size == 75 assert isinstance(SETTINGS.storage.max_file_size, int) + + +def _build_oauth2_provider(groups_claim: str = "groups") -> SecurityOAuth2Provider1: + return SecurityOAuth2Provider1( + client_id="infrahub-client", + client_secret="secret", + authorization_url="https://idp.example.com/auth", + token_url="https://idp.example.com/token", + userinfo_url="https://idp.example.com/userinfo", + groups_claim=groups_claim, + ) + + +def _build_oauth2_provider_2(groups_claim: str = "groups") -> SecurityOAuth2Provider2: + return SecurityOAuth2Provider2( + client_id="infrahub-client", + client_secret="secret", + authorization_url="https://idp.example.com/auth", + token_url="https://idp.example.com/token", + userinfo_url="https://idp.example.com/userinfo", + groups_claim=groups_claim, + ) + + +def _build_oidc_provider(groups_claim: str = "groups") -> SecurityOIDCProvider1: + return SecurityOIDCProvider1( + client_id="infrahub-client", + client_secret="secret", + discovery_url="https://idp.example.com/.well-known/openid-configuration", + groups_claim=groups_claim, + ) + + +def _build_oidc_provider_2(groups_claim: str = "groups") -> SecurityOIDCProvider2: + return SecurityOIDCProvider2( + client_id="infrahub-client", + client_secret="secret", + discovery_url="https://idp.example.com/.well-known/openid-configuration", + groups_claim=groups_claim, + ) + + +def test_groups_claim_default_is_groups() -> None: + assert _build_oauth2_provider().groups_claim == "groups" + assert _build_oauth2_provider_2().groups_claim == "groups" + assert _build_oidc_provider().groups_claim == "groups" + assert _build_oidc_provider_2().groups_claim == "groups" + + +@pytest.mark.parametrize("empty_value", ["", " ", "\t", "\n", " \t\n "]) +def test_groups_claim_empty_string_is_rejected_at_startup_oauth2(empty_value: str) -> None: + with pytest.raises(ValidationError, match=r"groups_claim must not be empty or whitespace-only"): + _build_oauth2_provider(groups_claim=empty_value) + + +@pytest.mark.parametrize("empty_value", ["", " ", "\t", "\n", " \t\n "]) +def test_groups_claim_empty_string_is_rejected_at_startup_oidc(empty_value: str) -> None: + with pytest.raises(ValidationError, match=r"groups_claim must not be empty or whitespace-only"): + _build_oidc_provider(groups_claim=empty_value) + + +def test_fixture_loaded_providers_have_expected_groups_claim(helper: TestHelper) -> None: + config_file = str(helper.get_fixtures_dir() / "config_files" / "sso_config_methods.toml") + config = load(config_file_name=config_file) + + assert config.security.get_oauth2_provider("provider1").groups_claim == "roles" + assert config.security.get_oauth2_provider("provider2").groups_claim == "groups" + assert config.security.get_oidc_provider("provider1").groups_claim == "roles" + assert config.security.get_oidc_provider("provider2").groups_claim == "groups" diff --git a/changelog/+sso-groups-claim.added.md b/changelog/+sso-groups-claim.added.md new file mode 100644 index 00000000000..00834af6db1 --- /dev/null +++ b/changelog/+sso-groups-claim.added.md @@ -0,0 +1 @@ +Per-provider `groups_claim` setting for OAuth2 and OIDC providers: configure the JSON key used to extract the user's groups from the IdP claim payload (default `groups`). See the SSO guide for details. diff --git a/docs/archive/guides/sso.mdx b/docs/archive/guides/sso.mdx index ff45043131b..b9835130913 100644 --- a/docs/archive/guides/sso.mdx +++ b/docs/archive/guides/sso.mdx @@ -381,6 +381,39 @@ Configure your identity provider application to include group information in the Refer to your provider's documentation for instructions on "group claims" or "configuring OAuth2/OIDC group mappings". ::: +##### Customizing the groups claim key + +By default, Infrahub reads the user's groups from a top-level `groups` key in the identity provider's claim payload (the OAuth2 userinfo response, the OIDC userinfo response, or the decoded OIDC `id_token`). If your identity provider emits group memberships under a different key — for example `roles` or `memberships` — set the `groups_claim` field on the provider settings. + + + + +```toml +[security.oidc_provider_settings.provider1] +client_id = "infrahub-client" +client_secret = "..." +discovery_url = "https://idp.example.com/.well-known/openid-configuration" +groups_claim = "roles" +``` + + + + +```bash +# OIDC provider +export INFRAHUB_OIDC_PROVIDER1_GROUPS_CLAIM='roles' + +# OAuth2 provider +export INFRAHUB_OAUTH2_PROVIDER1_GROUPS_CLAIM='roles' +``` + + + + +:::info +The value of `groups_claim` is matched literally against the top-level keys of the claim payload — no path expressions, no namespace expansion, no case-folding. The same key is used at all three extraction points (OAuth2 userinfo, OIDC userinfo, OIDC `id_token`). +::: + #### Step 2: Create corresponding groups in Infrahub Create groups in Infrahub that match the groups sent by your identity provider. diff --git a/docs/docs/reference/configuration.mdx b/docs/docs/reference/configuration.mdx index 570e8863add..4c18f6af5b5 100644 --- a/docs/docs/reference/configuration.mdx +++ b/docs/docs/reference/configuration.mdx @@ -212,6 +212,7 @@ Configuration settings for the message bus. | `INFRAHUB_OAUTH2_GOOGLE_ICON` | None | string | mdi:google | | `INFRAHUB_OAUTH2_GOOGLE_USERINFO_METHOD` | None | string (post, get) | get | | `INFRAHUB_OAUTH2_GOOGLE_PKCE_ENABLED` | Enable PKCE (RFC 7636) with S256 method for authorization code flow | boolean | True | +| `INFRAHUB_OAUTH2_GOOGLE_GROUPS_CLAIM` | Top-level key in the IdP claim payload from which the user's groups are read. Defaults to `groups`. Set per provider when your IdP emits group memberships under a different claim name (e.g., `roles`). | string | groups | | `INFRAHUB_OAUTH2_GOOGLE_CLIENT_ID` | Client ID of the application created in the auth provider | string | None | | `INFRAHUB_OAUTH2_GOOGLE_CLIENT_SECRET` | Client secret as defined in auth provider | None | None | | `INFRAHUB_OAUTH2_GOOGLE_AUTHORIZATION_URL` | None | string | https://accounts.google.com/o/oauth2/auth | @@ -224,6 +225,7 @@ Configuration settings for the message bus. | `INFRAHUB_OAUTH2_PROVIDER1_ICON` | None | string | mdi:account-key | | `INFRAHUB_OAUTH2_PROVIDER1_USERINFO_METHOD` | None | string (post, get) | get | | `INFRAHUB_OAUTH2_PROVIDER1_PKCE_ENABLED` | Enable PKCE (RFC 7636) with S256 method for authorization code flow | boolean | True | +| `INFRAHUB_OAUTH2_PROVIDER1_GROUPS_CLAIM` | Top-level key in the IdP claim payload from which the user's groups are read. Defaults to `groups`. Set per provider when your IdP emits group memberships under a different claim name (e.g., `roles`). | string | groups | | `INFRAHUB_OAUTH2_PROVIDER1_CLIENT_ID` | Client ID of the application created in the auth provider | string | None | | `INFRAHUB_OAUTH2_PROVIDER1_CLIENT_SECRET` | Client secret as defined in auth provider | None | None | | `INFRAHUB_OAUTH2_PROVIDER1_AUTHORIZATION_URL` | None | string | None | @@ -234,6 +236,7 @@ Configuration settings for the message bus. | `INFRAHUB_OAUTH2_PROVIDER2_ICON` | None | string | mdi:account-key | | `INFRAHUB_OAUTH2_PROVIDER2_USERINFO_METHOD` | None | string (post, get) | get | | `INFRAHUB_OAUTH2_PROVIDER2_PKCE_ENABLED` | Enable PKCE (RFC 7636) with S256 method for authorization code flow | boolean | True | +| `INFRAHUB_OAUTH2_PROVIDER2_GROUPS_CLAIM` | Top-level key in the IdP claim payload from which the user's groups are read. Defaults to `groups`. Set per provider when your IdP emits group memberships under a different claim name (e.g., `roles`). | string | groups | | `INFRAHUB_OAUTH2_PROVIDER2_CLIENT_ID` | Client ID of the application created in the auth provider | string | None | | `INFRAHUB_OAUTH2_PROVIDER2_CLIENT_SECRET` | Client secret as defined in auth provider | None | None | | `INFRAHUB_OAUTH2_PROVIDER2_AUTHORIZATION_URL` | None | string | None | @@ -250,6 +253,7 @@ Configuration settings for the message bus. | `INFRAHUB_OIDC_GOOGLE_DISPLAY_LABEL` | None | string | Google | | `INFRAHUB_OIDC_GOOGLE_USERINFO_METHOD` | None | string (post, get) | get | | `INFRAHUB_OIDC_GOOGLE_PKCE_ENABLED` | Enable PKCE (RFC 7636) with S256 method for authorization code flow | boolean | True | +| `INFRAHUB_OIDC_GOOGLE_GROUPS_CLAIM` | Top-level key in the IdP claim payload from which the user's groups are read. Defaults to `groups`. Set per provider when your IdP emits group memberships under a different claim name (e.g., `roles`). | string | groups | | `INFRAHUB_OIDC_GOOGLE_CLIENT_ID` | Client ID of the application created in the auth provider | string | None | | `INFRAHUB_OIDC_GOOGLE_CLIENT_SECRET` | Client secret as defined in auth provider | None | None | | `INFRAHUB_OIDC_GOOGLE_DISCOVERY_URL` | None | string | https://accounts.google.com/.well-known/openid-configuration | @@ -260,6 +264,7 @@ Configuration settings for the message bus. | `INFRAHUB_OIDC_PROVIDER1_DISPLAY_LABEL` | None | string | Single Sign on | | `INFRAHUB_OIDC_PROVIDER1_USERINFO_METHOD` | None | string (post, get) | get | | `INFRAHUB_OIDC_PROVIDER1_PKCE_ENABLED` | Enable PKCE (RFC 7636) with S256 method for authorization code flow | boolean | True | +| `INFRAHUB_OIDC_PROVIDER1_GROUPS_CLAIM` | Top-level key in the IdP claim payload from which the user's groups are read. Defaults to `groups`. Set per provider when your IdP emits group memberships under a different claim name (e.g., `roles`). | string | groups | | `INFRAHUB_OIDC_PROVIDER1_CLIENT_ID` | Client ID of the application created in the auth provider | string | None | | `INFRAHUB_OIDC_PROVIDER1_CLIENT_SECRET` | Client secret as defined in auth provider | None | None | | `INFRAHUB_OIDC_PROVIDER1_DISCOVERY_URL` | The OIDC discovery URL xyz/.well-known/openid-configuration | string | None | @@ -268,6 +273,7 @@ Configuration settings for the message bus. | `INFRAHUB_OIDC_PROVIDER2_DISPLAY_LABEL` | None | string | Single Sign on | | `INFRAHUB_OIDC_PROVIDER2_USERINFO_METHOD` | None | string (post, get) | get | | `INFRAHUB_OIDC_PROVIDER2_PKCE_ENABLED` | Enable PKCE (RFC 7636) with S256 method for authorization code flow | boolean | True | +| `INFRAHUB_OIDC_PROVIDER2_GROUPS_CLAIM` | Top-level key in the IdP claim payload from which the user's groups are read. Defaults to `groups`. Set per provider when your IdP emits group memberships under a different claim name (e.g., `roles`). | string | groups | | `INFRAHUB_OIDC_PROVIDER2_CLIENT_ID` | Client ID of the application created in the auth provider | string | None | | `INFRAHUB_OIDC_PROVIDER2_CLIENT_SECRET` | Client secret as defined in auth provider | None | None | | `INFRAHUB_OIDC_PROVIDER2_DISCOVERY_URL` | The OIDC discovery URL xyz/.well-known/openid-configuration | string | None |