-
Notifications
You must be signed in to change notification settings - Fork 52
Configurable groups claim key for SSO providers #9367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
ccdf77f
a5ab1b4
d0b0a64
43e30fb
355e88f
935b6ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,8 @@ | |
| from infrahub.log import get_logger | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Mapping | ||
|
|
||
| import httpx | ||
|
|
||
| from infrahub.database import InfrahubDatabase | ||
|
|
@@ -587,6 +589,50 @@ async def get_groups_from_provider( | |
| return [] | ||
|
|
||
|
|
||
| def extract_sso_groups( | ||
| *, | ||
| payload: Mapping[str, Any], | ||
| claim_key: str, | ||
| provider_name: str, | ||
| source: str, | ||
| ) -> list[str]: | ||
| if claim_key not in payload: | ||
| log.warning( | ||
| "sso groups claim miss", | ||
| provider=provider_name, | ||
| source=source, | ||
| configured_claim=claim_key, | ||
| available_keys=sorted(payload.keys()), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The miss log dumps all top-level claim keys at WARN. Key names aren't usually PII, but URI-namespaced claims can embed tenant/customer identifiers (e.g. |
||
| miss_reason="absent", | ||
| ) | ||
| return [] | ||
|
|
||
| value = payload[claim_key] | ||
| if not isinstance(value, list): | ||
| log.warning( | ||
| "sso groups claim miss", | ||
| provider=provider_name, | ||
| source=source, | ||
| configured_claim=claim_key, | ||
| available_keys=sorted(payload.keys()), | ||
| miss_reason="not_list", | ||
| ) | ||
| return [] | ||
|
|
||
| if not all(isinstance(item, str) for item in value): | ||
| log.warning( | ||
| "sso groups claim miss", | ||
| provider=provider_name, | ||
| source=source, | ||
| configured_claim=claim_key, | ||
| available_keys=sorted(payload.keys()), | ||
| miss_reason="list_has_non_string", | ||
| ) | ||
| return [] | ||
|
|
||
| return value | ||
|
|
||
|
|
||
| def safe_get_response_body(response: httpx.Response, raise_error_on_empty_body: bool = True) -> str | dict[str, Any]: | ||
| """Safely extract response body from HTTP response. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -602,6 +602,21 @@ class SecurityOIDCBaseSettings(BaseSettings): | |
| pkce_enabled: bool = Field( | ||
| default=True, description="Enable PKCE (RFC 7636) with S256 method for authorization code flow" | ||
| ) | ||
| groups_claim: str = Field( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P3: New groups_claim logic duplicated in two base classes. Easy drift later. Extract one shared mixin/helper and reuse it. Prompt for AI agents |
||
| default="groups", | ||
| description=( | ||
| "Top-level key in the IdP claim payload from which the user's groups are read. " | ||
| "Defaults to `groups`. Set per provider when your IdP emits group memberships " | ||
| "under a different claim name (e.g., `roles`)." | ||
| ), | ||
| ) | ||
|
|
||
| @field_validator("groups_claim") | ||
| @classmethod | ||
| def _validate_groups_claim(cls, value: str) -> str: | ||
| if not value.strip(): | ||
| raise ValueError("groups_claim must not be empty or whitespace-only") | ||
| return value | ||
|
|
||
|
|
||
| class SecurityOIDCSettings(SecurityOIDCBaseSettings): | ||
|
|
@@ -657,6 +672,21 @@ class SecurityOAuth2BaseSettings(BaseSettings): | |
| pkce_enabled: bool = Field( | ||
| default=True, description="Enable PKCE (RFC 7636) with S256 method for authorization code flow" | ||
| ) | ||
| groups_claim: str = Field( | ||
| default="groups", | ||
| description=( | ||
| "Top-level key in the IdP claim payload from which the user's groups are read. " | ||
| "Defaults to `groups`. Set per provider when your IdP emits group memberships " | ||
| "under a different claim name (e.g., `roles`)." | ||
| ), | ||
| ) | ||
|
|
||
| @field_validator("groups_claim") | ||
| @classmethod | ||
| def _validate_groups_claim(cls, value: str) -> str: | ||
| if not value.strip(): | ||
| raise ValueError("groups_claim must not be empty or whitespace-only") | ||
| return value | ||
|
|
||
|
|
||
| class SecurityOAuth2Settings(SecurityOAuth2BaseSettings): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| import json | ||
|
|
||
| import httpx | ||
| from structlog.testing import capture_logs | ||
|
|
||
| from infrahub.auth.auth import extract_sso_groups | ||
| from infrahub.config import SecurityOAuth2Provider1 | ||
| from infrahub.services import InfrahubServices | ||
| from tests.adapters.http import MemoryHTTP | ||
|
|
||
|
|
||
| def _build_provider(groups_claim: str = "groups") -> SecurityOAuth2Provider1: | ||
| return SecurityOAuth2Provider1( | ||
| client_id="infrahub-user-client", | ||
| client_secret="secret", | ||
| authorization_url="https://idp.example.com/auth", | ||
| token_url="https://idp.example.com/token", | ||
| userinfo_url="https://idp.example.com/userinfo", | ||
| groups_claim=groups_claim, | ||
| ) | ||
|
|
||
|
|
||
| async def test_oauth2_userinfo_extracts_groups_from_custom_claim_key() -> None: | ||
| memory_http = MemoryHTTP() | ||
| service = await InfrahubServices.new(http=memory_http) | ||
| provider = _build_provider(groups_claim="roles") | ||
|
|
||
| userinfo_body = { | ||
| "sub": "u1", | ||
| "name": "Otto", | ||
| "email": "o@x.com", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are using fake data, we should use a RFC 2606/BCP 32 reserved domain rather than domain owned by an actual company. |
||
| "roles": ["network-engineering"], | ||
| } | ||
| memory_http.add_get_response( | ||
| url=provider.userinfo_url, | ||
| response=httpx.Response(status_code=200, content=json.dumps(userinfo_body)), | ||
| ) | ||
|
|
||
| userinfo_response = await service.http.get(provider.userinfo_url) | ||
| user_info = userinfo_response.json() | ||
|
|
||
| sso_groups = extract_sso_groups( | ||
| payload=user_info, | ||
| claim_key=provider.groups_claim, | ||
| provider_name="provider1", | ||
| source="oauth2_userinfo", | ||
| ) | ||
|
|
||
| assert sso_groups == ["network-engineering"] | ||
|
|
||
|
|
||
| async def test_oauth2_userinfo_custom_claim_key_does_not_read_groups_key() -> None: | ||
| memory_http = MemoryHTTP() | ||
| service = await InfrahubServices.new(http=memory_http) | ||
| provider = _build_provider(groups_claim="roles") | ||
|
|
||
| userinfo_body = { | ||
| "sub": "u1", | ||
| "name": "Otto", | ||
| "email": "o@x.com", | ||
| "groups": ["legacy-group"], | ||
| } | ||
| memory_http.add_get_response( | ||
| url=provider.userinfo_url, | ||
| response=httpx.Response(status_code=200, content=json.dumps(userinfo_body)), | ||
| ) | ||
|
|
||
| userinfo_response = await service.http.get(provider.userinfo_url) | ||
| user_info = userinfo_response.json() | ||
|
|
||
| with capture_logs() as records: | ||
| sso_groups = extract_sso_groups( | ||
| payload=user_info, | ||
| claim_key=provider.groups_claim, | ||
| provider_name="provider1", | ||
| source="oauth2_userinfo", | ||
| ) | ||
|
|
||
| assert sso_groups == [] | ||
| warnings = [r for r in records if r.get("event") == "sso groups claim miss"] | ||
| assert len(warnings) == 1 | ||
| assert warnings[0]["miss_reason"] == "absent" | ||
| assert warnings[0]["source"] == "oauth2_userinfo" | ||
|
|
||
|
|
||
| async def test_default_claim_key_preserves_existing_behavior() -> None: | ||
| memory_http = MemoryHTTP() | ||
| service = await InfrahubServices.new(http=memory_http) | ||
| provider = _build_provider() | ||
|
|
||
| assert provider.groups_claim == "groups" | ||
|
|
||
| userinfo_body = { | ||
| "sub": "u", | ||
| "name": "Otto", | ||
| "email": "o@x.com", | ||
| "groups": ["admin-otter"], | ||
| } | ||
| memory_http.add_get_response( | ||
| url=provider.userinfo_url, | ||
| response=httpx.Response(status_code=200, content=json.dumps(userinfo_body)), | ||
| ) | ||
|
|
||
| userinfo_response = await service.http.get(provider.userinfo_url) | ||
| user_info = userinfo_response.json() | ||
|
|
||
| with capture_logs() as records: | ||
| sso_groups = extract_sso_groups( | ||
| payload=user_info, | ||
| claim_key=provider.groups_claim, | ||
| provider_name="provider1", | ||
| source="oauth2_userinfo", | ||
| ) | ||
|
|
||
| assert sso_groups == ["admin-otter"] | ||
| assert not any(r.get("event") == "sso groups claim miss" for r in records) | ||
|
|
||
|
|
||
| async def test_oauth2_and_oidc_with_different_claim_keys_coexist() -> None: | ||
| oauth2_provider = _build_provider(groups_claim="memberships") | ||
| oauth2_payload = { | ||
| "sub": "u1", | ||
| "name": "Otto", | ||
| "email": "o@x.com", | ||
| "memberships": ["membership-a"], | ||
| } | ||
| oauth2_groups = extract_sso_groups( | ||
| payload=oauth2_payload, | ||
| claim_key=oauth2_provider.groups_claim, | ||
| provider_name="provider1", | ||
| source="oauth2_userinfo", | ||
| ) | ||
|
|
||
| oidc_payload = { | ||
| "sub": "u2", | ||
| "name": "Otto", | ||
| "email": "o@x.com", | ||
| "roles": ["role-x"], | ||
| } | ||
| oidc_groups = extract_sso_groups( | ||
| payload=oidc_payload, | ||
| claim_key="roles", | ||
| provider_name="provider2", | ||
| source="oidc_id_token", | ||
| ) | ||
|
|
||
| assert oauth2_groups == ["membership-a"] | ||
| assert oidc_groups == ["role-x"] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every caller passes
claim_keyandprovider_nameexplicitly. The defaults ("groups"/"") only exist for tests. Suggest dropping the defaults and making both required.