diff --git a/server/config/development.yaml b/server/config/development.yaml index 64a5b15e52..2a736c6727 100644 --- a/server/config/development.yaml +++ b/server/config/development.yaml @@ -39,6 +39,8 @@ auth: audience: myorg-dev organization: org_xxxxxxxxxxxx invitation: + additionalClaims: + group: "local" tls: caFile: certFile: diff --git a/server/config/docker.yaml b/server/config/docker.yaml index f5a842b93c..6af8f64ac6 100644 --- a/server/config/docker.yaml +++ b/server/config/docker.yaml @@ -23,6 +23,7 @@ batchActionsDisabled: {{ env "TEMPORAL_BATCH_ACTIONS_DISABLED" | default "false" startWorkflowDisabled: {{ env "TEMPORAL_START_WORKFLOW_DISABLED" | default "false" }} hideWorkflowQueryErrors: {{ env "TEMPORAL_HIDE_WORKFLOW_QUERY_ERRORS" | default "false" }} refreshWorkflowCountsDisabled: {{ env "TEMPORAL_REFRESH_WORKFLOW_COUNTS_DISABLED" | default "false" }} +additionalClaims: {{ env "TEMPORAL_OAUTH_ADDITIONAL_CLAIMS" | default "false" }} cors: cookieInsecure: {{ env "TEMPORAL_CSRF_COOKIE_INSECURE" | default "false" }} allowOrigins: diff --git a/server/docker/README.md b/server/docker/README.md index d987796eca..259988b633 100644 --- a/server/docker/README.md +++ b/server/docker/README.md @@ -25,6 +25,7 @@ docker run \ -e TEMPORAL_TLS_KEY=../cluster.key \ -e TEMPORAL_TLS_ENABLE_HOST_VERIFICATION=true \ -e TEMPORAL_TLS_SERVER_NAME=tls-server \ + -e TEMPORAL_OAUTH_ADDITIONAL_CLAIMS="group: \"mygroup\"" \ temporalio/ui:latest ``` diff --git a/server/server/auth/oidc.go b/server/server/auth/oidc.go index d863dc0d71..db122a24f0 100644 --- a/server/server/auth/oidc.go +++ b/server/server/auth/oidc.go @@ -58,7 +58,7 @@ type Claims struct { Picture string `json:"picture"` } -func ExchangeCode(ctx context.Context, r *http.Request, config *oauth2.Config, provider *oidc.Provider) (*User, error) { +func ExchangeCode(ctx context.Context, r *http.Request, config *oauth2.Config, provider *oidc.Provider, additionalClaimsConfig map[string]string) (*User, error) { state, err := r.Cookie("state") if err != nil { return nil, echo.NewHTTPError(http.StatusBadRequest, "State cookie is not set in request") @@ -98,6 +98,17 @@ func ExchangeCode(ctx context.Context, r *http.Request, config *oauth2.Config, p return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + if additionalClaimsConfig != nil { + var claimTokenValues map[string]interface{} + if err := idToken.Claims(&claimTokenValues); err != nil { + return nil, echo.NewHTTPError(http.StatusInternalServerError, "Error parse claims") + } + + if err := VerifyAdditionalClaims(additionalClaimsConfig, claimTokenValues); err != nil { + return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + } + user := User{ OAuth2Token: oauth2Token, IDToken: &IDToken{ @@ -108,3 +119,24 @@ func ExchangeCode(ctx context.Context, r *http.Request, config *oauth2.Config, p return &user, nil } + +func VerifyAdditionalClaims(additionalClaims map[string]string, claimTokenValues map[string]interface{}) error { + for configClaimKey, configClaimValue := range additionalClaims { + claimValues := claimTokenValues[configClaimKey] + if claimValues != nil { + switch t := claimValues.(type) { + case string: + if t == configClaimValue { + return nil + } + case []interface{}: + for _, claimValue := range t { + if claimValue == configClaimValue { + return nil + } + } + } + } + } + return echo.NewHTTPError(http.StatusInternalServerError, "No additional Claims defined") +} diff --git a/server/server/auth/oidc_test.go b/server/server/auth/oidc_test.go new file mode 100644 index 0000000000..3b52451f6f --- /dev/null +++ b/server/server/auth/oidc_test.go @@ -0,0 +1,30 @@ +package auth + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestAdditionalClaimsError(t *testing.T) { + configAdditionalClaims := map[string]string{ + "foo": "bar", + } + claimTokenValues := map[string]interface{}{ + "foo": []interface{}{"value1", "value2", "value3"}, + "group": []interface{}{"value1", "value2", "bar"}, + } + err := VerifyAdditionalClaims(configAdditionalClaims, claimTokenValues) + assert.Error(t, err) +} + +func TestAdditionalClaimsSuccess(t *testing.T) { + configAdditionalClaims := map[string]string{ + "group": "bar", + } + claimTokenValues := map[string]interface{}{ + "foo": []interface{}{"value1", "value2", "value3"}, + "group": []interface{}{"value1", "value2", "bar"}, + } + err := VerifyAdditionalClaims(configAdditionalClaims, claimTokenValues) + assert.NoError(t, err) +} diff --git a/server/server/config/config.go b/server/server/config/config.go index d545b0d3e7..1126ca7d62 100644 --- a/server/server/config/config.go +++ b/server/server/config/config.go @@ -113,7 +113,8 @@ type ( // CallbackURL - URL for the callback URL, ex. https://localhost:8080/sso/callback CallbackURL string `yaml:"callbackUrl"` // Options added as URL query params when redirecting to auth provider. Can be used to configure custom auth flows such as Auth0 invitation flow. - Options map[string]interface{} `yaml:"options"` + Options map[string]interface{} `yaml:"options"` + AdditionalClaims map[string]string `yaml:"additionalClaims"` } Codec struct { diff --git a/server/server/route/auth.go b/server/server/route/auth.go index 21cab99728..832fc25471 100644 --- a/server/server/route/auth.go +++ b/server/server/route/auth.go @@ -77,8 +77,8 @@ func SetAuthRoutes(e *echo.Echo, cfgProvider *config.ConfigProviderWithRefresh) api := e.Group("/auth") api.GET("/sso", authenticate(&oauthCfg, providerCfg.Options)) - api.GET("/sso/callback", authenticateCb(ctx, &oauthCfg, provider)) - api.GET("/sso_callback", authenticateCb(ctx, &oauthCfg, provider)) // compatibility with UI v1 + api.GET("/sso/callback", authenticateCb(ctx, &oauthCfg, provider, providerCfg.AdditionalClaims)) + api.GET("/sso_callback", authenticateCb(ctx, &oauthCfg, provider, providerCfg.AdditionalClaims)) // compatibility with UI v1 } func authenticate(config *oauth2.Config, options map[string]interface{}) func(echo.Context) error { @@ -119,9 +119,9 @@ func authenticate(config *oauth2.Config, options map[string]interface{}) func(ec } } -func authenticateCb(ctx context.Context, oauthCfg *oauth2.Config, provider *oidc.Provider) func(echo.Context) error { +func authenticateCb(ctx context.Context, oauthCfg *oauth2.Config, provider *oidc.Provider, additionalClaims map[string]string) func(echo.Context) error { return func(c echo.Context) error { - user, err := auth.ExchangeCode(ctx, c.Request(), oauthCfg, provider) + user, err := auth.ExchangeCode(ctx, c.Request(), oauthCfg, provider, additionalClaims) if err != nil { return err }