From c3ae45d1b73e073e86fdb3c73592f40699babb2a Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Thu, 4 Oct 2018 18:14:06 -0700 Subject: [PATCH 01/38] First pass port of Azure v2 provider --- Godeps | 5 + internal/auth/options.go | 7 +- internal/auth/providers/azure.go | 256 +++++++++++++++++++++++++ internal/auth/providers/google.go | 6 + internal/auth/providers/oidc.go | 92 +++++++++ internal/auth/providers/providers.go | 4 + internal/pkg/sessions/session_state.go | 1 + 7 files changed, 370 insertions(+), 1 deletion(-) create mode 100644 internal/auth/providers/azure.go create mode 100644 internal/auth/providers/oidc.go diff --git a/Godeps b/Godeps index 836ab345..47666da0 100644 --- a/Godeps +++ b/Godeps @@ -1,6 +1,8 @@ github.com/18F/hmacauth 1.0.1 gopkg.in/yaml.v2 v2 github.com/imdario/mergo v0.3.4 +github.com/bitly/go-simplejson da1a8928f709389522c8023062a3739f3b4af419 +github.com/mreiferson/go-options 77551d20752b54535462404ad9d877ebdb26e53d github.com/datadog/datadog-go/statsd c74bd0589c83817c93e4eff39ccae69d6c46df9b golang.org/x/oauth2 7fdf09982454086d5570c7db3e11f360194830ca golang.org/x/net/context 242b6b35177ec3909636b6cf6a47e8c2c6324b5d @@ -11,3 +13,6 @@ github.com/kelseyhightower/envconfig v1.3.0 github.com/miscreant/miscreant-go 6b98fbe3dd42dfd24a8ecbabdb3586ada20dc5f8 github.com/sirupsen/logrus e54a77765aca7bbdd8e56c1c54f60579968b2dc9 github.com/rakyll/statik v0.1.4 +github.com/coreos/go-oidc v2.0.0 +gopkg.in/square/go-jose.v2 v2.1.9 +github.com/pquerna/cachecontrol 1555304b9b35fdd2b425bccf1a5613677705e7d0 diff --git a/internal/auth/options.go b/internal/auth/options.go index 2e317eb7..b85b33c6 100644 --- a/internal/auth/options.go +++ b/internal/auth/options.go @@ -63,6 +63,7 @@ type Options struct { EmailDomains []string `envconfig:"SSO_EMAIL_DOMAIN"` ProxyRootDomains []string `envconfig:"PROXY_ROOT_DOMAIN"` + AzureTenant string `envconfig:"AZURE_TENANT"` GoogleAdminEmail string `envconfig:"GOOGLE_ADMIN_EMAIL"` GoogleServiceAccountJSON string `envconfig:"GOOGLE_SERVICE_ACCOUNT_JSON"` @@ -260,7 +261,11 @@ func newProvider(o *Options) (providers.Provider, error) { var singleFlightProvider providers.Provider switch o.Provider { - case providers.GoogleProviderName: // Google + case providers.AzureProviderName: + azureProvider := providers.NewAzureV2Provider(p) + azureProvider.Configure(o.AzureTenant) + singleFlightProvider = providers.NewSingleFlightProvider(azureProvider) + case providers.GoogleProviderName: if o.GoogleServiceAccountJSON != "" { _, err := os.Open(o.GoogleServiceAccountJSON) if err != nil { diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go new file mode 100644 index 00000000..caa84134 --- /dev/null +++ b/internal/auth/providers/azure.go @@ -0,0 +1,256 @@ +package providers + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/bitly/oauth2_proxy/api" + "github.com/buzzfeed/sso/internal/pkg/sessions" + "golang.org/x/oauth2" + + log "github.com/buzzfeed/sso/internal/pkg/logging" + oidc "github.com/coreos/go-oidc" +) + +// AzureV2Provider is an Azure AD v2 specific implementation of the Provider interface. +type AzureV2Provider struct { + *ProviderData + *OIDCProvider + + Tenant string + PermittedGroups []string +} + +// NewAzureV2Provider creates a new AzureV2Provider struct +func NewAzureV2Provider(p *ProviderData) *AzureV2Provider { + p.ProviderName = "Azure v2.0" + return &AzureV2Provider{ProviderData: p, OIDCProvider: &OIDCProvider{ProviderData: p}} +} + +// Redeem fulfills the Provider interface. +// The authenticator uses this method to redeem the code provided to /callback after the user logs into their Azure AD account. +func (p *AzureV2Provider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { + ctx := context.Background() + c := oauth2.Config{ + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: p.RedeemURL.String(), + }, + RedirectURL: redirectURL, + } + token, err := c.Exchange(ctx, code) + if err != nil { + return nil, fmt.Errorf("token exchange: %v", err) + } + + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return nil, fmt.Errorf("token response did not contain an id_token") + } + + // Parse and verify ID Token payload. + idToken, err := p.OIDCProvider.Verifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, fmt.Errorf("could not verify id_token: %v", err) + } + + // Extract custom claims. + var claims struct { + Email string `json:"email"` + UPN string `json:"upn"` + Roles []string `json:"roles"` + Verified *bool `json:"email_verified"` + } + if err := idToken.Claims(&claims); err != nil { + return nil, fmt.Errorf("failed to parse id_token claims: %v", err) + } + if claims.Email == "" { + return nil, fmt.Errorf("id_token did not contain an email") + } + if claims.Verified != nil && !*claims.Verified { + return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) + } + + s = &sessions.SessionState{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + RefreshDeadline: token.Expiry, + Email: claims.Email, + User: claims.UPN, + Groups: claims.Roles, + } + + return +} + +// Configure sets the Azure tenant ID value for the provider +func (p *AzureV2Provider) Configure(tenant string) { + p.Tenant = tenant + if tenant == "" { + p.Tenant = "common" + } + discoveryURL := fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0", p.Tenant) + + // Configure discoverable provider data. + oidcProvider, err := oidc.NewProvider(context.Background(), discoveryURL) + if err != nil { + return + } + + logger := log.NewLogEntry() + + p.OIDCProvider.Verifier = oidcProvider.Verifier(&oidc.Config{ + ClientID: p.ClientID, + }) + p.SignInURL, err = url.Parse(oidcProvider.Endpoint().AuthURL) + if err != nil { + logger.Printf("Unable to parse OIDC Authentication URL: %v", err) + return + } + p.RedeemURL, err = url.Parse(oidcProvider.Endpoint().TokenURL) + if err != nil { + logger.Printf("Unable to parse OIDC Token URL: %v", err) + return + } + p.ProfileURL, err = url.Parse("https://graph.microsoft.com/oidc/userinfo") + if err != nil { + logger.Printf("Unable to parse OIDC UserInfo URL: %v", err) + return + } + if p.Scope == "" { + p.Scope = "openid email profile" + } +} + +// RefreshSessionIfNeeded takes in a SessionState and +// returns false if the session is not refreshed and true if it is. +func (p *AzureV2Provider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshDeadline.After(time.Now()) || s.RefreshToken == "" { + return false, nil + } + logger := log.NewLogEntry() + + s.RefreshDeadline = time.Now().Add(time.Second).Truncate(time.Second) + logger.WithUser(s.Email).WithRefreshDeadline(s.RefreshDeadline).Info("refreshed access token") + return false, nil +} + +func getAzureHeader(accessToken string) http.Header { + header := make(http.Header) + header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + return header +} + +// GetGroups lists groups user belongs to. Filter the desired names of groups (in case of huge group set) +func (p *AzureV2Provider) GetGroups(s *sessions.SessionState, f string) ([]string, error) { + logger := log.NewLogEntry() + logger.Printf("[azure_v2] GetGroups") + if s.AccessToken == "" { + return []string{}, errors.New("missing access token") + } + logger.Printf("Access Token %v", s.AccessToken) + + if s.IDToken == "" { + return []string{}, errors.New("missing id token") + } + logger.Printf("ID Token %v", s.IDToken) + + groupNames := make([]string, 0) + requestURL := p.ProfileURL.String() + for { + req, err := http.NewRequest("GET", requestURL, nil) + + if err != nil { + return []string{}, err + } + req.Header = getAzureHeader(s.AccessToken) + req.Header.Add("Content-Type", "application/json") + + groupData, err := api.Request(req) + if err != nil { + return []string{}, err + } + logger.Printf("Got Graph response: %v", groupData) + + for _, groupInfo := range groupData.Get("value").MustArray() { + v, ok := groupInfo.(map[string]interface{}) + if !ok { + continue + } + dname := v["displayName"].(string) + if strings.Contains(dname, f) { + groupNames = append(groupNames, dname) + } + + } + + if nextlink := groupData.Get("@odata.nextLink").MustString(); nextlink != "" { + requestURL = nextlink + } else { + break + } + } + + return groupNames, nil +} + +// GetSignInURL returns the sign in url with typical oauth parameters +func (p *AzureV2Provider) GetSignInURL(redirectURI, state string) string { + var a url.URL + a = *p.SignInURL + params, _ := url.ParseQuery(a.RawQuery) + params.Set("client_id", p.ClientID) + params.Set("response_type", "id_token code") + params.Set("redirect_uri", redirectURI) + params.Set("response_mode", "form_post") + params.Add("scope", p.Scope) + params.Add("state", state) + params.Set("prompt", p.ApprovalPrompt) + params.Set("nonce", "FIXME") // FIXME + a.RawQuery = params.Encode() + + return a.String() +} + +// SetGroupRestriction limits which groups are allowed +func (p *AzureV2Provider) SetGroupRestriction(groups []string) { + logger := log.NewLogEntry() + if len(groups) == 1 && strings.Index(groups[0], "|") >= 0 { + p.PermittedGroups = strings.Split(groups[0], "|") + } else { + p.PermittedGroups = groups + } + logger.Printf("Set group restrictions. Allowed groups are:") + for _, pGroup := range p.PermittedGroups { + logger.Printf("\t'%s'", pGroup) + } +} + +// ValidateGroup ensures that a user is a member of a permitted group +func (p *AzureV2Provider) ValidateGroup(s *sessions.SessionState) bool { + if len(p.PermittedGroups) != 0 { + for _, pGroup := range p.PermittedGroups { + + if contains(s.Groups, pGroup) { + return true + } + } + return false + } + return true +} + +func contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} diff --git a/internal/auth/providers/google.go b/internal/auth/providers/google.go index b2faf551..e619216f 100644 --- a/internal/auth/providers/google.go +++ b/internal/auth/providers/google.go @@ -21,6 +21,12 @@ import ( "github.com/datadog/datadog-go/statsd" ) +var ( + // This is a compile-time check to make sure our types correctly implement the interface: + // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae + _ Provider = &GoogleProvider{} +) + // GoogleProvider is an implementation of the Provider interface. type GoogleProvider struct { *ProviderData diff --git a/internal/auth/providers/oidc.go b/internal/auth/providers/oidc.go new file mode 100644 index 00000000..7aa0d3b9 --- /dev/null +++ b/internal/auth/providers/oidc.go @@ -0,0 +1,92 @@ +package providers + +import ( + "context" + "fmt" + "time" + + "golang.org/x/oauth2" + + "github.com/buzzfeed/sso/internal/pkg/sessions" + oidc "github.com/coreos/go-oidc" +) + +// OIDCProvider is a generic OpenID Connect provider +type OIDCProvider struct { + *ProviderData + + Verifier *oidc.IDTokenVerifier +} + +// NewOIDCProvider creates a new generic OpenID Connect provider +func NewOIDCProvider(p *ProviderData) *OIDCProvider { + p.ProviderName = "OpenID Connect" + return &OIDCProvider{ProviderData: p} +} + +// Redeem fulfills the Provider interface. +// The authenticator uses this method to redeem the code provided to /callback after the user logs into their OpenID Connect account. +func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { + ctx := context.Background() + c := oauth2.Config{ + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: p.RedeemURL.String(), + }, + RedirectURL: redirectURL, + } + token, err := c.Exchange(ctx, code) + if err != nil { + return nil, fmt.Errorf("token exchange: %v", err) + } + + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return nil, fmt.Errorf("token response did not contain an id_token") + } + + // Parse and verify ID Token payload. + idToken, err := p.Verifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, fmt.Errorf("could not verify id_token: %v", err) + } + + // Extract custom claims. + var claims struct { + Email string `json:"email"` + Verified *bool `json:"email_verified"` + } + if err := idToken.Claims(&claims); err != nil { + return nil, fmt.Errorf("failed to parse id_token claims: %v", err) + } + + if claims.Email == "" { + return nil, fmt.Errorf("id_token did not contain an email") + } + if claims.Verified != nil && !*claims.Verified { + return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) + } + + s = &sessions.SessionState{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + RefreshDeadline: token.Expiry, + Email: claims.Email, + } + + return +} + +// RefreshSessionIfNeeded takes in a SessionState and +// returns false if the session is not refreshed and true if it is. +func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { + if s == nil || s.RefreshDeadline.After(time.Now()) || s.RefreshToken == "" { + return false, nil + } + + origExpiration := s.RefreshDeadline + s.RefreshDeadline = time.Now().Add(time.Second).Truncate(time.Second) + fmt.Printf("refreshed access token %s (expired on %s)\n", s, origExpiration) + return false, nil +} diff --git a/internal/auth/providers/providers.go b/internal/auth/providers/providers.go index 4feed8a4..10405960 100644 --- a/internal/auth/providers/providers.go +++ b/internal/auth/providers/providers.go @@ -28,8 +28,12 @@ var ( ) const ( + // AzureProviderName identifies the Azure AD v2 provider + AzureProviderName = "azure" // GoogleProviderName identifies the Google provider GoogleProviderName = "google" + // OIDCProviderName identifies the OpenID Connect provider + OIDCProviderName = "oidc" ) // Provider is an interface exposing functions necessary to authenticate with a given provider. diff --git a/internal/pkg/sessions/session_state.go b/internal/pkg/sessions/session_state.go index 090b073d..5cc2e492 100644 --- a/internal/pkg/sessions/session_state.go +++ b/internal/pkg/sessions/session_state.go @@ -19,6 +19,7 @@ var ( type SessionState struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` RefreshDeadline time.Time `json:"refresh_deadline"` LifetimeDeadline time.Time `json:"lifetime_deadline"` From 5acfd70a5edd41f36f6170dd4599a0875ec80fa9 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Wed, 14 Nov 2018 17:14:15 -0800 Subject: [PATCH 02/38] Porting of Azure AD provider largely complete --- internal/auth/authenticator.go | 10 +- internal/auth/options.go | 2 + internal/auth/providers/azure.go | 252 ++++++---- internal/auth/providers/azure_graph.go | 170 +++++++ internal/auth/providers/azure_graph_mock.go | 12 + internal/auth/providers/azure_test.go | 476 ++++++++++++++++++ internal/auth/providers/google_test.go | 4 +- internal/auth/providers/internal_util.go | 1 + internal/auth/providers/oidc.go | 56 ++- internal/auth/providers/provider_default.go | 2 +- .../auth/providers/singleflight_middleware.go | 2 + internal/pkg/aead/aead.go | 18 +- internal/pkg/sessions/session_state.go | 1 - internal/proxy/oauthproxy.go | 106 ++++ internal/proxy/providers/sso.go | 2 + internal/proxy/proxy_config.go | 4 + 16 files changed, 994 insertions(+), 124 deletions(-) create mode 100644 internal/auth/providers/azure_graph.go create mode 100644 internal/auth/providers/azure_graph_mock.go create mode 100644 internal/auth/providers/azure_test.go diff --git a/internal/auth/authenticator.go b/internal/auth/authenticator.go index c7035c99..b07b4d76 100644 --- a/internal/auth/authenticator.go +++ b/internal/auth/authenticator.go @@ -180,7 +180,7 @@ func (p *Authenticator) newMux() http.Handler { serviceMux.HandleFunc("/start", p.withMethods(p.OAuthStart, "GET")) serviceMux.HandleFunc("/sign_in", p.withMethods(p.validateClientID(p.validateRedirectURI(p.validateSignature(p.SignIn))), "GET")) serviceMux.HandleFunc("/sign_out", p.withMethods(p.validateRedirectURI(p.validateSignature(p.SignOut)), "GET", "POST")) - serviceMux.HandleFunc("/oauth2/callback", p.withMethods(p.OAuthCallback, "GET")) + serviceMux.HandleFunc("/oauth2/callback", p.withMethods(p.OAuthCallback, "GET", "POST")) serviceMux.HandleFunc("/profile", p.withMethods(p.validateClientID(p.validateClientSecret(p.GetProfile)), "GET")) serviceMux.HandleFunc("/validate", p.withMethods(p.validateClientID(p.validateClientSecret(p.ValidateToken)), "GET")) serviceMux.HandleFunc("/redeem", p.withMethods(p.validateClientID(p.validateClientSecret(p.Redeem)), "POST")) @@ -533,8 +533,8 @@ func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) { // Here we validate the redirect that is nested within the redirect_uri. // `authRedirectURL` points to step D, `proxyRedirectURL` points to step E. // - // A* B C D E - // /start -> Google -> auth /callback -> /sign_in -> proxy /callback + // A* B C D E + // /start -> IdP -> auth /callback -> /sign_in -> proxy /callback // // * you are here proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri")) @@ -558,7 +558,7 @@ func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) { func (p *Authenticator) redeemCode(host, code string) (*sessions.SessionState, error) { // The authenticator redeems `code` for an access token, and uses the token to request user - // info from the provider (Google). + // info from the provider. redirectURI := p.GetRedirectURI(host) // see providers/google.go#Redeem for more info @@ -574,7 +574,7 @@ func (p *Authenticator) redeemCode(host, code string) (*sessions.SessionState, e } func (p *Authenticator) getOAuthCallback(rw http.ResponseWriter, req *http.Request) (string, error) { - // After the provider (Google) redirects back to the sso proxy, the proxy uses this + // After the provider redirects back to the sso proxy, the proxy uses this // endpoint to set up auth cookies. logger := log.NewLogEntry() diff --git a/internal/auth/options.go b/internal/auth/options.go index b85b33c6..7e9666a5 100644 --- a/internal/auth/options.go +++ b/internal/auth/options.go @@ -313,6 +313,8 @@ func AssignStatsdClient(opts *Options) func(*Authenticator) error { proxy.StatsdClient = StatsdClient switch v := proxy.provider.(type) { + case *providers.AzureV2Provider: + v.SetStatsdClient(StatsdClient) case *providers.GoogleProvider: v.SetStatsdClient(StatsdClient) case *providers.SingleFlightProvider: diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index caa84134..cac4c638 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -9,27 +9,43 @@ import ( "strings" "time" - "github.com/bitly/oauth2_proxy/api" "github.com/buzzfeed/sso/internal/pkg/sessions" + "github.com/datadog/datadog-go/statsd" "golang.org/x/oauth2" log "github.com/buzzfeed/sso/internal/pkg/logging" oidc "github.com/coreos/go-oidc" ) +var ( + azureOIDCConfigURL = "https://login.microsoftonline.com/{tenant}/v2.0" + azureOIDCProfileURL = "https://graph.microsoft.com/oidc/userinfo" +) + // AzureV2Provider is an Azure AD v2 specific implementation of the Provider interface. type AzureV2Provider struct { *ProviderData *OIDCProvider - Tenant string - PermittedGroups []string + Tenant string + + StatsdClient *statsd.Client + + GraphService GraphService } // NewAzureV2Provider creates a new AzureV2Provider struct func NewAzureV2Provider(p *ProviderData) *AzureV2Provider { - p.ProviderName = "Azure v2.0" - return &AzureV2Provider{ProviderData: p, OIDCProvider: &OIDCProvider{ProviderData: p}} + p.ProviderName = "Azure AD" + return &AzureV2Provider{ + ProviderData: p, + OIDCProvider: &OIDCProvider{ProviderData: p}, + } +} + +// SetStatsdClient sets the azure provider statsd client +func (p *AzureV2Provider) SetStatsdClient(statsdClient *statsd.Client) { + p.StatsdClient = statsdClient } // Redeem fulfills the Provider interface. @@ -50,10 +66,19 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (s *sessions.SessionS } rawIDToken, ok := token.Extra("id_token").(string) - if !ok { + // BUG? rawIDToken is empty string here in current test cases. Test cases + // are shipping invalid ID tokens, but the Extra function doesn't seem to have + // code that would be affected by that. + if !ok || rawIDToken == "" { + fmt.Printf("token: %+v\n", token) return nil, fmt.Errorf("token response did not contain an id_token") } + // should only happen if oidc autodiscovery is broken + if p.OIDCProvider.Verifier == nil { + return nil, fmt.Errorf("oidc verifier missing") + } + // Parse and verify ID Token payload. idToken, err := p.OIDCProvider.Verifier.Verify(ctx, rawIDToken) if err != nil { @@ -62,10 +87,8 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (s *sessions.SessionS // Extract custom claims. var claims struct { - Email string `json:"email"` - UPN string `json:"upn"` - Roles []string `json:"roles"` - Verified *bool `json:"email_verified"` + Email string `json:"email"` + UPN string `json:"upn"` } if err := idToken.Claims(&claims); err != nil { return nil, fmt.Errorf("failed to parse id_token claims: %v", err) @@ -73,131 +96,141 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (s *sessions.SessionS if claims.Email == "" { return nil, fmt.Errorf("id_token did not contain an email") } - if claims.Verified != nil && !*claims.Verified { - return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) - } + // TODO: test this w/ an account that uses an alias and compare email claim + // with UPN claim; UPN has usually been what you want, but I think it's not + // rendered as a full email address here. + // FIXME: validate nonce against session s = &sessions.SessionState{ - AccessToken: token.AccessToken, - RefreshToken: token.RefreshToken, - RefreshDeadline: token.Expiry, - Email: claims.Email, - User: claims.UPN, - Groups: claims.Roles, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + + RefreshDeadline: token.Expiry, + LifetimeDeadline: sessions.ExtendDeadline(p.SessionLifetimeTTL), + + Email: claims.Email, + User: claims.UPN, } + if p.GraphService != nil { + groupNames, err := p.GraphService.GetGroups(claims.Email) + if err != nil { + return nil, fmt.Errorf("could not get groups: %v", err) + } + s.Groups = groupNames + } return } // Configure sets the Azure tenant ID value for the provider -func (p *AzureV2Provider) Configure(tenant string) { +func (p *AzureV2Provider) Configure(tenant string) error { p.Tenant = tenant if tenant == "" { p.Tenant = "common" } - discoveryURL := fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0", p.Tenant) + discoveryURL := strings.Replace(azureOIDCConfigURL, "{tenant}", p.Tenant, -1) // Configure discoverable provider data. oidcProvider, err := oidc.NewProvider(context.Background(), discoveryURL) if err != nil { - return + // FIXME: this seems like it _should_ work for "common", but it doesn't + // Does anyone actually want to use this with "common" though? + return err } - logger := log.NewLogEntry() - p.OIDCProvider.Verifier = oidcProvider.Verifier(&oidc.Config{ ClientID: p.ClientID, }) - p.SignInURL, err = url.Parse(oidcProvider.Endpoint().AuthURL) - if err != nil { - logger.Printf("Unable to parse OIDC Authentication URL: %v", err) - return + // Set these only if they haven't been overridden + if p.SignInURL == nil || p.SignInURL.String() == "" { + p.SignInURL, err = url.Parse(oidcProvider.Endpoint().AuthURL) + if err != nil { + return err + } } - p.RedeemURL, err = url.Parse(oidcProvider.Endpoint().TokenURL) - if err != nil { - logger.Printf("Unable to parse OIDC Token URL: %v", err) - return + if p.RedeemURL == nil || p.RedeemURL.String() == "" { + p.RedeemURL, err = url.Parse(oidcProvider.Endpoint().TokenURL) + if err != nil { + return err + } + } + if p.ProfileURL == nil || p.ProfileURL.String() == "" { + p.ProfileURL, err = url.Parse(azureOIDCProfileURL) } - p.ProfileURL, err = url.Parse("https://graph.microsoft.com/oidc/userinfo") if err != nil { - logger.Printf("Unable to parse OIDC UserInfo URL: %v", err) - return + return err } if p.Scope == "" { - p.Scope = "openid email profile" + p.Scope = "openid email profile offline_access" } + if p.RedeemURL.String() == "" { + return errors.New("redeem url must be set") + } + p.GraphService = NewAzureGraphService(p.ClientID, p.ClientSecret, p.RedeemURL.String()) + return nil } // RefreshSessionIfNeeded takes in a SessionState and // returns false if the session is not refreshed and true if it is. func (p *AzureV2Provider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { - if s == nil || s.RefreshDeadline.After(time.Now()) || s.RefreshToken == "" { + if s == nil || !s.RefreshPeriodExpired() || s.RefreshToken == "" { return false, nil } + + newToken, duration, err := p.RefreshAccessToken(s.RefreshToken) + if err != nil { + return false, err + } logger := log.NewLogEntry() - s.RefreshDeadline = time.Now().Add(time.Second).Truncate(time.Second) + s.AccessToken = newToken + + s.RefreshDeadline = time.Now().Add(duration).Truncate(time.Second) logger.WithUser(s.Email).WithRefreshDeadline(s.RefreshDeadline).Info("refreshed access token") - return false, nil -} -func getAzureHeader(accessToken string) http.Header { - header := make(http.Header) - header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) - return header + return true, nil } -// GetGroups lists groups user belongs to. Filter the desired names of groups (in case of huge group set) -func (p *AzureV2Provider) GetGroups(s *sessions.SessionState, f string) ([]string, error) { - logger := log.NewLogEntry() - logger.Printf("[azure_v2] GetGroups") - if s.AccessToken == "" { - return []string{}, errors.New("missing access token") +// RefreshAccessToken uses default OAuth2 TokenSource method to get a new +// access token. +func (p *AzureV2Provider) RefreshAccessToken(refreshToken string) (string, time.Duration, error) { + if refreshToken == "" { + return "", 0, errors.New("missing refresh token") } - logger.Printf("Access Token %v", s.AccessToken) + logger := log.NewLogEntry() + logger.Info("refreshing access token") - if s.IDToken == "" { - return []string{}, errors.New("missing id token") + ctx := context.Background() + c := oauth2.Config{ + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: p.RedeemURL.String(), + }, + } + t := oauth2.Token{ + RefreshToken: refreshToken, + } + ts := c.TokenSource(ctx, &t) + newToken, err := ts.Token() + if err != nil { + return "", 0, fmt.Errorf("token exchange: %v", err) } - logger.Printf("ID Token %v", s.IDToken) - - groupNames := make([]string, 0) - requestURL := p.ProfileURL.String() - for { - req, err := http.NewRequest("GET", requestURL, nil) - - if err != nil { - return []string{}, err - } - req.Header = getAzureHeader(s.AccessToken) - req.Header.Add("Content-Type", "application/json") - - groupData, err := api.Request(req) - if err != nil { - return []string{}, err - } - logger.Printf("Got Graph response: %v", groupData) - - for _, groupInfo := range groupData.Get("value").MustArray() { - v, ok := groupInfo.(map[string]interface{}) - if !ok { - continue - } - dname := v["displayName"].(string) - if strings.Contains(dname, f) { - groupNames = append(groupNames, dname) - } - } + return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil +} - if nextlink := groupData.Get("@odata.nextLink").MustString(); nextlink != "" { - requestURL = nextlink - } else { - break - } - } +func getAzureHeader(accessToken string) http.Header { + header := make(http.Header) + header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + return header +} - return groupNames, nil +// ValidateSessionState attempts to validate the session state's access token. +func (p *AzureV2Provider) ValidateSessionState(s *sessions.SessionState) bool { + // return validateToken(p, s.AccessToken, nil) + // TODO Validate ID token + return true } // GetSignInURL returns the sign in url with typical oauth parameters @@ -212,38 +245,39 @@ func (p *AzureV2Provider) GetSignInURL(redirectURI, state string) string { params.Add("scope", p.Scope) params.Add("state", state) params.Set("prompt", p.ApprovalPrompt) - params.Set("nonce", "FIXME") // FIXME + params.Set("nonce", "FIXME") // FIXME, maybe change to session state struct a.RawQuery = params.Encode() return a.String() } -// SetGroupRestriction limits which groups are allowed -func (p *AzureV2Provider) SetGroupRestriction(groups []string) { - logger := log.NewLogEntry() - if len(groups) == 1 && strings.Index(groups[0], "|") >= 0 { - p.PermittedGroups = strings.Split(groups[0], "|") - } else { - p.PermittedGroups = groups +// ValidateGroupMembership takes in an email and the allowed groups and returns the groups that the email is part of in that list. +// If `allGroups` is an empty list it returns all the groups that the user belongs to. +func (p *AzureV2Provider) ValidateGroupMembership(email string, allGroups []string) ([]string, error) { + if p.GraphService == nil { + panic("provider has not been configured") } - logger.Printf("Set group restrictions. Allowed groups are:") - for _, pGroup := range p.PermittedGroups { - logger.Printf("\t'%s'", pGroup) + + userGroups, err := p.GraphService.GetGroups(email) + if err != nil { + return nil, err } -} -// ValidateGroup ensures that a user is a member of a permitted group -func (p *AzureV2Provider) ValidateGroup(s *sessions.SessionState) bool { - if len(p.PermittedGroups) != 0 { - for _, pGroup := range p.PermittedGroups { + // if `allGroups` is empty use the groups resource + if len(allGroups) == 0 { + return userGroups, nil + } - if contains(s.Groups, pGroup) { - return true + filtered := []string{} + for _, userGroup := range userGroups { + for _, allowedGroup := range allGroups { + if userGroup == allowedGroup { + filtered = append(filtered, userGroup) } } - return false } - return true + + return filtered, nil } func contains(s []string, e string) bool { diff --git a/internal/auth/providers/azure_graph.go b/internal/auth/providers/azure_graph.go new file mode 100644 index 00000000..65cdbd03 --- /dev/null +++ b/internal/auth/providers/azure_graph.go @@ -0,0 +1,170 @@ +package providers + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + "sync" + + lru "github.com/hashicorp/golang-lru" + "golang.org/x/oauth2/clientcredentials" +) + +// AzureGroupCacheSize controls the size of the caches of AD group info +const AzureGroupCacheSize = 1024 + +// GraphService wraps calls to provider admin APIs +type GraphService interface { + GetGroups(string) ([]string, error) +} + +// AzureGraphService implements graph API calls for the Azure provider +type AzureGraphService struct { + client *http.Client + groupMembershipCache *lru.Cache + groupNameCache *lru.Cache +} + +// NewAzureGraphService creates a new graph service for getting groups +func NewAzureGraphService(clientID string, clientSecret string, tokenURL string) *AzureGraphService { + clientConfig := &clientcredentials.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + TokenURL: tokenURL, + Scopes: []string{ + "https://graph.microsoft.com/.default", + }, + } + ctx := context.Background() + client := clientConfig.Client(ctx) + memberCache, err := lru.New(AzureGroupCacheSize) + if err != nil { + panic(err) // Should only happen if AzureGroupCacheSize is a negative number + } + nameCache, err := lru.New(AzureGroupCacheSize) + if err != nil { + panic(err) // Should only happen if AzureGroupCacheSize is a negative number + } + return &AzureGraphService{ + client: client, + groupMembershipCache: memberCache, + groupNameCache: nameCache, + } +} + +// GetGroups lists groups user belongs to. +func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { + if gs.client == nil { + return []string{}, errors.New("oauth client must be configured") + } + if email == "" { + return []string{}, errors.New("missing email") + } + + var wg sync.WaitGroup + var mtx sync.Mutex + var err error + groupNames := make([]string, 0) + // See: https://developer.microsoft.com/en-us/graph/docs/api-reference/beta/api/user_getmembergroups + requestBody := `{"securityEnabledOnly": false}` + requestURL := fmt.Sprintf("https://graph.microsoft.com/beta/users/%s/getMemberGroups", url.PathEscape(email)) + for { + groupResponse, err := gs.client.Post(requestURL, "application/json", strings.NewReader(requestBody)) + if err != nil { + return []string{}, err + } + + groupData := struct { + Next string `json:"@odata.nextLink"` + Value []string `json:"value"` + }{} + + body, err := ioutil.ReadAll(groupResponse.Body) + if err != nil { + return []string{}, err + } + if groupResponse.StatusCode >= 400 { + return []string{}, fmt.Errorf("api error: %s", string(body)) + } + + err = json.Unmarshal(body, &groupData) + if err != nil { + return []string{}, err + } + + for _, groupID := range groupData.Value { + wg.Add(1) + id := groupID + go func(wg *sync.WaitGroup) { + defer wg.Done() + + var name string + // check the cache for the group name first + cachedName, ok := gs.groupNameCache.Get(id) + if !ok { + // didn't have the group name, make concurrent API call to fetch it + name, err = gs.getGroupName(id) + if err == nil { + // got the name ok, populate the cache + gs.groupNameCache.Add(id, name) + } + } else { + // cache hit + name = cachedName.(string) + } + mtx.Lock() + groupNames = append(groupNames, name) + mtx.Unlock() + }(&wg) + } + + if groupData.Next != "" { + requestURL = groupData.Next + } else { + break + } + } + wg.Wait() + if err != nil { + return []string{}, err + } + + return groupNames, nil +} + +// getGroupName returns the group name, preferentially pulling from cache +func (gs *AzureGraphService) getGroupName(id string) (string, error) { + if gs.client == nil { + return "", errors.New("oauth client must be configured") + } + // See: https://developer.microsoft.com/en-us/graph/docs/api-reference/v1.0/api/group_get + requestURL := fmt.Sprintf("https://graph.microsoft.com/v1.0/groups/%s", url.PathEscape(id)) + groupMetaResponse, err := gs.client.Get(requestURL) + if err != nil { + return "", err + } + + groupMetadata := struct { + DisplayName string `json:"displayName"` + }{} + + body, err := ioutil.ReadAll(groupMetaResponse.Body) + if err != nil { + return "", err + } + if groupMetaResponse.StatusCode >= 400 { + return "", fmt.Errorf("api error: %s", string(body)) + } + + err = json.Unmarshal(body, &groupMetadata) + if err != nil { + return "", err + } + + return groupMetadata.DisplayName, nil +} diff --git a/internal/auth/providers/azure_graph_mock.go b/internal/auth/providers/azure_graph_mock.go new file mode 100644 index 00000000..044ce01f --- /dev/null +++ b/internal/auth/providers/azure_graph_mock.go @@ -0,0 +1,12 @@ +package providers + +// MockAzureGraphService is an implementation of GraphService to be used for testing +type MockAzureGraphService struct { + Groups []string + GroupsError error +} + +// GetGroups mocks the GetGroups function +func (ms *MockAzureGraphService) GetGroups(string) ([]string, error) { + return ms.Groups, ms.GroupsError +} diff --git a/internal/auth/providers/azure_test.go b/internal/auth/providers/azure_test.go new file mode 100644 index 00000000..40dea165 --- /dev/null +++ b/internal/auth/providers/azure_test.go @@ -0,0 +1,476 @@ +package providers + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "log" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "strings" + "testing" + "time" + + "github.com/buzzfeed/sso/internal/pkg/groups" + "github.com/buzzfeed/sso/internal/pkg/sessions" + "github.com/buzzfeed/sso/internal/pkg/testutil" + + jose "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +const MicrosoftTenantID = "9188040d-6c67-4c5b-b112-36a304b66dad" +const TestClientID = "a4c35c92-e858-41e8-bd2c-ade04cb622b1" + +func newAzureProviderServer(redeemBody *[]byte, redeemCode int, pubKey *rsa.PublicKey) (*url.URL, *httptest.Server) { + var u *url.URL + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Header().Set("Content-Type", "application/json") + + // simple router for the mock IdP server + switch r.RequestURI { + case "/.well-known/openid-configuration": + rw.WriteHeader(200) + rw.Write([]byte(fmt.Sprintf(`{ + "token_endpoint":"%s/oauth2/token", + "jwks_uri":"%s/common/discovery/keys", + "issuer":"%s" + }`, u, u, u))) + case "/common/discovery/keys": + pub := jose.JSONWebKey{Key: pubKey, Algorithm: "RSA", Use: "sig"} + keyData, err := pub.MarshalJSON() + if err != nil { + panic(err) + } + rw.WriteHeader(200) + rw.Write([]byte(fmt.Sprintf(`{ + "keys":[%s] + }`, keyData))) + case "/oauth2/token": + rw.WriteHeader(redeemCode) + rw.Write(*redeemBody) + } + })) + u, _ = url.Parse(s.URL) + return u, s +} + +func newAzureV2Provider(providerData *ProviderData) *AzureV2Provider { + if providerData == nil { + providerData = &ProviderData{ + ProviderName: "", + SignInURL: &url.URL{}, + RedeemURL: &url.URL{}, + RevokeURL: &url.URL{}, + ProfileURL: &url.URL{}, + ValidateURL: &url.URL{}, + Scope: ""} + } + return NewAzureV2Provider(providerData) +} + +func TestAzureV2ProviderDefaults(t *testing.T) { + expectedResults := []struct { + name string + providerData *ProviderData + signInURL string + redeemURL string + revokeURL string + profileURL string + validateURL string + scope string + }{ + { + name: "defaults", + signInURL: "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/oauth2/v2.0/authorize", + redeemURL: "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/oauth2/v2.0/token", + profileURL: "https://graph.microsoft.com/oidc/userinfo", + revokeURL: "", // does not exist + validateURL: "", // does not exist + scope: "openid email profile offline_access", + }, + { + name: "with provider overrides", + providerData: &ProviderData{ + SignInURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/auth"}, + RedeemURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/token"}, + RevokeURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/deauth"}, + ProfileURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/profile"}, + ValidateURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/tokeninfo"}, + Scope: "profile"}, + signInURL: "https://example.com/oauth/auth", + redeemURL: "https://example.com/oauth/token", + revokeURL: "https://example.com/oauth/deauth", + profileURL: "https://example.com/oauth/profile", + validateURL: "https://example.com/oauth/tokeninfo", + scope: "profile", + }, + } + for _, expected := range expectedResults { + t.Run(expected.name, func(t *testing.T) { + p := newAzureV2Provider(expected.providerData) + err := p.Configure(MicrosoftTenantID) + if err != nil { + t.Error(err) + } + if p == nil { + t.Errorf("azure provider was nil") + } + if p.Data().ProviderName != "Azure AD" { + t.Errorf("expected provider name Azure AD, got %s", p.Data().ProviderName) + } + if p.Data().SignInURL.String() != expected.signInURL { + log.Printf("expected %s", expected.signInURL) + log.Printf("got %s", p.Data().SignInURL.String()) + t.Errorf("unexpected signin url") + } + + if p.Data().RedeemURL.String() != expected.redeemURL { + log.Printf("expected %s", expected.redeemURL) + log.Printf("got %s", p.Data().RedeemURL.String()) + t.Errorf("unexpected redeem url") + } + + if p.Data().RevokeURL.String() != expected.revokeURL { + log.Printf("expected %s", expected.revokeURL) + log.Printf("got %s", p.Data().RevokeURL.String()) + t.Errorf("unexpected revoke url") + } + + if p.Data().ValidateURL.String() != expected.validateURL { + log.Printf("expected %s", expected.validateURL) + log.Printf("got %s", p.Data().ValidateURL.String()) + t.Errorf("unexpected validate url") + } + + if p.Data().ProfileURL.String() != expected.profileURL { + log.Printf("expected %s", expected.profileURL) + log.Printf("got %s", p.Data().ProfileURL.String()) + t.Errorf("unexpected profile url") + } + + if p.Data().Scope != expected.scope { + log.Printf("expected %s", expected.scope) + log.Printf("got %s", p.Data().Scope) + t.Errorf("unexpected scope") + } + }) + + } +} + +// claims represents public claim values (as specified in RFC 7519). +type claims struct { + Issuer string `json:"iss,omitempty"` + Audience string `json:"aud,omitempty"` + Expiry jwt.NumericDate `json:"exp,omitempty"` + Name string `json:"name,omitempty"` + Email string `json:"email,omitempty"` + NotEmail string `json:"not_email,omitempty"` +} + +func TestAzureV2ProviderRedeem(t *testing.T) { + // For testing create the RSA key pair in the code + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Error(err) + } + // create Square.jose signing key + key := jose.SigningKey{Algorithm: jose.RS256, Key: privKey} + + // create a Square.jose RSA signer, used to sign the JWT + var signerOpts = jose.SignerOptions{} + signerOpts.WithType("JWT") + rsaSigner, err := jose.NewSigner(key, &signerOpts) + if err != nil { + t.Error(err) + } + + testCases := []struct { + name string + claims *claims + resp redeemResponse + expectedError bool + expectedSession *sessions.SessionState + }{ + { + name: "redeem", + claims: &claims{ + Issuer: "{mock-issuer}", + Audience: TestClientID, + Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), + Name: "Michael Bland", + Email: "michael.bland@gsa.gov", + }, + resp: redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + }, + expectedSession: &sessions.SessionState{ + Email: "michael.bland@gsa.gov", + AccessToken: "a1234", + RefreshToken: "refresh12345", + }, + }, + { + name: "missing issuer", + claims: &claims{ + Audience: TestClientID, + Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), + Name: "Michael Bland", + Email: "michael.bland@gsa.gov", + }, + resp: redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + }, + expectedError: true, + }, + { + name: "invalid issuer", + claims: &claims{ + Issuer: "https://example.com/bogus/issuer", + Audience: TestClientID, + Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), + Name: "Michael Bland", + Email: "michael.bland@gsa.gov", + }, + resp: redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + }, + expectedError: true, + }, + { + name: "invalid encoding", + resp: redeemResponse{ + AccessToken: "a1234", + IDToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + `{"name": "Michael Bland","email": "michael.bland@gsa.gov"}` + ".SC_eJ3K04rLOPLLDIWEKwr0DPZqw5KlFySybzmxfM6Y", + }, + expectedError: true, + }, + { + name: "invalid json", + resp: redeemResponse{ + AccessToken: "a1234", + IDToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)) + ".SC_eJ3K04rLOPLLDIWEKwr0DPZqw5KlFySybzmxfM6Y", + }, + expectedError: true, + }, + { + name: "missing email", + claims: &claims{ + NotEmail: "missing", + }, + resp: redeemResponse{ + AccessToken: "a1234", + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var body []byte + var server *httptest.Server + // pointer to body to bypass chicken/egg issue w/ mock server urls + providerURL, server := newAzureProviderServer(&body, http.StatusOK, &privKey.PublicKey) + defer server.Close() + azureOIDCConfigURL = providerURL.String() + + if tc.claims != nil { + // create an instance of Builder that uses the rsa signer + builder := jwt.Signed(rsaSigner) + + // add claims to the Builder + tc.claims.Issuer = strings.Replace(tc.claims.Issuer, "{mock-issuer}", providerURL.String(), -1) + builder = builder.Claims(tc.claims) + + // build and inject ID token into response + idToken, err := builder.CompactSerialize() + if err != nil { + t.Error(err) + } + tc.resp.IDToken = idToken + } + + body, err = json.Marshal(tc.resp) + testutil.Equal(t, nil, err) + + p := newAzureV2Provider(nil) + p.ClientID = TestClientID + p.ClientSecret = "456" + err = p.Configure(MicrosoftTenantID) + if err != nil { + t.Error(err) + } + // graph service mock has to be set after p.Configure + p.GraphService = &MockAzureGraphService{} + + session, err := p.Redeem("http://redirect/", "code1234") + if tc.expectedError && err == nil { + t.Errorf("expected redeem error but was nil") + } + if !tc.expectedError && err != nil { + t.Errorf("unexpected error %s", err) + } + if tc.expectedSession == nil && session != nil { + t.Errorf("expected session to be nil but it was %s", session) + } + if session != nil && tc.expectedSession != nil { + if session.Email != tc.expectedSession.Email { + log.Printf("expected email %s", tc.expectedSession.Email) + log.Printf("got %s", session.Email) + t.Errorf("unexpected session email") + } + + if session.AccessToken != tc.expectedSession.AccessToken { + log.Printf("expected access token %s", tc.expectedSession.AccessToken) + log.Printf("got %s", session.AccessToken) + t.Errorf("unexpected access token") + } + + if session.RefreshToken != tc.expectedSession.RefreshToken { + log.Printf("expected refresh token %s", tc.expectedSession.RefreshToken) + log.Printf("got %s", session.RefreshToken) + t.Errorf("unexpected session refresh token") + } + } + }) + } +} + +type groupsClientMock struct { +} + +func (c *groupsClientMock) Do(req *http.Request) (*http.Response, error) { + return &http.Response{}, nil +} + +func TestAzureV2ValidateGroupMembers(t *testing.T) { + testCases := []struct { + name string + inputAllowedGroups []string + groups []string + groupsError error + getMembersFunc func(string) (groups.MemberSet, bool) + expectedGroups []string + expectedErrorString string + }{ + { + name: "empty input groups should return an empty string", + inputAllowedGroups: []string{}, + groups: []string{"group1"}, + expectedGroups: []string{"group1"}, + getMembersFunc: func(string) (groups.MemberSet, bool) { return nil, false }, + }, + { + name: "empty inputs and error on groups resource should return error", + inputAllowedGroups: []string{}, + getMembersFunc: func(string) (groups.MemberSet, bool) { return nil, false }, + groupsError: fmt.Errorf("error"), + expectedErrorString: "error", + }, + { + name: "member exists in cache, should not call groups resource", + inputAllowedGroups: []string{"group1"}, + groupsError: fmt.Errorf("should not get here"), + getMembersFunc: func(string) (groups.MemberSet, bool) { return groups.MemberSet{"email": {}}, true }, + expectedGroups: []string{"group1"}, + }, + { + name: "member does not exist in cache, should still not call groups resource", + inputAllowedGroups: []string{"group1"}, + groupsError: fmt.Errorf("should not get here"), + getMembersFunc: func(string) (groups.MemberSet, bool) { return groups.MemberSet{}, true }, + expectedGroups: []string{}, + }, + { + name: "subset of groups are not cached, calls groups resource", + inputAllowedGroups: []string{"group1", "group2"}, + groups: []string{"group1", "group2", "group3"}, + groupsError: nil, + getMembersFunc: func(group string) (groups.MemberSet, bool) { + switch group { + case "group1": + return groups.MemberSet{"email": {}}, true + default: + return groups.MemberSet{}, false + } + }, + expectedGroups: []string{"group1", "group2"}, + }, + { + name: "subset of groups are not cached, calls groups resource with error", + inputAllowedGroups: []string{"group1", "group2"}, + groupsError: fmt.Errorf("error"), + getMembersFunc: func(group string) (groups.MemberSet, bool) { + switch group { + case "group1": + return groups.MemberSet{"email": {}}, true + default: + return groups.MemberSet{}, false + } + }, + expectedErrorString: "error", + }, + { + name: "subset of groups not there, does not call groups resource", + inputAllowedGroups: []string{"group1", "group2"}, + groups: []string{"group1", "group2", "group3"}, + groupsError: fmt.Errorf("should not get here"), + getMembersFunc: func(group string) (groups.MemberSet, bool) { + switch group { + case "group1": + return groups.MemberSet{"email": {}}, true + default: + return groups.MemberSet{}, true + } + }, + expectedGroups: []string{"group1"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p := newAzureV2Provider(nil) + p.GraphService = &MockAzureGraphService{Groups: tc.groups, GroupsError: tc.groupsError} + + groups, err := p.ValidateGroupMembership("email", tc.inputAllowedGroups) + + if err != nil { + if tc.expectedErrorString != err.Error() { + t.Errorf("expected error %s but err was %s", tc.expectedErrorString, err) + } + } + if !reflect.DeepEqual(tc.expectedGroups, groups) { + t.Logf("expected groups %v", tc.expectedGroups) + t.Logf("got groups %v", groups) + t.Errorf("unexpected groups returned") + } + + }) + } +} diff --git a/internal/auth/providers/google_test.go b/internal/auth/providers/google_test.go index c5fc39b0..0eaed4ff 100644 --- a/internal/auth/providers/google_test.go +++ b/internal/auth/providers/google_test.go @@ -20,12 +20,14 @@ import ( func newProviderServer(body []byte, code int) (*url.URL, *httptest.Server) { s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Header().Set("Content-Type", "application/json") rw.WriteHeader(code) rw.Write(body) })) u, _ := url.Parse(s.URL) return u, s } + func newGoogleProvider(providerData *ProviderData) *GoogleProvider { if providerData == nil { providerData = &ProviderData{ @@ -310,7 +312,7 @@ func TestGoogleProviderRevoke(t *testing.T) { } } -func TestValidateGroupMembers(t *testing.T) { +func TestGoogleValidateGroupMembers(t *testing.T) { testCases := []struct { name string inputAllowedGroups []string diff --git a/internal/auth/providers/internal_util.go b/internal/auth/providers/internal_util.go index 2505cf46..cf1940bc 100644 --- a/internal/auth/providers/internal_util.go +++ b/internal/auth/providers/internal_util.go @@ -50,6 +50,7 @@ func stripParam(param, endpoint string) string { func validateToken(p Provider, accessToken string, header http.Header) bool { logger := log.NewLogEntry() + logger.Info(p.Data().ValidateURL) if accessToken == "" || p.Data().ValidateURL == nil { return false } diff --git a/internal/auth/providers/oidc.go b/internal/auth/providers/oidc.go index 7aa0d3b9..01d8acb1 100644 --- a/internal/auth/providers/oidc.go +++ b/internal/auth/providers/oidc.go @@ -2,11 +2,13 @@ package providers import ( "context" + "errors" "fmt" "time" "golang.org/x/oauth2" + log "github.com/buzzfeed/sso/internal/pkg/logging" "github.com/buzzfeed/sso/internal/pkg/sessions" oidc "github.com/coreos/go-oidc" ) @@ -81,12 +83,56 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat // RefreshSessionIfNeeded takes in a SessionState and // returns false if the session is not refreshed and true if it is. func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { - if s == nil || s.RefreshDeadline.After(time.Now()) || s.RefreshToken == "" { + // if s == nil || s.RefreshDeadline.After(time.Now()) || s.RefreshToken == "" { + // return false, nil + // } + // + // origExpiration := s.RefreshDeadline + // s.RefreshDeadline = time.Now().Add(time.Second).Truncate(time.Second) + // fmt.Printf("refreshed access token %s (expired on %s)\n", s, origExpiration) + // return false, nil + + if s == nil || !s.RefreshPeriodExpired() || s.RefreshToken == "" { return false, nil } - origExpiration := s.RefreshDeadline - s.RefreshDeadline = time.Now().Add(time.Second).Truncate(time.Second) - fmt.Printf("refreshed access token %s (expired on %s)\n", s, origExpiration) - return false, nil + newToken, duration, err := p.RefreshAccessToken(s.RefreshToken) + if err != nil { + return false, err + } + logger := log.NewLogEntry() + + s.AccessToken = newToken + + s.RefreshDeadline = time.Now().Add(duration).Truncate(time.Second) + logger.WithUser(s.Email).WithRefreshDeadline(s.RefreshDeadline).Info("refreshed access token") + + return true, nil +} + +// RefreshAccessToken uses default OAuth2 TokenSource method to get a new +// access token. +func (p *OIDCProvider) RefreshAccessToken(refreshToken string) (string, time.Duration, error) { + if refreshToken == "" { + return "", 0, errors.New("missing refresh token") + } + + ctx := context.Background() + c := oauth2.Config{ + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: p.RedeemURL.String(), + }, + } + t := oauth2.Token{ + RefreshToken: refreshToken, + } + ts := c.TokenSource(ctx, &t) + newToken, err := ts.Token() + if err != nil { + return "", 0, fmt.Errorf("token exchange: %v", err) + } + + return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil } diff --git a/internal/auth/providers/provider_default.go b/internal/auth/providers/provider_default.go index 3295ad21..91eabfbb 100644 --- a/internal/auth/providers/provider_default.go +++ b/internal/auth/providers/provider_default.go @@ -122,7 +122,7 @@ func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, e return false, nil } -// RefreshAccessToken returns a nont implemented error. +// RefreshAccessToken returns a not implemented error. func (p *ProviderData) RefreshAccessToken(refreshToken string) (string, time.Duration, error) { return "", 0, ErrNotImplemented } diff --git a/internal/auth/providers/singleflight_middleware.go b/internal/auth/providers/singleflight_middleware.go index afe9e77e..995b4c1c 100644 --- a/internal/auth/providers/singleflight_middleware.go +++ b/internal/auth/providers/singleflight_middleware.go @@ -63,6 +63,8 @@ func (p *SingleFlightProvider) do(endpoint, key string, fn func() (interface{}, func (p *SingleFlightProvider) AssignStatsdClient(StatsdClient *statsd.Client) { p.StatsdClient = StatsdClient switch v := p.provider.(type) { + case *AzureV2Provider: + v.SetStatsdClient(StatsdClient) case *GoogleProvider: v.SetStatsdClient(StatsdClient) } diff --git a/internal/pkg/aead/aead.go b/internal/pkg/aead/aead.go index a837719e..78df486f 100644 --- a/internal/pkg/aead/aead.go +++ b/internal/pkg/aead/aead.go @@ -1,10 +1,13 @@ package aead import ( + "bytes" + "compress/gzip" "crypto/cipher" "encoding/base64" "encoding/json" "fmt" + "io" "sync" miscreant "github.com/miscreant/miscreant-go" @@ -95,8 +98,14 @@ func (c *MiscreantCipher) Marshal(s interface{}) (string, error) { return "", err } + // gzip the bytes + var jsonBuffer bytes.Buffer + w := gzip.NewWriter(&jsonBuffer) + w.Write(plaintext) + w.Close() + // encrypt the JSON - ciphertext, err := c.Encrypt(plaintext) + ciphertext, err := c.Encrypt(jsonBuffer.Bytes()) if err != nil { return "", err } @@ -121,8 +130,13 @@ func (c *MiscreantCipher) Unmarshal(value string, s interface{}) error { return err } + // gunzip the bytes + var jsonBuffer bytes.Buffer + r, err := gzip.NewReader(bytes.NewBuffer(plaintext)) + io.Copy(&jsonBuffer, r) + // unmarshal bytes - err = json.Unmarshal(plaintext, s) + err = json.Unmarshal(jsonBuffer.Bytes(), s) if err != nil { return err } diff --git a/internal/pkg/sessions/session_state.go b/internal/pkg/sessions/session_state.go index 5cc2e492..090b073d 100644 --- a/internal/pkg/sessions/session_state.go +++ b/internal/pkg/sessions/session_state.go @@ -19,7 +19,6 @@ var ( type SessionState struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` - IDToken string `json:"id_token"` RefreshDeadline time.Time `json:"refresh_deadline"` LifetimeDeadline time.Time `json:"lifetime_deadline"` diff --git a/internal/proxy/oauthproxy.go b/internal/proxy/oauthproxy.go index da9d9a75..5bd5549e 100755 --- a/internal/proxy/oauthproxy.go +++ b/internal/proxy/oauthproxy.go @@ -341,6 +341,10 @@ func NewOAuthProxy(opts *Options, optFuncs ...func(*OAuthProxy) error) (*OAuthPr } for _, upstreamConfig := range opts.upstreamConfigs { + logger.Printf("upstreamConfig.Route: %v", upstreamConfig.Route) + logger.Printf("upstreamConfig.RouteConfig: %v", upstreamConfig.RouteConfig) + logger.Printf("upstreamConfig.RouteConfig.Options: %v", upstreamConfig.RouteConfig.Options) + switch route := upstreamConfig.Route.(type) { case *SimpleRoute: reverseProxy := NewReverseProxy(route.ToURL, upstreamConfig) @@ -470,6 +474,99 @@ func (p *OAuthProxy) redeemCode(host, code string) (*sessions.SessionState, erro return s, nil } +// MakeSessionCookie constructs a session cookie given the request, an expiration time and the current time. +func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { + return p.makeCookie(req, p.CookieName, value, expiration, now) +} + +// MakeCSRFCookie creates a CSRF cookie given the request, an expiration time, and the current time. +func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { + return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) +} + +func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { + logger := log.NewLogEntry() + + domain := req.Host + if h, _, err := net.SplitHostPort(domain); err == nil { + domain = h + } + if p.CookieDomain != "" { + if !strings.HasSuffix(domain, p.CookieDomain) { + logger.WithRequestHost(domain).WithCookieDomain(p.CookieDomain).Warn( + "using configured cookie domain") + } + domain = p.CookieDomain + } + + return &http.Cookie{ + Name: name, + Value: value, + Path: "/", + Domain: domain, + HttpOnly: p.CookieHTTPOnly, + Secure: p.CookieSecure, + Expires: now.Add(expiration), + } +} + +// ClearCSRFCookie clears the CSRF cookie from the request +func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) { + http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now())) +} + +// SetCSRFCookie sets the CSRFCookie creates a CSRF cookie in a given request +func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) { + http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now())) +} + +// ClearSessionCookie clears the session cookie from a request +func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { + http.SetCookie(rw, p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())) +} + +// SetSessionCookie creates a sesion cookie based on the value and the expiration time. +func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { + http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now())) +} + +// LoadCookiedSession returns a SessionState from the cookie in the request. +func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, error) { + logger := log.NewLogEntry().WithRemoteAddress(getRemoteAddr(req)) + + c, err := req.Cookie(p.CookieName) + if err != nil { + // always http.ErrNoCookie + return nil, err + } + + session, err := providers.UnmarshalSession(c.Value, p.CookieCipher) + if err != nil { + tags := []string{"error:unmarshaling_session"} + p.StatsdClient.Incr("application_error", tags, 1.0) + logger.Error(err, "unable to unmarshal session") + return nil, ErrInvalidSession + } + + logger.Printf("session loaded: %v", session) + + return session, nil +} + +// SaveSession saves a session state to a request cookie. +func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { + value, err := providers.MarshalSession(s, p.CookieCipher) + if err != nil { + return err + } + + p.SetSessionCookie(rw, req, value) + logger := log.NewLogEntry().WithRemoteAddress(getRemoteAddr(req)) + logger.Printf("session saved: %v", s) + + return nil +} + // RobotsTxt sets the User-Agent header in the response to be "Disallow" func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter, _ *http.Request) { rw.WriteHeader(http.StatusOK) @@ -788,6 +885,12 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { allowedGroups := route.upstreamConfig.AllowedGroups inGroups, validGroup, err := p.provider.ValidateGroup(session.Email, allowedGroups) + logger.Printf("route: %v", route) + logger.Printf("route.upstreamConfig: %v", route.upstreamConfig) + logger.Printf("inGroups: %v", inGroups) + logger.Printf("allowedGroups: %v", allowedGroups) + logger.Printf("validGroup: %v", validGroup) + logger.Printf("err: %v", err) if err != nil { tags = append(tags, "error:user_group_failed") p.StatsdClient.Incr("provider_error", tags, 1.0) @@ -997,6 +1100,9 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er return ErrUserNotAuthorized } + logger.Printf("proxied full session: %v", session) + logger.Printf("proxied groups: %v", session.Groups) + req.Header.Set("X-Forwarded-User", session.User) if p.PassAccessToken && session.AccessToken != "" { diff --git a/internal/proxy/providers/sso.go b/internal/proxy/providers/sso.go index 00e9e4ad..8ed1b845 100644 --- a/internal/proxy/providers/sso.go +++ b/internal/proxy/providers/sso.go @@ -170,9 +170,11 @@ func (p *SSOProvider) Redeem(redirectURL, code string) (*sessions.SessionState, // an authorized group. func (p *SSOProvider) ValidateGroup(email string, allowedGroups []string) ([]string, bool, error) { logger := log.NewLogEntry() + logger.Info("called sso.go ValidateGroup") logger.WithUser(email).WithAllowedGroups(allowedGroups).Info("validating groups") inGroups := []string{} + logger.Printf("allowedGroups: %v", allowedGroups) if len(allowedGroups) == 0 { return inGroups, true, nil } diff --git a/internal/proxy/proxy_config.go b/internal/proxy/proxy_config.go index 1fd6e915..853748aa 100644 --- a/internal/proxy/proxy_config.go +++ b/internal/proxy/proxy_config.go @@ -8,6 +8,7 @@ import ( "time" "github.com/18F/hmacauth" + log "github.com/buzzfeed/sso/internal/pkg/logging" "github.com/imdario/mergo" "gopkg.in/yaml.v2" ) @@ -185,6 +186,9 @@ func loadServiceConfigs(raw []byte, cluster, scheme string, configVars map[strin if err != nil { return nil, err } + logger := log.NewLogEntry() + logger.Printf("proxy.AllowedGroups: %v", proxy.AllowedGroups) + } for _, proxy := range configs { From 83e43a1787ac028681bef0cb24a9c1d9050ca428 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Thu, 15 Nov 2018 16:37:07 -0800 Subject: [PATCH 03/38] Fix group tests, add sign-in tests, and generate nonces --- internal/auth/providers/azure.go | 18 ++- internal/auth/providers/azure_test.go | 171 ++++++++++++++++---------- 2 files changed, 119 insertions(+), 70 deletions(-) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index cac4c638..1d24a145 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -2,6 +2,9 @@ package providers import ( "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" "errors" "fmt" "net/http" @@ -99,7 +102,6 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (s *sessions.SessionS // TODO: test this w/ an account that uses an alias and compare email claim // with UPN claim; UPN has usually been what you want, but I think it's not // rendered as a full email address here. - // FIXME: validate nonce against session s = &sessions.SessionState{ AccessToken: token.AccessToken, @@ -245,12 +247,24 @@ func (p *AzureV2Provider) GetSignInURL(redirectURI, state string) string { params.Add("scope", p.Scope) params.Add("state", state) params.Set("prompt", p.ApprovalPrompt) - params.Set("nonce", "FIXME") // FIXME, maybe change to session state struct + params.Set("nonce", p.calculateNonce(state)) // required parameter a.RawQuery = params.Encode() return a.String() } +// calculateNonce generates a deterministic nonce from the state value. +// We don't have a session state pointer but we need to generate a nonce +// that we can verify statelessly later. We can only use what's in the +// params and provider struct to assemble a nonce. State is guaranteed to be +// indistinguishable from random and will always change. +func (p *AzureV2Provider) calculateNonce(state string) string { + key := []byte(p.ClientID + p.ClientSecret) + h := hmac.New(sha256.New, key) + h.Write([]byte(state)) + return base64.URLEncoding.EncodeToString(h.Sum(nil))[:8] +} + // ValidateGroupMembership takes in an email and the allowed groups and returns the groups that the email is part of in that list. // If `allGroups` is an empty list it returns all the groups that the user belongs to. func (p *AzureV2Provider) ValidateGroupMembership(email string, allGroups []string) ([]string, error) { diff --git a/internal/auth/providers/azure_test.go b/internal/auth/providers/azure_test.go index 40dea165..df439a3e 100644 --- a/internal/auth/providers/azure_test.go +++ b/internal/auth/providers/azure_test.go @@ -15,7 +15,6 @@ import ( "testing" "time" - "github.com/buzzfeed/sso/internal/pkg/groups" "github.com/buzzfeed/sso/internal/pkg/sessions" "github.com/buzzfeed/sso/internal/pkg/testutil" @@ -369,96 +368,133 @@ func (c *groupsClientMock) Do(req *http.Request) (*http.Response, error) { return &http.Response{}, nil } -func TestAzureV2ValidateGroupMembers(t *testing.T) { +func TestAzureV2GetSignInURL(t *testing.T) { testCases := []struct { - name string - inputAllowedGroups []string - groups []string - groupsError error - getMembersFunc func(string) (groups.MemberSet, bool) - expectedGroups []string - expectedErrorString string + name string + redirectURI string + state string + expectedParams url.Values }{ { - name: "empty input groups should return an empty string", - inputAllowedGroups: []string{}, - groups: []string{"group1"}, - expectedGroups: []string{"group1"}, - getMembersFunc: func(string) (groups.MemberSet, bool) { return nil, false }, + name: "nonce values passed to azure should be deterministic, pass one", + redirectURI: "https://example.com/oauth/callback", + state: "1234", + expectedParams: url.Values{ + "redirect_uri": []string{"https://example.com/oauth/callback"}, + "response_mode": []string{"form_post"}, + "response_type": []string{"id_token code"}, + "scope": []string{"openid email profile offline_access"}, + "state": []string{"1234"}, + "client_id": []string{TestClientID}, + "nonce": []string{"KEB9Aopa"}, + "prompt": []string{"consent"}, + }, }, { - name: "empty inputs and error on groups resource should return error", - inputAllowedGroups: []string{}, - getMembersFunc: func(string) (groups.MemberSet, bool) { return nil, false }, - groupsError: fmt.Errorf("error"), - expectedErrorString: "error", + name: "nonce values passed to azure should be deterministic, pass two", + redirectURI: "https://example.com/oauth/callback", + state: "1234", + expectedParams: url.Values{ + "redirect_uri": []string{"https://example.com/oauth/callback"}, + "response_mode": []string{"form_post"}, + "response_type": []string{"id_token code"}, + "scope": []string{"openid email profile offline_access"}, + "state": []string{"1234"}, + "client_id": []string{TestClientID}, + "nonce": []string{"KEB9Aopa"}, + "prompt": []string{"consent"}, + }, }, { - name: "member exists in cache, should not call groups resource", - inputAllowedGroups: []string{"group1"}, - groupsError: fmt.Errorf("should not get here"), - getMembersFunc: func(string) (groups.MemberSet, bool) { return groups.MemberSet{"email": {}}, true }, - expectedGroups: []string{"group1"}, + name: "nonce values passed to azure should be deterministic, pass three", + redirectURI: "https://example.com/oauth/callback", + state: "4321", + expectedParams: url.Values{ + "redirect_uri": []string{"https://example.com/oauth/callback"}, + "response_mode": []string{"form_post"}, + "response_type": []string{"id_token code"}, + "scope": []string{"openid email profile offline_access"}, + "state": []string{"4321"}, + "client_id": []string{TestClientID}, + "nonce": []string{"x_PhEN0K"}, + "prompt": []string{"consent"}, + }, }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p := newAzureV2Provider(nil) + p.ClientID = TestClientID + p.ClientSecret = "456" + p.Scope = "openid email profile offline_access" + p.ApprovalPrompt = "consent" + + signInURL := p.GetSignInURL(tc.redirectURI, tc.state) + parsedURL, err := url.Parse(signInURL) + if err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(tc.expectedParams, parsedURL.Query()) { + t.Logf("expected params %+v", tc.expectedParams) + t.Logf("got params %+v", parsedURL.Query()) + t.Errorf("unexpected params returned") + } + }) + } +} + +func TestAzureV2ValidateGroupMembers(t *testing.T) { + testCases := []struct { + name string + allowedGroups []string + mockedGroups []string + mockedError error + expectedGroups []string + expectedErrorString string + }{ { - name: "member does not exist in cache, should still not call groups resource", - inputAllowedGroups: []string{"group1"}, - groupsError: fmt.Errorf("should not get here"), - getMembersFunc: func(string) (groups.MemberSet, bool) { return groups.MemberSet{}, true }, - expectedGroups: []string{}, + name: "allowed groups and groups resource output exactly match should return all groups", + allowedGroups: []string{"group1", "group2", "group3"}, + mockedGroups: []string{"group1", "group2", "group3"}, + expectedGroups: []string{"group1", "group2", "group3"}, }, { - name: "subset of groups are not cached, calls groups resource", - inputAllowedGroups: []string{"group1", "group2"}, - groups: []string{"group1", "group2", "group3"}, - groupsError: nil, - getMembersFunc: func(group string) (groups.MemberSet, bool) { - switch group { - case "group1": - return groups.MemberSet{"email": {}}, true - default: - return groups.MemberSet{}, false - } - }, + name: "allowed groups should restrict to subset of groups", + allowedGroups: []string{"group1", "group2"}, + mockedGroups: []string{"group1", "group2", "group3"}, expectedGroups: []string{"group1", "group2"}, }, { - name: "subset of groups are not cached, calls groups resource with error", - inputAllowedGroups: []string{"group1", "group2"}, - groupsError: fmt.Errorf("error"), - getMembersFunc: func(group string) (groups.MemberSet, bool) { - switch group { - case "group1": - return groups.MemberSet{"email": {}}, true - default: - return groups.MemberSet{}, false - } - }, - expectedErrorString: "error", + name: "allowed groups superset should not restrict to subset of groups", + allowedGroups: []string{"group1", "group2", "group3"}, + mockedGroups: []string{"group1", "group2"}, + expectedGroups: []string{"group1", "group2"}, }, { - name: "subset of groups not there, does not call groups resource", - inputAllowedGroups: []string{"group1", "group2"}, - groups: []string{"group1", "group2", "group3"}, - groupsError: fmt.Errorf("should not get here"), - getMembersFunc: func(group string) (groups.MemberSet, bool) { - switch group { - case "group1": - return groups.MemberSet{"email": {}}, true - default: - return groups.MemberSet{}, true - } - }, + name: "groups allowed zero value should default to return all groups", + allowedGroups: []string{}, + mockedGroups: []string{"group1"}, expectedGroups: []string{"group1"}, }, + { + name: "empty inputs and error on groups resource should return error", + allowedGroups: []string{}, + mockedError: fmt.Errorf("error"), + expectedErrorString: "error", + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { p := newAzureV2Provider(nil) - p.GraphService = &MockAzureGraphService{Groups: tc.groups, GroupsError: tc.groupsError} + p.GraphService = &MockAzureGraphService{ + Groups: tc.mockedGroups, + GroupsError: tc.mockedError, + } - groups, err := p.ValidateGroupMembership("email", tc.inputAllowedGroups) + groups, err := p.ValidateGroupMembership("test@example.com", tc.allowedGroups) if err != nil { if tc.expectedErrorString != err.Error() { @@ -470,7 +506,6 @@ func TestAzureV2ValidateGroupMembers(t *testing.T) { t.Logf("got groups %v", groups) t.Errorf("unexpected groups returned") } - }) } } From 4e51eaea69715d81389734402bd294c4513cad60 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Thu, 15 Nov 2018 17:32:08 -0800 Subject: [PATCH 04/38] Add golang-lru to Godeps --- Godeps | 1 + 1 file changed, 1 insertion(+) diff --git a/Godeps b/Godeps index 47666da0..9bd42775 100644 --- a/Godeps +++ b/Godeps @@ -16,3 +16,4 @@ github.com/rakyll/statik v0.1.4 github.com/coreos/go-oidc v2.0.0 gopkg.in/square/go-jose.v2 v2.1.9 github.com/pquerna/cachecontrol 1555304b9b35fdd2b425bccf1a5613677705e7d0 +github.com/hashicorp/golang-lru v0.5.0 From 413d831b02cbbfa6a0a56242eba7e8dd6b632bf3 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 15:24:31 -0800 Subject: [PATCH 05/38] Expand abbreviation in comment --- internal/auth/authenticator.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/auth/authenticator.go b/internal/auth/authenticator.go index b07b4d76..e443a30c 100644 --- a/internal/auth/authenticator.go +++ b/internal/auth/authenticator.go @@ -533,8 +533,8 @@ func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) { // Here we validate the redirect that is nested within the redirect_uri. // `authRedirectURL` points to step D, `proxyRedirectURL` points to step E. // - // A* B C D E - // /start -> IdP -> auth /callback -> /sign_in -> proxy /callback + // A* B C D E + // /start -> IdProvider -> auth /callback -> /sign_in -> proxy /callback // // * you are here proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri")) From 3fcdff41f0857d96a769ab9364ecdb8b184fc78b Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 17:18:33 -0800 Subject: [PATCH 06/38] Update documentation and compile time checking of provider interface --- internal/auth/providers/azure.go | 4 ++++ internal/auth/providers/google.go | 2 +- internal/auth/providers/singleflight_middleware.go | 2 +- internal/proxy/providers/singleflight_middleware.go | 2 +- internal/proxy/providers/sso.go | 2 +- 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index 1d24a145..9342fae5 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -23,6 +23,10 @@ import ( var ( azureOIDCConfigURL = "https://login.microsoftonline.com/{tenant}/v2.0" azureOIDCProfileURL = "https://graph.microsoft.com/oidc/userinfo" + + // This is a compile-time check to make sure our types correctly implement the interface: + // https://medium.com/@matryer/c167afed3aae + _ Provider = &AzureV2Provider{} ) // AzureV2Provider is an Azure AD v2 specific implementation of the Provider interface. diff --git a/internal/auth/providers/google.go b/internal/auth/providers/google.go index e619216f..eda238cf 100644 --- a/internal/auth/providers/google.go +++ b/internal/auth/providers/google.go @@ -23,7 +23,7 @@ import ( var ( // This is a compile-time check to make sure our types correctly implement the interface: - // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae + // https://medium.com/@matryer/c167afed3aae _ Provider = &GoogleProvider{} ) diff --git a/internal/auth/providers/singleflight_middleware.go b/internal/auth/providers/singleflight_middleware.go index 995b4c1c..8a1d781d 100644 --- a/internal/auth/providers/singleflight_middleware.go +++ b/internal/auth/providers/singleflight_middleware.go @@ -15,7 +15,7 @@ import ( var ( // This is a compile-time check to make sure our types correctly implement the interface: - // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae + // https://medium.com/@matryer/c167afed3aae _ Provider = &SingleFlightProvider{} ) diff --git a/internal/proxy/providers/singleflight_middleware.go b/internal/proxy/providers/singleflight_middleware.go index 1fb03fd2..7e1b4bdc 100644 --- a/internal/proxy/providers/singleflight_middleware.go +++ b/internal/proxy/providers/singleflight_middleware.go @@ -15,7 +15,7 @@ import ( var ( // This is a compile-time check to make sure our types correctly implement the interface: - // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae + // https://medium.com/@matryer/c167afed3aae _ Provider = &SingleFlightProvider{} ) diff --git a/internal/proxy/providers/sso.go b/internal/proxy/providers/sso.go index 8ed1b845..2b39bcab 100644 --- a/internal/proxy/providers/sso.go +++ b/internal/proxy/providers/sso.go @@ -23,7 +23,7 @@ import ( var ( // This is a compile-time check to make sure our types correctly implement the interface: - // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae + // https://medium.com/@matryer/c167afed3aae _ Provider = &SSOProvider{} ) From 315d1bd74dc36d359a19b96b30d6f15fcc57fa31 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 17:20:26 -0800 Subject: [PATCH 07/38] Update Azure provider constant for consistency --- internal/auth/providers/providers.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/auth/providers/providers.go b/internal/auth/providers/providers.go index 10405960..1629726b 100644 --- a/internal/auth/providers/providers.go +++ b/internal/auth/providers/providers.go @@ -29,7 +29,7 @@ var ( const ( // AzureProviderName identifies the Azure AD v2 provider - AzureProviderName = "azure" + AzureProviderName = "azure_v2" // GoogleProviderName identifies the Google provider GoogleProviderName = "google" // OIDCProviderName identifies the OpenID Connect provider From 38d052322fbfe6b42232196fe8b69c52ff1f3d62 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 17:27:06 -0800 Subject: [PATCH 08/38] Fix up comments related to common default tenant ID --- internal/auth/providers/azure.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index 9342fae5..c405082c 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -131,7 +131,9 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (s *sessions.SessionS // Configure sets the Azure tenant ID value for the provider func (p *AzureV2Provider) Configure(tenant string) error { p.Tenant = tenant - if tenant == "" { + if p.Tenant == "" { + // TODO: See below, "common" is the right default value, and while + // Azure AD docs suggest this should work, it results in an error. p.Tenant = "common" } discoveryURL := strings.Replace(azureOIDCConfigURL, "{tenant}", p.Tenant, -1) @@ -139,7 +141,7 @@ func (p *AzureV2Provider) Configure(tenant string) error { // Configure discoverable provider data. oidcProvider, err := oidc.NewProvider(context.Background(), discoveryURL) if err != nil { - // FIXME: this seems like it _should_ work for "common", but it doesn't + // TODO: This seems like it _should_ work for "common", but it doesn't // Does anyone actually want to use this with "common" though? return err } From 4f8b3302a35089006b9d0de65794210b3075ec07 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 17:29:45 -0800 Subject: [PATCH 09/38] Document odata pagination hyperlinking --- internal/auth/providers/azure_graph.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/auth/providers/azure_graph.go b/internal/auth/providers/azure_graph.go index 65cdbd03..0bb6fbda 100644 --- a/internal/auth/providers/azure_graph.go +++ b/internal/auth/providers/azure_graph.go @@ -80,6 +80,8 @@ func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { } groupData := struct { + // Link to next page of data, see: + // https://docs.microsoft.com/en-us/graph/query-parameters#skip-parameter Next string `json:"@odata.nextLink"` Value []string `json:"value"` }{} From 23e5067822ae3b01948be79f731ada3dcd9988e9 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 17:32:12 -0800 Subject: [PATCH 10/38] Remove unused helper function --- internal/auth/providers/azure.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index c405082c..80c2a11c 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -7,7 +7,6 @@ import ( "encoding/base64" "errors" "fmt" - "net/http" "net/url" "strings" "time" @@ -228,12 +227,6 @@ func (p *AzureV2Provider) RefreshAccessToken(refreshToken string) (string, time. return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil } -func getAzureHeader(accessToken string) http.Header { - header := make(http.Header) - header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) - return header -} - // ValidateSessionState attempts to validate the session state's access token. func (p *AzureV2Provider) ValidateSessionState(s *sessions.SessionState) bool { // return validateToken(p, s.AccessToken, nil) From 97957bc6c7a5afec25e11dc9b002fc6a461c5f8e Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 17:34:46 -0800 Subject: [PATCH 11/38] Remove unused contains helper function --- internal/auth/providers/azure.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index 80c2a11c..2ad3f461 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -292,12 +292,3 @@ func (p *AzureV2Provider) ValidateGroupMembership(email string, allGroups []stri return filtered, nil } - -func contains(s []string, e string) bool { - for _, a := range s { - if a == e { - return true - } - } - return false -} From 70b7c20deea225483a404b401d1fe9c580f267d0 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 17:36:44 -0800 Subject: [PATCH 12/38] Removed commented out code --- internal/auth/providers/oidc.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/internal/auth/providers/oidc.go b/internal/auth/providers/oidc.go index 01d8acb1..18de54e8 100644 --- a/internal/auth/providers/oidc.go +++ b/internal/auth/providers/oidc.go @@ -83,15 +83,6 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat // RefreshSessionIfNeeded takes in a SessionState and // returns false if the session is not refreshed and true if it is. func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { - // if s == nil || s.RefreshDeadline.After(time.Now()) || s.RefreshToken == "" { - // return false, nil - // } - // - // origExpiration := s.RefreshDeadline - // s.RefreshDeadline = time.Now().Add(time.Second).Truncate(time.Second) - // fmt.Printf("refreshed access token %s (expired on %s)\n", s, origExpiration) - // return false, nil - if s == nil || !s.RefreshPeriodExpired() || s.RefreshToken == "" { return false, nil } From 0d4d57ed729546ec4b82d34e970aeb69c99ad18b Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 17:37:59 -0800 Subject: [PATCH 13/38] Remove unnecessary new line from comment --- internal/auth/providers/azure.go | 3 +-- internal/auth/providers/oidc.go | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index 2ad3f461..39bfe345 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -198,8 +198,7 @@ func (p *AzureV2Provider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool return true, nil } -// RefreshAccessToken uses default OAuth2 TokenSource method to get a new -// access token. +// RefreshAccessToken uses default OAuth2 TokenSource method to get a new access token. func (p *AzureV2Provider) RefreshAccessToken(refreshToken string) (string, time.Duration, error) { if refreshToken == "" { return "", 0, errors.New("missing refresh token") diff --git a/internal/auth/providers/oidc.go b/internal/auth/providers/oidc.go index 18de54e8..2aae35b2 100644 --- a/internal/auth/providers/oidc.go +++ b/internal/auth/providers/oidc.go @@ -101,8 +101,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, e return true, nil } -// RefreshAccessToken uses default OAuth2 TokenSource method to get a new -// access token. +// RefreshAccessToken uses default OAuth2 TokenSource method to get a new access token. func (p *OIDCProvider) RefreshAccessToken(refreshToken string) (string, time.Duration, error) { if refreshToken == "" { return "", 0, errors.New("missing refresh token") From 837ad426868a5864cd3ae8b9dddadcc69e3a59bb Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 18:02:34 -0800 Subject: [PATCH 14/38] Switch things to package private that don't need to be exported --- internal/auth/providers/azure_graph.go | 12 ++++++------ internal/auth/providers/azure_test.go | 26 ++++++++++++++------------ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/internal/auth/providers/azure_graph.go b/internal/auth/providers/azure_graph.go index 0bb6fbda..1ad9e161 100644 --- a/internal/auth/providers/azure_graph.go +++ b/internal/auth/providers/azure_graph.go @@ -15,8 +15,8 @@ import ( "golang.org/x/oauth2/clientcredentials" ) -// AzureGroupCacheSize controls the size of the caches of AD group info -const AzureGroupCacheSize = 1024 +// azureGroupCacheSize controls the size of the caches of AD group info +const azureGroupCacheSize = 1024 // GraphService wraps calls to provider admin APIs type GraphService interface { @@ -42,13 +42,13 @@ func NewAzureGraphService(clientID string, clientSecret string, tokenURL string) } ctx := context.Background() client := clientConfig.Client(ctx) - memberCache, err := lru.New(AzureGroupCacheSize) + memberCache, err := lru.New(azureGroupCacheSize) if err != nil { - panic(err) // Should only happen if AzureGroupCacheSize is a negative number + panic(err) // Should only happen if azureGroupCacheSize is a negative number } - nameCache, err := lru.New(AzureGroupCacheSize) + nameCache, err := lru.New(azureGroupCacheSize) if err != nil { - panic(err) // Should only happen if AzureGroupCacheSize is a negative number + panic(err) // Should only happen if azureGroupCacheSize is a negative number } return &AzureGraphService{ client: client, diff --git a/internal/auth/providers/azure_test.go b/internal/auth/providers/azure_test.go index df439a3e..de4e1e15 100644 --- a/internal/auth/providers/azure_test.go +++ b/internal/auth/providers/azure_test.go @@ -22,8 +22,10 @@ import ( "gopkg.in/square/go-jose.v2/jwt" ) -const MicrosoftTenantID = "9188040d-6c67-4c5b-b112-36a304b66dad" -const TestClientID = "a4c35c92-e858-41e8-bd2c-ade04cb622b1" +const ( + microsoftTenantID = "9188040d-6c67-4c5b-b112-36a304b66dad" + testClientID = "a4c35c92-e858-41e8-bd2c-ade04cb622b1" +) func newAzureProviderServer(redeemBody *[]byte, redeemCode int, pubKey *rsa.PublicKey) (*url.URL, *httptest.Server) { var u *url.URL @@ -127,7 +129,7 @@ func TestAzureV2ProviderDefaults(t *testing.T) { for _, expected := range expectedResults { t.Run(expected.name, func(t *testing.T) { p := newAzureV2Provider(expected.providerData) - err := p.Configure(MicrosoftTenantID) + err := p.Configure(microsoftTenantID) if err != nil { t.Error(err) } @@ -215,7 +217,7 @@ func TestAzureV2ProviderRedeem(t *testing.T) { name: "redeem", claims: &claims{ Issuer: "{mock-issuer}", - Audience: TestClientID, + Audience: testClientID, Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), Name: "Michael Bland", Email: "michael.bland@gsa.gov", @@ -234,7 +236,7 @@ func TestAzureV2ProviderRedeem(t *testing.T) { { name: "missing issuer", claims: &claims{ - Audience: TestClientID, + Audience: testClientID, Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), Name: "Michael Bland", Email: "michael.bland@gsa.gov", @@ -250,7 +252,7 @@ func TestAzureV2ProviderRedeem(t *testing.T) { name: "invalid issuer", claims: &claims{ Issuer: "https://example.com/bogus/issuer", - Audience: TestClientID, + Audience: testClientID, Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), Name: "Michael Bland", Email: "michael.bland@gsa.gov", @@ -319,9 +321,9 @@ func TestAzureV2ProviderRedeem(t *testing.T) { testutil.Equal(t, nil, err) p := newAzureV2Provider(nil) - p.ClientID = TestClientID + p.ClientID = testClientID p.ClientSecret = "456" - err = p.Configure(MicrosoftTenantID) + err = p.Configure(microsoftTenantID) if err != nil { t.Error(err) } @@ -385,7 +387,7 @@ func TestAzureV2GetSignInURL(t *testing.T) { "response_type": []string{"id_token code"}, "scope": []string{"openid email profile offline_access"}, "state": []string{"1234"}, - "client_id": []string{TestClientID}, + "client_id": []string{testClientID}, "nonce": []string{"KEB9Aopa"}, "prompt": []string{"consent"}, }, @@ -400,7 +402,7 @@ func TestAzureV2GetSignInURL(t *testing.T) { "response_type": []string{"id_token code"}, "scope": []string{"openid email profile offline_access"}, "state": []string{"1234"}, - "client_id": []string{TestClientID}, + "client_id": []string{testClientID}, "nonce": []string{"KEB9Aopa"}, "prompt": []string{"consent"}, }, @@ -415,7 +417,7 @@ func TestAzureV2GetSignInURL(t *testing.T) { "response_type": []string{"id_token code"}, "scope": []string{"openid email profile offline_access"}, "state": []string{"4321"}, - "client_id": []string{TestClientID}, + "client_id": []string{testClientID}, "nonce": []string{"x_PhEN0K"}, "prompt": []string{"consent"}, }, @@ -425,7 +427,7 @@ func TestAzureV2GetSignInURL(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { p := newAzureV2Provider(nil) - p.ClientID = TestClientID + p.ClientID = testClientID p.ClientSecret = "456" p.Scope = "openid email profile offline_access" p.ApprovalPrompt = "consent" From 82c429b48d89e023bb83ee3640dd7091831d47a0 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 18:06:23 -0800 Subject: [PATCH 15/38] Fix return value on error --- internal/auth/providers/azure_graph.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/auth/providers/azure_graph.go b/internal/auth/providers/azure_graph.go index 1ad9e161..9e88b9a0 100644 --- a/internal/auth/providers/azure_graph.go +++ b/internal/auth/providers/azure_graph.go @@ -60,10 +60,10 @@ func NewAzureGraphService(clientID string, clientSecret string, tokenURL string) // GetGroups lists groups user belongs to. func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { if gs.client == nil { - return []string{}, errors.New("oauth client must be configured") + return nil, errors.New("oauth client must be configured") } if email == "" { - return []string{}, errors.New("missing email") + return nil, errors.New("missing email") } var wg sync.WaitGroup @@ -76,7 +76,7 @@ func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { for { groupResponse, err := gs.client.Post(requestURL, "application/json", strings.NewReader(requestBody)) if err != nil { - return []string{}, err + return nil, err } groupData := struct { @@ -88,15 +88,15 @@ func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { body, err := ioutil.ReadAll(groupResponse.Body) if err != nil { - return []string{}, err + return nil, err } if groupResponse.StatusCode >= 400 { - return []string{}, fmt.Errorf("api error: %s", string(body)) + return nil, fmt.Errorf("api error: %s", string(body)) } err = json.Unmarshal(body, &groupData) if err != nil { - return []string{}, err + return nil, err } for _, groupID := range groupData.Value { @@ -133,7 +133,7 @@ func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { } wg.Wait() if err != nil { - return []string{}, err + return nil, err } return groupNames, nil From bd344fe97b93597895c245a65567866ec96a5c55 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 18:07:39 -0800 Subject: [PATCH 16/38] This API should only return 200 on success --- internal/auth/providers/azure_graph.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/auth/providers/azure_graph.go b/internal/auth/providers/azure_graph.go index 9e88b9a0..8125b980 100644 --- a/internal/auth/providers/azure_graph.go +++ b/internal/auth/providers/azure_graph.go @@ -90,7 +90,7 @@ func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { if err != nil { return nil, err } - if groupResponse.StatusCode >= 400 { + if groupResponse.StatusCode != http.StatusOK { return nil, fmt.Errorf("api error: %s", string(body)) } From e0e9804e8f070f938edcb3ff77db4323c2e7c547 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 18:09:06 -0800 Subject: [PATCH 17/38] Consistent mutex variable names --- internal/auth/providers/azure_graph.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/auth/providers/azure_graph.go b/internal/auth/providers/azure_graph.go index 8125b980..799a9045 100644 --- a/internal/auth/providers/azure_graph.go +++ b/internal/auth/providers/azure_graph.go @@ -67,7 +67,7 @@ func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { } var wg sync.WaitGroup - var mtx sync.Mutex + var mux sync.Mutex var err error groupNames := make([]string, 0) // See: https://developer.microsoft.com/en-us/graph/docs/api-reference/beta/api/user_getmembergroups @@ -119,9 +119,9 @@ func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { // cache hit name = cachedName.(string) } - mtx.Lock() + mux.Lock() groupNames = append(groupNames, name) - mtx.Unlock() + mux.Unlock() }(&wg) } From 523434673574fa3965b24a347373206c1313b024 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 18:11:13 -0800 Subject: [PATCH 18/38] This API should only return 200 on success --- internal/auth/providers/azure_graph.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/auth/providers/azure_graph.go b/internal/auth/providers/azure_graph.go index 799a9045..d2da76b6 100644 --- a/internal/auth/providers/azure_graph.go +++ b/internal/auth/providers/azure_graph.go @@ -159,7 +159,7 @@ func (gs *AzureGraphService) getGroupName(id string) (string, error) { if err != nil { return "", err } - if groupMetaResponse.StatusCode >= 400 { + if groupMetaResponse.StatusCode != http.StatusOK { return "", fmt.Errorf("api error: %s", string(body)) } From 515bb6a115919563581d6c9bed78a8426d0af3d2 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 18:16:38 -0800 Subject: [PATCH 19/38] Make it clearer that this is a template value --- internal/auth/providers/azure.go | 6 +++--- internal/auth/providers/azure_test.go | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index 39bfe345..ec4ebdaf 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -20,8 +20,8 @@ import ( ) var ( - azureOIDCConfigURL = "https://login.microsoftonline.com/{tenant}/v2.0" - azureOIDCProfileURL = "https://graph.microsoft.com/oidc/userinfo" + azureOIDCConfigURLTemplate = "https://login.microsoftonline.com/{tenant}/v2.0" + azureOIDCProfileURL = "https://graph.microsoft.com/oidc/userinfo" // This is a compile-time check to make sure our types correctly implement the interface: // https://medium.com/@matryer/c167afed3aae @@ -135,7 +135,7 @@ func (p *AzureV2Provider) Configure(tenant string) error { // Azure AD docs suggest this should work, it results in an error. p.Tenant = "common" } - discoveryURL := strings.Replace(azureOIDCConfigURL, "{tenant}", p.Tenant, -1) + discoveryURL := strings.Replace(azureOIDCConfigURLTemplate, "{tenant}", p.Tenant, -1) // Configure discoverable provider data. oidcProvider, err := oidc.NewProvider(context.Background(), discoveryURL) diff --git a/internal/auth/providers/azure_test.go b/internal/auth/providers/azure_test.go index de4e1e15..f13a716f 100644 --- a/internal/auth/providers/azure_test.go +++ b/internal/auth/providers/azure_test.go @@ -299,7 +299,8 @@ func TestAzureV2ProviderRedeem(t *testing.T) { // pointer to body to bypass chicken/egg issue w/ mock server urls providerURL, server := newAzureProviderServer(&body, http.StatusOK, &privKey.PublicKey) defer server.Close() - azureOIDCConfigURL = providerURL.String() + // swap the global OIDC URL template for a test provider URL + azureOIDCConfigURLTemplate = providerURL.String() if tc.claims != nil { // create an instance of Builder that uses the rsa signer From 359275be0751ff6382522f7c26c1bd9c3dd1dcd6 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 18:19:49 -0800 Subject: [PATCH 20/38] Drop usage of named return values --- internal/auth/providers/azure.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index ec4ebdaf..c93a5783 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -56,7 +56,7 @@ func (p *AzureV2Provider) SetStatsdClient(statsdClient *statsd.Client) { // Redeem fulfills the Provider interface. // The authenticator uses this method to redeem the code provided to /callback after the user logs into their Azure AD account. -func (p *AzureV2Provider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *AzureV2Provider) Redeem(redirectURL, code string) (*sessions.SessionState, error) { ctx := context.Background() c := oauth2.Config{ ClientID: p.ClientID, @@ -106,7 +106,7 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (s *sessions.SessionS // with UPN claim; UPN has usually been what you want, but I think it's not // rendered as a full email address here. - s = &sessions.SessionState{ + s := &sessions.SessionState{ AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, @@ -124,7 +124,7 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (s *sessions.SessionS } s.Groups = groupNames } - return + return s, nil } // Configure sets the Azure tenant ID value for the provider From 0ff780c30a9a502a2e54f306e869916d6e777ad3 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 18:24:20 -0800 Subject: [PATCH 21/38] Update comments to accurately reflect what's happening in Marshal/Unmarshal --- internal/pkg/aead/aead.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/internal/pkg/aead/aead.go b/internal/pkg/aead/aead.go index 78df486f..efb5734d 100644 --- a/internal/pkg/aead/aead.go +++ b/internal/pkg/aead/aead.go @@ -89,8 +89,9 @@ func (c *MiscreantCipher) Decrypt(joined []byte) ([]byte, error) { return plaintext, nil } -// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher -// and base64 encodes the binary value as a string and returns the result +// Marshal marshals the interface state as JSON, gzips it, encrypts the JSON +// using the cipher, and base64 encodes the binary value as a string and +// returns the result. func (c *MiscreantCipher) Marshal(s interface{}) (string, error) { // encode json value plaintext, err := json.Marshal(s) @@ -115,8 +116,9 @@ func (c *MiscreantCipher) Marshal(s interface{}) (string, error) { return encoded, nil } -// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the -// byte slice the pased cipher, and unmarshals the resulting JSON into the struct pointer passed +// Unmarshal takes the marshaled string, base64-decodes into a byte slice, +// decrypts the byte slice the pased cipher, gunzips it, and unmarshals +// the resulting JSON into the struct pointer passed. func (c *MiscreantCipher) Unmarshal(value string, s interface{}) error { // convert base64 string value to bytes ciphertext, err := base64.RawURLEncoding.DecodeString(value) From 91b3768267b1f89af757e736e61292012b4696b5 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 18:26:32 -0800 Subject: [PATCH 22/38] Drop another usage of named return values --- internal/auth/providers/oidc.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/auth/providers/oidc.go b/internal/auth/providers/oidc.go index 2aae35b2..32105aeb 100644 --- a/internal/auth/providers/oidc.go +++ b/internal/auth/providers/oidc.go @@ -28,7 +28,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider { // Redeem fulfills the Provider interface. // The authenticator uses this method to redeem the code provided to /callback after the user logs into their OpenID Connect account. -func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *OIDCProvider) Redeem(redirectURL, code string) (*sessions.SessionState, error) { ctx := context.Background() c := oauth2.Config{ ClientID: p.ClientID, @@ -70,14 +70,14 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionStat return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) } - s = &sessions.SessionState{ + s := &sessions.SessionState{ AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, RefreshDeadline: token.Expiry, Email: claims.Email, } - return + return s, nil } // RefreshSessionIfNeeded takes in a SessionState and From dfc50408329f994518ea115b592b651698751b03 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 18:29:26 -0800 Subject: [PATCH 23/38] Combine cache lookup and success check into a single if statement --- internal/auth/providers/azure_graph.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/auth/providers/azure_graph.go b/internal/auth/providers/azure_graph.go index d2da76b6..ca45d34c 100644 --- a/internal/auth/providers/azure_graph.go +++ b/internal/auth/providers/azure_graph.go @@ -107,8 +107,7 @@ func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { var name string // check the cache for the group name first - cachedName, ok := gs.groupNameCache.Get(id) - if !ok { + if cachedName, ok := gs.groupNameCache.Get(id); !ok { // didn't have the group name, make concurrent API call to fetch it name, err = gs.getGroupName(id) if err == nil { From e786df6a344d302d938d32fb1e291d5cdb40e49d Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Tue, 8 Jan 2019 18:32:17 -0800 Subject: [PATCH 24/38] Add explanatory text to interface comments to highlight purpose in mocks --- internal/auth/providers/azure_graph.go | 2 ++ internal/auth/providers/google_admin.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/internal/auth/providers/azure_graph.go b/internal/auth/providers/azure_graph.go index ca45d34c..c8d1ed4b 100644 --- a/internal/auth/providers/azure_graph.go +++ b/internal/auth/providers/azure_graph.go @@ -19,6 +19,8 @@ import ( const azureGroupCacheSize = 1024 // GraphService wraps calls to provider admin APIs +// +// This interface allows the service to be more readily mocked in tests. type GraphService interface { GetGroups(string) ([]string, error) } diff --git a/internal/auth/providers/google_admin.go b/internal/auth/providers/google_admin.go index 3699fb8d..3e996705 100644 --- a/internal/auth/providers/google_admin.go +++ b/internal/auth/providers/google_admin.go @@ -18,6 +18,8 @@ import ( ) // AdminService wraps calls to provider admin APIs +// +// This interface allows the service to be more readily mocked in tests. type AdminService interface { ListMemberships(group string, depth int) (members []string, err error) CheckMemberships(groups []string, user string) (inGroups []string, errr error) From ce8b2d0ee3c2b2dc8020f661c28503ba2cced984 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Fri, 11 Jan 2019 17:50:33 -0800 Subject: [PATCH 25/38] Drop debug logging lines --- internal/auth/providers/internal_util.go | 1 - internal/proxy/oauthproxy.go | 2 -- 2 files changed, 3 deletions(-) diff --git a/internal/auth/providers/internal_util.go b/internal/auth/providers/internal_util.go index cf1940bc..2505cf46 100644 --- a/internal/auth/providers/internal_util.go +++ b/internal/auth/providers/internal_util.go @@ -50,7 +50,6 @@ func stripParam(param, endpoint string) string { func validateToken(p Provider, accessToken string, header http.Header) bool { logger := log.NewLogEntry() - logger.Info(p.Data().ValidateURL) if accessToken == "" || p.Data().ValidateURL == nil { return false } diff --git a/internal/proxy/oauthproxy.go b/internal/proxy/oauthproxy.go index 5bd5549e..823683f4 100755 --- a/internal/proxy/oauthproxy.go +++ b/internal/proxy/oauthproxy.go @@ -548,8 +548,6 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt return nil, ErrInvalidSession } - logger.Printf("session loaded: %v", session) - return session, nil } From 759506e4c103c01547eeeaf421be42e55766ad00 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Sat, 12 Jan 2019 03:00:24 -0800 Subject: [PATCH 26/38] Switch from HMAC to AEAD to simplify nonce validation --- internal/auth/options.go | 5 +- internal/auth/providers/azure.go | 64 +++++++++++++++---- internal/auth/providers/azure_test.go | 90 ++++++++++++++++++++------- 3 files changed, 124 insertions(+), 35 deletions(-) diff --git a/internal/auth/options.go b/internal/auth/options.go index 7e9666a5..76baeaa7 100644 --- a/internal/auth/options.go +++ b/internal/auth/options.go @@ -262,7 +262,10 @@ func newProvider(o *Options) (providers.Provider, error) { var singleFlightProvider providers.Provider switch o.Provider { case providers.AzureProviderName: - azureProvider := providers.NewAzureV2Provider(p) + azureProvider, err := providers.NewAzureV2Provider(p) + if err != nil { + return nil, err + } azureProvider.Configure(o.AzureTenant) singleFlightProvider = providers.NewSingleFlightProvider(azureProvider) case providers.GoogleProviderName: diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index c93a5783..e8ea261a 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -2,7 +2,6 @@ package providers import ( "context" - "crypto/hmac" "crypto/sha256" "encoding/base64" "errors" @@ -11,6 +10,7 @@ import ( "strings" "time" + "github.com/buzzfeed/sso/internal/pkg/aead" "github.com/buzzfeed/sso/internal/pkg/sessions" "github.com/datadog/datadog-go/statsd" "golang.org/x/oauth2" @@ -36,17 +36,30 @@ type AzureV2Provider struct { Tenant string StatsdClient *statsd.Client - + NonceCipher aead.Cipher GraphService GraphService } // NewAzureV2Provider creates a new AzureV2Provider struct -func NewAzureV2Provider(p *ProviderData) *AzureV2Provider { +func NewAzureV2Provider(p *ProviderData) (*AzureV2Provider, error) { p.ProviderName = "Azure AD" + + if p.ClientSecret == "" { + return nil, errors.New("client secret cannot be empty") + } + // Can't guarantee the client secret will be 32 or 64 bytes in length, + // hash to derive a key, error on empty string to avoid silent failure. + key := sha256.Sum256([]byte(p.ClientSecret)) + nonceCipher, err := aead.NewMiscreantCipher(key[:]) + if err != nil { + return nil, err + } + return &AzureV2Provider{ ProviderData: p, + NonceCipher: nonceCipher, OIDCProvider: &OIDCProvider{ProviderData: p}, - } + }, nil } // SetStatsdClient sets the azure provider statsd client @@ -95,6 +108,7 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (*sessions.SessionSta var claims struct { Email string `json:"email"` UPN string `json:"upn"` + Nonce string `json:"nonce"` } if err := idToken.Claims(&claims); err != nil { return nil, fmt.Errorf("failed to parse id_token claims: %v", err) @@ -102,6 +116,12 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (*sessions.SessionSta if claims.Email == "" { return nil, fmt.Errorf("id_token did not contain an email") } + if claims.Nonce == "" { + return nil, fmt.Errorf("id_token did not contain a nonce") + } + if !p.validateNonce(claims.Nonce) { + return nil, fmt.Errorf("unable to validate id_token nonce") + } // TODO: test this w/ an account that uses an alias and compare email claim // with UPN claim; UPN has usually been what you want, but I think it's not // rendered as a full email address here. @@ -251,16 +271,34 @@ func (p *AzureV2Provider) GetSignInURL(redirectURI, state string) string { return a.String() } -// calculateNonce generates a deterministic nonce from the state value. -// We don't have a session state pointer but we need to generate a nonce -// that we can verify statelessly later. We can only use what's in the -// params and provider struct to assemble a nonce. State is guaranteed to be -// indistinguishable from random and will always change. +// calculateNonce generates a verifiable nonce from the state value. +// A nonce can be subsequently validated by attempting to decrypt it. func (p *AzureV2Provider) calculateNonce(state string) string { - key := []byte(p.ClientID + p.ClientSecret) - h := hmac.New(sha256.New, key) - h.Write([]byte(state)) - return base64.URLEncoding.EncodeToString(h.Sum(nil))[:8] + rawNonce, err := p.NonceCipher.Encrypt([]byte(state)) + if err != nil { + // GetSignInURL can't return an error and this shouldn't fail silently + panic(err) + } + return base64.URLEncoding.EncodeToString(rawNonce) +} + +// validateNonce attempts to decrypt the nonce value. If it decrypts +// successfully, the nonce is considered valid. +func (p *AzureV2Provider) validateNonce(nonce string) bool { + rawNonce, err := base64.URLEncoding.DecodeString(nonce) + if err != nil { + return false + } + state, err := p.NonceCipher.Decrypt(rawNonce) + if err != nil { + return false + } + // Sanity check to ensure state contains roughly what we expect + _, err = base64.URLEncoding.DecodeString(string(state)) + if err != nil { + return false + } + return true } // ValidateGroupMembership takes in an email and the allowed groups and returns the groups that the email is part of in that list. diff --git a/internal/auth/providers/azure_test.go b/internal/auth/providers/azure_test.go index f13a716f..8716b09d 100644 --- a/internal/auth/providers/azure_test.go +++ b/internal/auth/providers/azure_test.go @@ -25,6 +25,7 @@ import ( const ( microsoftTenantID = "9188040d-6c67-4c5b-b112-36a304b66dad" testClientID = "a4c35c92-e858-41e8-bd2c-ade04cb622b1" + testClientSecret = "4" // number chosen at random ) func newAzureProviderServer(redeemBody *[]byte, redeemCode int, pubKey *rsa.PublicKey) (*url.URL, *httptest.Server) { @@ -64,6 +65,8 @@ func newAzureV2Provider(providerData *ProviderData) *AzureV2Provider { if providerData == nil { providerData = &ProviderData{ ProviderName: "", + ClientID: testClientID, + ClientSecret: testClientSecret, SignInURL: &url.URL{}, RedeemURL: &url.URL{}, RevokeURL: &url.URL{}, @@ -71,7 +74,11 @@ func newAzureV2Provider(providerData *ProviderData) *AzureV2Provider { ValidateURL: &url.URL{}, Scope: ""} } - return NewAzureV2Provider(providerData) + provider, err := NewAzureV2Provider(providerData) + if err != nil { + panic(err) + } + return provider } func TestAzureV2ProviderDefaults(t *testing.T) { @@ -97,6 +104,8 @@ func TestAzureV2ProviderDefaults(t *testing.T) { { name: "with provider overrides", providerData: &ProviderData{ + ClientID: "1234", + ClientSecret: "4", // Number chosen at random SignInURL: &url.URL{ Scheme: "https", Host: "example.com", @@ -117,7 +126,8 @@ func TestAzureV2ProviderDefaults(t *testing.T) { Scheme: "https", Host: "example.com", Path: "/oauth/tokeninfo"}, - Scope: "profile"}, + Scope: "profile", + }, signInURL: "https://example.com/oauth/auth", redeemURL: "https://example.com/oauth/token", revokeURL: "https://example.com/oauth/deauth", @@ -186,6 +196,7 @@ type claims struct { Expiry jwt.NumericDate `json:"exp,omitempty"` Name string `json:"name,omitempty"` Email string `json:"email,omitempty"` + Nonce string `json:"nonce,omitempty"` NotEmail string `json:"not_email,omitempty"` } @@ -221,6 +232,7 @@ func TestAzureV2ProviderRedeem(t *testing.T) { Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), Name: "Michael Bland", Email: "michael.bland@gsa.gov", + Nonce: "{mock-nonce}", }, resp: redeemResponse{ AccessToken: "a1234", @@ -240,6 +252,7 @@ func TestAzureV2ProviderRedeem(t *testing.T) { Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), Name: "Michael Bland", Email: "michael.bland@gsa.gov", + Nonce: "{mock-nonce}", }, resp: redeemResponse{ AccessToken: "a1234", @@ -256,6 +269,40 @@ func TestAzureV2ProviderRedeem(t *testing.T) { Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), Name: "Michael Bland", Email: "michael.bland@gsa.gov", + Nonce: "{mock-nonce}", + }, + resp: redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + }, + expectedError: true, + }, + { + name: "missing nonce", + claims: &claims{ + Issuer: "{mock-issuer}", + Audience: testClientID, + Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), + Name: "Michael Bland", + Email: "michael.bland@gsa.gov", + }, + resp: redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + }, + expectedError: true, + }, + { + name: "invalid nonce", + claims: &claims{ + Issuer: "{mock-issuer}", + Audience: testClientID, + Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), + Name: "Michael Bland", + Email: "michael.bland@gsa.gov", + Nonce: "123456789", }, resp: redeemResponse{ AccessToken: "a1234", @@ -302,12 +349,19 @@ func TestAzureV2ProviderRedeem(t *testing.T) { // swap the global OIDC URL template for a test provider URL azureOIDCConfigURLTemplate = providerURL.String() + p := newAzureV2Provider(nil) + err = p.Configure(microsoftTenantID) + if err != nil { + t.Error(err) + } + if tc.claims != nil { // create an instance of Builder that uses the rsa signer builder := jwt.Signed(rsaSigner) // add claims to the Builder tc.claims.Issuer = strings.Replace(tc.claims.Issuer, "{mock-issuer}", providerURL.String(), -1) + tc.claims.Nonce = strings.Replace(tc.claims.Nonce, "{mock-nonce}", p.calculateNonce("1234"), -1) builder = builder.Claims(tc.claims) // build and inject ID token into response @@ -321,13 +375,6 @@ func TestAzureV2ProviderRedeem(t *testing.T) { body, err = json.Marshal(tc.resp) testutil.Equal(t, nil, err) - p := newAzureV2Provider(nil) - p.ClientID = testClientID - p.ClientSecret = "456" - err = p.Configure(microsoftTenantID) - if err != nil { - t.Error(err) - } // graph service mock has to be set after p.Configure p.GraphService = &MockAzureGraphService{} @@ -379,7 +426,7 @@ func TestAzureV2GetSignInURL(t *testing.T) { expectedParams url.Values }{ { - name: "nonce values passed to azure should be deterministic, pass one", + name: "nonce values passed to azure should validate, pass one", redirectURI: "https://example.com/oauth/callback", state: "1234", expectedParams: url.Values{ @@ -389,12 +436,11 @@ func TestAzureV2GetSignInURL(t *testing.T) { "scope": []string{"openid email profile offline_access"}, "state": []string{"1234"}, "client_id": []string{testClientID}, - "nonce": []string{"KEB9Aopa"}, "prompt": []string{"consent"}, }, }, { - name: "nonce values passed to azure should be deterministic, pass two", + name: "nonce values passed to azure should validate, pass two", redirectURI: "https://example.com/oauth/callback", state: "1234", expectedParams: url.Values{ @@ -404,12 +450,11 @@ func TestAzureV2GetSignInURL(t *testing.T) { "scope": []string{"openid email profile offline_access"}, "state": []string{"1234"}, "client_id": []string{testClientID}, - "nonce": []string{"KEB9Aopa"}, "prompt": []string{"consent"}, }, }, { - name: "nonce values passed to azure should be deterministic, pass three", + name: "nonce values passed to azure should validate, pass three", redirectURI: "https://example.com/oauth/callback", state: "4321", expectedParams: url.Values{ @@ -419,7 +464,6 @@ func TestAzureV2GetSignInURL(t *testing.T) { "scope": []string{"openid email profile offline_access"}, "state": []string{"4321"}, "client_id": []string{testClientID}, - "nonce": []string{"x_PhEN0K"}, "prompt": []string{"consent"}, }, }, @@ -428,8 +472,6 @@ func TestAzureV2GetSignInURL(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { p := newAzureV2Provider(nil) - p.ClientID = testClientID - p.ClientSecret = "456" p.Scope = "openid email profile offline_access" p.ApprovalPrompt = "consent" @@ -439,10 +481,16 @@ func TestAzureV2GetSignInURL(t *testing.T) { t.Error(err) } - if !reflect.DeepEqual(tc.expectedParams, parsedURL.Query()) { - t.Logf("expected params %+v", tc.expectedParams) - t.Logf("got params %+v", parsedURL.Query()) - t.Errorf("unexpected params returned") + for k := range tc.expectedParams { + if tc.expectedParams.Get(k) != parsedURL.Query().Get(k) { + t.Logf("expected param %s: %+v", k, tc.expectedParams.Get(k)) + t.Logf("got param %s: %+v", k, parsedURL.Query().Get(k)) + } + } + + nonce := parsedURL.Query().Get("nonce") + if p.validateNonce(nonce) != true { + t.Logf("expected valid nonce, got: %+v", nonce) } }) } From 1f06ed22bd606ea15ee1b9315ffd335ac87e517b Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Sat, 12 Jan 2019 03:07:19 -0800 Subject: [PATCH 27/38] Rename to ms_graph_api.go and add top-level comment --- internal/auth/providers/azure.go | 2 +- .../{azure_graph.go => ms_graph_api.go} | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) rename internal/auth/providers/{azure_graph.go => ms_graph_api.go} (87%) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index e8ea261a..b3f861b6 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -193,7 +193,7 @@ func (p *AzureV2Provider) Configure(tenant string) error { if p.RedeemURL.String() == "" { return errors.New("redeem url must be set") } - p.GraphService = NewAzureGraphService(p.ClientID, p.ClientSecret, p.RedeemURL.String()) + p.GraphService = NewMSGraphService(p.ClientID, p.ClientSecret, p.RedeemURL.String()) return nil } diff --git a/internal/auth/providers/azure_graph.go b/internal/auth/providers/ms_graph_api.go similarity index 87% rename from internal/auth/providers/azure_graph.go rename to internal/auth/providers/ms_graph_api.go index c8d1ed4b..52ba5cf4 100644 --- a/internal/auth/providers/azure_graph.go +++ b/internal/auth/providers/ms_graph_api.go @@ -15,6 +15,9 @@ import ( "golang.org/x/oauth2/clientcredentials" ) +// The Microsoft Graph API provides a mechanism to obtain group membership +// information for users authenticated with the Azure AD provider. + // azureGroupCacheSize controls the size of the caches of AD group info const azureGroupCacheSize = 1024 @@ -25,15 +28,15 @@ type GraphService interface { GetGroups(string) ([]string, error) } -// AzureGraphService implements graph API calls for the Azure provider -type AzureGraphService struct { +// MSGraphService implements graph API calls for the Azure provider +type MSGraphService struct { client *http.Client groupMembershipCache *lru.Cache groupNameCache *lru.Cache } -// NewAzureGraphService creates a new graph service for getting groups -func NewAzureGraphService(clientID string, clientSecret string, tokenURL string) *AzureGraphService { +// NewMSGraphService creates a new graph service for getting groups +func NewMSGraphService(clientID string, clientSecret string, tokenURL string) *MSGraphService { clientConfig := &clientcredentials.Config{ ClientID: clientID, ClientSecret: clientSecret, @@ -52,7 +55,7 @@ func NewAzureGraphService(clientID string, clientSecret string, tokenURL string) if err != nil { panic(err) // Should only happen if azureGroupCacheSize is a negative number } - return &AzureGraphService{ + return &MSGraphService{ client: client, groupMembershipCache: memberCache, groupNameCache: nameCache, @@ -60,7 +63,7 @@ func NewAzureGraphService(clientID string, clientSecret string, tokenURL string) } // GetGroups lists groups user belongs to. -func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { +func (gs *MSGraphService) GetGroups(email string) ([]string, error) { if gs.client == nil { return nil, errors.New("oauth client must be configured") } @@ -141,7 +144,7 @@ func (gs *AzureGraphService) GetGroups(email string) ([]string, error) { } // getGroupName returns the group name, preferentially pulling from cache -func (gs *AzureGraphService) getGroupName(id string) (string, error) { +func (gs *MSGraphService) getGroupName(id string) (string, error) { if gs.client == nil { return "", errors.New("oauth client must be configured") } From fcf4c96edbea1af392cb685d2e6fa99c66d9a3f5 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Mon, 14 Jan 2019 16:12:38 -0800 Subject: [PATCH 28/38] Move generic OIDC discovery logic into generic OIDC provider --- internal/auth/providers/azure.go | 44 +++++---------------------- internal/auth/providers/google.go | 4 ++- internal/auth/providers/oidc.go | 50 +++++++++++++++++++++++++++++-- 3 files changed, 58 insertions(+), 40 deletions(-) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index b3f861b6..c7da4590 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -16,7 +16,6 @@ import ( "golang.org/x/oauth2" log "github.com/buzzfeed/sso/internal/pkg/logging" - oidc "github.com/coreos/go-oidc" ) var ( @@ -42,7 +41,9 @@ type AzureV2Provider struct { // NewAzureV2Provider creates a new AzureV2Provider struct func NewAzureV2Provider(p *ProviderData) (*AzureV2Provider, error) { - p.ProviderName = "Azure AD" + if p.ProviderName == "" { + p.ProviderName = "Azure AD" + } if p.ClientSecret == "" { return nil, errors.New("client secret cannot be empty") @@ -58,7 +59,7 @@ func NewAzureV2Provider(p *ProviderData) (*AzureV2Provider, error) { return &AzureV2Provider{ ProviderData: p, NonceCipher: nonceCipher, - OIDCProvider: &OIDCProvider{ProviderData: p}, + OIDCProvider: nil, }, nil } @@ -93,8 +94,8 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (*sessions.SessionSta return nil, fmt.Errorf("token response did not contain an id_token") } - // should only happen if oidc autodiscovery is broken - if p.OIDCProvider.Verifier == nil { + // should only happen if oidc autodiscovery is broken or unconfigured + if p.OIDCProvider == nil || p.OIDCProvider.Verifier == nil { return nil, fmt.Errorf("oidc verifier missing") } @@ -158,41 +159,12 @@ func (p *AzureV2Provider) Configure(tenant string) error { discoveryURL := strings.Replace(azureOIDCConfigURLTemplate, "{tenant}", p.Tenant, -1) // Configure discoverable provider data. - oidcProvider, err := oidc.NewProvider(context.Background(), discoveryURL) + var err error + p.OIDCProvider, err = NewOIDCProvider(p.ProviderData, discoveryURL) if err != nil { - // TODO: This seems like it _should_ work for "common", but it doesn't - // Does anyone actually want to use this with "common" though? return err } - p.OIDCProvider.Verifier = oidcProvider.Verifier(&oidc.Config{ - ClientID: p.ClientID, - }) - // Set these only if they haven't been overridden - if p.SignInURL == nil || p.SignInURL.String() == "" { - p.SignInURL, err = url.Parse(oidcProvider.Endpoint().AuthURL) - if err != nil { - return err - } - } - if p.RedeemURL == nil || p.RedeemURL.String() == "" { - p.RedeemURL, err = url.Parse(oidcProvider.Endpoint().TokenURL) - if err != nil { - return err - } - } - if p.ProfileURL == nil || p.ProfileURL.String() == "" { - p.ProfileURL, err = url.Parse(azureOIDCProfileURL) - } - if err != nil { - return err - } - if p.Scope == "" { - p.Scope = "openid email profile offline_access" - } - if p.RedeemURL.String() == "" { - return errors.New("redeem url must be set") - } p.GraphService = NewMSGraphService(p.ClientID, p.ClientSecret, p.RedeemURL.String()) return nil } diff --git a/internal/auth/providers/google.go b/internal/auth/providers/google.go index eda238cf..7246ee88 100644 --- a/internal/auth/providers/google.go +++ b/internal/auth/providers/google.go @@ -47,7 +47,9 @@ func NewGoogleProvider(p *ProviderData, adminEmail, credsFilePath string) (*Goog } } - p.ProviderName = "Google" + if p.ProviderName == "" { + p.ProviderName = "Google" + } if p.SignInURL.String() == "" { p.SignInURL = &url.URL{Scheme: "https", Host: "accounts.google.com", diff --git a/internal/auth/providers/oidc.go b/internal/auth/providers/oidc.go index 32105aeb..445c8dc0 100644 --- a/internal/auth/providers/oidc.go +++ b/internal/auth/providers/oidc.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/url" "time" "golang.org/x/oauth2" @@ -21,9 +22,52 @@ type OIDCProvider struct { } // NewOIDCProvider creates a new generic OpenID Connect provider -func NewOIDCProvider(p *ProviderData) *OIDCProvider { - p.ProviderName = "OpenID Connect" - return &OIDCProvider{ProviderData: p} +func NewOIDCProvider(p *ProviderData, discoveryURL string) (*OIDCProvider, error) { + provider := &OIDCProvider{ + ProviderData: p, + } + if p.ProviderName == "" { + p.ProviderName = "OpenID Connect" + } + + // Configure discoverable provider data. + oidcProvider, err := oidc.NewProvider(context.Background(), discoveryURL) + if err != nil { + // TODO: This seems like it _should_ work for "common", but it doesn't + // Does anyone actually want to use this with "common" though? + return nil, err + } + + provider.Verifier = oidcProvider.Verifier(&oidc.Config{ + ClientID: p.ClientID, + }) + // Set these only if they haven't been overridden + if p.SignInURL == nil || p.SignInURL.String() == "" { + p.SignInURL, err = url.Parse(oidcProvider.Endpoint().AuthURL) + if err != nil { + return nil, err + } + } + if p.RedeemURL == nil || p.RedeemURL.String() == "" { + p.RedeemURL, err = url.Parse(oidcProvider.Endpoint().TokenURL) + if err != nil { + return nil, err + } + } + if p.ProfileURL == nil || p.ProfileURL.String() == "" { + p.ProfileURL, err = url.Parse(azureOIDCProfileURL) + } + if err != nil { + return nil, err + } + if p.Scope == "" { + p.Scope = "openid email profile offline_access" + } + if p.RedeemURL.String() == "" { + return nil, errors.New("redeem url must be set") + } + + return provider, nil } // Redeem fulfills the Provider interface. From d19cddf62acafa001fa2dfa65ac93d46a60ebc77 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Mon, 14 Jan 2019 16:35:52 -0800 Subject: [PATCH 29/38] Add clarification around error handling --- internal/auth/providers/ms_graph_api.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/auth/providers/ms_graph_api.go b/internal/auth/providers/ms_graph_api.go index 52ba5cf4..225c3776 100644 --- a/internal/auth/providers/ms_graph_api.go +++ b/internal/auth/providers/ms_graph_api.go @@ -115,6 +115,8 @@ func (gs *MSGraphService) GetGroups(email string) ([]string, error) { if cachedName, ok := gs.groupNameCache.Get(id); !ok { // didn't have the group name, make concurrent API call to fetch it name, err = gs.getGroupName(id) + // the err value is not shadowed in the goroutine, so if this isn't + // nil, it will return the err value after wg.Wait() is called if err == nil { // got the name ok, populate the cache gs.groupNameCache.Add(id, name) @@ -136,6 +138,7 @@ func (gs *MSGraphService) GetGroups(email string) ([]string, error) { } } wg.Wait() + // any err value set above will cause this to fail if err != nil { return nil, err } From 55d4776e5ce5f9da5b9380389f9deb6539f05b74 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Thu, 17 Jan 2019 18:31:42 -0800 Subject: [PATCH 30/38] Update mock file to match rename of graph service struct --- internal/auth/providers/azure_graph_mock.go | 12 ------------ internal/auth/providers/azure_test.go | 4 ++-- internal/auth/providers/ms_graph_mock.go | 12 ++++++++++++ 3 files changed, 14 insertions(+), 14 deletions(-) delete mode 100644 internal/auth/providers/azure_graph_mock.go create mode 100644 internal/auth/providers/ms_graph_mock.go diff --git a/internal/auth/providers/azure_graph_mock.go b/internal/auth/providers/azure_graph_mock.go deleted file mode 100644 index 044ce01f..00000000 --- a/internal/auth/providers/azure_graph_mock.go +++ /dev/null @@ -1,12 +0,0 @@ -package providers - -// MockAzureGraphService is an implementation of GraphService to be used for testing -type MockAzureGraphService struct { - Groups []string - GroupsError error -} - -// GetGroups mocks the GetGroups function -func (ms *MockAzureGraphService) GetGroups(string) ([]string, error) { - return ms.Groups, ms.GroupsError -} diff --git a/internal/auth/providers/azure_test.go b/internal/auth/providers/azure_test.go index 8716b09d..b839ba74 100644 --- a/internal/auth/providers/azure_test.go +++ b/internal/auth/providers/azure_test.go @@ -376,7 +376,7 @@ func TestAzureV2ProviderRedeem(t *testing.T) { testutil.Equal(t, nil, err) // graph service mock has to be set after p.Configure - p.GraphService = &MockAzureGraphService{} + p.GraphService = &MockMSGraphService{} session, err := p.Redeem("http://redirect/", "code1234") if tc.expectedError && err == nil { @@ -540,7 +540,7 @@ func TestAzureV2ValidateGroupMembers(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { p := newAzureV2Provider(nil) - p.GraphService = &MockAzureGraphService{ + p.GraphService = &MockMSGraphService{ Groups: tc.mockedGroups, GroupsError: tc.mockedError, } diff --git a/internal/auth/providers/ms_graph_mock.go b/internal/auth/providers/ms_graph_mock.go new file mode 100644 index 00000000..e841b842 --- /dev/null +++ b/internal/auth/providers/ms_graph_mock.go @@ -0,0 +1,12 @@ +package providers + +// MockMSGraphService is an implementation of GraphService to be used for testing +type MockMSGraphService struct { + Groups []string + GroupsError error +} + +// GetGroups mocks the GetGroups function +func (ms *MockMSGraphService) GetGroups(string) ([]string, error) { + return ms.Groups, ms.GroupsError +} From fd7b43fe5f5ae3a9d9bd890fde5d51dfefb239f9 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Fri, 18 Jan 2019 16:01:03 -0800 Subject: [PATCH 31/38] Remove methods again --- internal/proxy/oauthproxy.go | 91 ------------------------------------ 1 file changed, 91 deletions(-) diff --git a/internal/proxy/oauthproxy.go b/internal/proxy/oauthproxy.go index 823683f4..0fea992a 100755 --- a/internal/proxy/oauthproxy.go +++ b/internal/proxy/oauthproxy.go @@ -474,97 +474,6 @@ func (p *OAuthProxy) redeemCode(host, code string) (*sessions.SessionState, erro return s, nil } -// MakeSessionCookie constructs a session cookie given the request, an expiration time and the current time. -func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { - return p.makeCookie(req, p.CookieName, value, expiration, now) -} - -// MakeCSRFCookie creates a CSRF cookie given the request, an expiration time, and the current time. -func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { - return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) -} - -func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { - logger := log.NewLogEntry() - - domain := req.Host - if h, _, err := net.SplitHostPort(domain); err == nil { - domain = h - } - if p.CookieDomain != "" { - if !strings.HasSuffix(domain, p.CookieDomain) { - logger.WithRequestHost(domain).WithCookieDomain(p.CookieDomain).Warn( - "using configured cookie domain") - } - domain = p.CookieDomain - } - - return &http.Cookie{ - Name: name, - Value: value, - Path: "/", - Domain: domain, - HttpOnly: p.CookieHTTPOnly, - Secure: p.CookieSecure, - Expires: now.Add(expiration), - } -} - -// ClearCSRFCookie clears the CSRF cookie from the request -func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) { - http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now())) -} - -// SetCSRFCookie sets the CSRFCookie creates a CSRF cookie in a given request -func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) { - http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now())) -} - -// ClearSessionCookie clears the session cookie from a request -func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { - http.SetCookie(rw, p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())) -} - -// SetSessionCookie creates a sesion cookie based on the value and the expiration time. -func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { - http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now())) -} - -// LoadCookiedSession returns a SessionState from the cookie in the request. -func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, error) { - logger := log.NewLogEntry().WithRemoteAddress(getRemoteAddr(req)) - - c, err := req.Cookie(p.CookieName) - if err != nil { - // always http.ErrNoCookie - return nil, err - } - - session, err := providers.UnmarshalSession(c.Value, p.CookieCipher) - if err != nil { - tags := []string{"error:unmarshaling_session"} - p.StatsdClient.Incr("application_error", tags, 1.0) - logger.Error(err, "unable to unmarshal session") - return nil, ErrInvalidSession - } - - return session, nil -} - -// SaveSession saves a session state to a request cookie. -func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { - value, err := providers.MarshalSession(s, p.CookieCipher) - if err != nil { - return err - } - - p.SetSessionCookie(rw, req, value) - logger := log.NewLogEntry().WithRemoteAddress(getRemoteAddr(req)) - logger.Printf("session saved: %v", s) - - return nil -} - // RobotsTxt sets the User-Agent header in the response to be "Disallow" func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter, _ *http.Request) { rw.WriteHeader(http.StatusOK) From 4c8c289568494e9e306adefda165a83108519fae Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Fri, 18 Jan 2019 16:06:14 -0800 Subject: [PATCH 32/38] Remove debug lines --- internal/proxy/oauthproxy.go | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/internal/proxy/oauthproxy.go b/internal/proxy/oauthproxy.go index 0fea992a..da9d9a75 100755 --- a/internal/proxy/oauthproxy.go +++ b/internal/proxy/oauthproxy.go @@ -341,10 +341,6 @@ func NewOAuthProxy(opts *Options, optFuncs ...func(*OAuthProxy) error) (*OAuthPr } for _, upstreamConfig := range opts.upstreamConfigs { - logger.Printf("upstreamConfig.Route: %v", upstreamConfig.Route) - logger.Printf("upstreamConfig.RouteConfig: %v", upstreamConfig.RouteConfig) - logger.Printf("upstreamConfig.RouteConfig.Options: %v", upstreamConfig.RouteConfig.Options) - switch route := upstreamConfig.Route.(type) { case *SimpleRoute: reverseProxy := NewReverseProxy(route.ToURL, upstreamConfig) @@ -792,12 +788,6 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { allowedGroups := route.upstreamConfig.AllowedGroups inGroups, validGroup, err := p.provider.ValidateGroup(session.Email, allowedGroups) - logger.Printf("route: %v", route) - logger.Printf("route.upstreamConfig: %v", route.upstreamConfig) - logger.Printf("inGroups: %v", inGroups) - logger.Printf("allowedGroups: %v", allowedGroups) - logger.Printf("validGroup: %v", validGroup) - logger.Printf("err: %v", err) if err != nil { tags = append(tags, "error:user_group_failed") p.StatsdClient.Incr("provider_error", tags, 1.0) @@ -1007,9 +997,6 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er return ErrUserNotAuthorized } - logger.Printf("proxied full session: %v", session) - logger.Printf("proxied groups: %v", session.Groups) - req.Header.Set("X-Forwarded-User", session.User) if p.PassAccessToken && session.AccessToken != "" { From b95cb548dba62d510d41b26605ef8853e4d96cf0 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Wed, 30 Jan 2019 10:47:09 -0800 Subject: [PATCH 33/38] Get sign out working --- internal/auth/providers/azure.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go index c7da4590..025cf038 100644 --- a/internal/auth/providers/azure.go +++ b/internal/auth/providers/azure.go @@ -218,6 +218,12 @@ func (p *AzureV2Provider) RefreshAccessToken(refreshToken string) (string, time. return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil } +// Revoke does nothing for Azure AD, but needs to be declared to avoid a +// not implemented error which would prevent clearing sessions on sign out. +func (p *AzureV2Provider) Revoke(s *sessions.SessionState) error { + return nil +} + // ValidateSessionState attempts to validate the session state's access token. func (p *AzureV2Provider) ValidateSessionState(s *sessions.SessionState) bool { // return validateToken(p, s.AccessToken, nil) From b5d595d975b7de951ce065e1b2a5158fe002a7c1 Mon Sep 17 00:00:00 2001 From: Kevin O'Connor Date: Wed, 13 Feb 2019 14:35:46 -0800 Subject: [PATCH 34/38] Extend lifetime deadline for OIDC provider Co-Authored-By: sporkmonger --- internal/auth/providers/oidc.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/auth/providers/oidc.go b/internal/auth/providers/oidc.go index 445c8dc0..0a1348f4 100644 --- a/internal/auth/providers/oidc.go +++ b/internal/auth/providers/oidc.go @@ -118,6 +118,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (*sessions.SessionState, AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, RefreshDeadline: token.Expiry, + LifetimeDeadline: sessions.ExtendDeadline(p.SessionLifetimeTTL), Email: claims.Email, } From c30c7c8e97b78e7e5efa17d5db9e940b6ab03887 Mon Sep 17 00:00:00 2001 From: Kevin O'Connor Date: Wed, 13 Feb 2019 20:01:25 -0800 Subject: [PATCH 35/38] Add OIDC discovery URL environment var Co-Authored-By: sporkmonger --- internal/auth/options.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/auth/options.go b/internal/auth/options.go index 76baeaa7..ba1014df 100644 --- a/internal/auth/options.go +++ b/internal/auth/options.go @@ -66,6 +66,7 @@ type Options struct { AzureTenant string `envconfig:"AZURE_TENANT"` GoogleAdminEmail string `envconfig:"GOOGLE_ADMIN_EMAIL"` GoogleServiceAccountJSON string `envconfig:"GOOGLE_SERVICE_ACCOUNT_JSON"` + OIDCDiscoveryURL string `envconfig:"OIDC_DISCOVERY_URL"` Footer string `envconfig:"FOOTER"` From 0dba59cb80f7704aa325cfaff1bd82ce3cee06c0 Mon Sep 17 00:00:00 2001 From: Kevin O'Connor Date: Wed, 13 Feb 2019 20:02:00 -0800 Subject: [PATCH 36/38] Allow OIDC as a provider in options Co-Authored-By: sporkmonger --- internal/auth/options.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/internal/auth/options.go b/internal/auth/options.go index ba1014df..6a189c99 100644 --- a/internal/auth/options.go +++ b/internal/auth/options.go @@ -269,6 +269,12 @@ func newProvider(o *Options) (providers.Provider, error) { } azureProvider.Configure(o.AzureTenant) singleFlightProvider = providers.NewSingleFlightProvider(azureProvider) + case providers.OIDCProviderName: + oidcProvider, err := providers.NewOIDCProvider(p, o.OIDCDiscoveryURL) + if err != nil { + return nil, err + } + singleFlightProvider = providers.NewSingleFlightProvider(oidcProvider) case providers.GoogleProviderName: if o.GoogleServiceAccountJSON != "" { _, err := os.Open(o.GoogleServiceAccountJSON) From a7362350f3ee3dd112aa12e22a1d00c8613283dd Mon Sep 17 00:00:00 2001 From: Kevin O'Connor Date: Wed, 13 Feb 2019 20:03:20 -0800 Subject: [PATCH 37/38] OIDC doesn't require a token validation endpoint (though extensions do support it) Co-Authored-By: sporkmonger --- internal/auth/providers/oidc.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/internal/auth/providers/oidc.go b/internal/auth/providers/oidc.go index 0a1348f4..6b5723d1 100644 --- a/internal/auth/providers/oidc.go +++ b/internal/auth/providers/oidc.go @@ -171,3 +171,10 @@ func (p *OIDCProvider) RefreshAccessToken(refreshToken string) (string, time.Dur return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil } + +// ValidateSessionState attempts to validate the session state's access token. +func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool { + // return validateToken(p, s.AccessToken, nil) + // TODO Validate ID token + return true +} From 552ce5bfa7fcc1f36c922abedc885938d90ed054 Mon Sep 17 00:00:00 2001 From: Bob Aman Date: Mon, 1 Jul 2019 18:42:58 -0700 Subject: [PATCH 38/38] go fmt --- internal/auth/providers/oidc.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/auth/providers/oidc.go b/internal/auth/providers/oidc.go index 6b5723d1..7d1a5be5 100644 --- a/internal/auth/providers/oidc.go +++ b/internal/auth/providers/oidc.go @@ -115,11 +115,11 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (*sessions.SessionState, } s := &sessions.SessionState{ - AccessToken: token.AccessToken, - RefreshToken: token.RefreshToken, - RefreshDeadline: token.Expiry, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + RefreshDeadline: token.Expiry, LifetimeDeadline: sessions.ExtendDeadline(p.SessionLifetimeTTL), - Email: claims.Email, + Email: claims.Email, } return s, nil