diff --git a/go.mod b/go.mod index 61df78d7..763aa2e7 100644 --- a/go.mod +++ b/go.mod @@ -4,18 +4,22 @@ require ( github.com/18F/hmacauth v0.0.0-20151013130326-9232a6386b73 github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect + github.com/coreos/go-oidc v2.0.0+incompatible github.com/datadog/datadog-go v0.0.0-20180822151419-281ae9f2d895 + github.com/hashicorp/golang-lru v0.5.1 github.com/imdario/mergo v0.3.7 github.com/kelseyhightower/envconfig v1.3.0 github.com/mccutchen/go-httpbin v1.1.1 github.com/micro/go-micro v1.5.0 github.com/miscreant/miscreant-go v0.0.0-20181010193435-325cbd69228b github.com/mitchellh/mapstructure v1.1.2 + github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/rakyll/statik v0.1.6 github.com/sirupsen/logrus v1.4.2 golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 golang.org/x/sync v0.0.0-20190423024810-112230192c58 golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 google.golang.org/api v0.5.0 + gopkg.in/square/go-jose.v2 v2.3.1 gopkg.in/yaml.v2 v2.2.2 ) diff --git a/go.sum b/go.sum index 1cad5d50..c72150e4 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/containerd/continuity v0.0.0-20181203112020-004b46473808/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/coreos/go-oidc v2.0.0+incompatible h1:+RStIopZ8wooMx+Vs5Bt8zMXxV1ABl5LbakNExNmZIg= +github.com/coreos/go-oidc v2.0.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= github.com/datadog/datadog-go v0.0.0-20180822151419-281ae9f2d895 h1:VTq58gB0MvQpLAz2kfeSBchIAi9+zGRMTh+pyfXEoGs= github.com/datadog/datadog-go v0.0.0-20180822151419-281ae9f2d895/go.mod h1:Mo2ZYXXA9Kp6qoXibOPpsSwkwZ67pcianic5+LKUZvY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -218,6 +220,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/posener/complete v1.2.1/go.mod h1:6gapUrK/U1TAN7ciCoNRIdVC5sbdBTUh1DKN0g6uH7E= +github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU= +github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.2/go.mod h1:OsXs2jCmiKlQ1lTBmv21f2mNfw4xf/QclQDMrYNZzcM= github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= @@ -275,6 +279,7 @@ golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5 h1:8dUaAV7K4uHsF56JQWkprecIQKdPHtR9jCHF5nB8uzc= golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -389,6 +394,8 @@ gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/go-playground/validator.v9 v9.29.0/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= gopkg.in/redis.v3 v3.6.4/go.mod h1:6XeGv/CrsUFDU9aVbUdNykN7k1zVmoeg83KC9RbQfiU= +gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4= +gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/src-d/go-billy.v4 v4.2.1/go.mod h1:tm33zBoOwxjYHZIE+OV8bxTWFMJLrconzFMd38aARFk= gopkg.in/src-d/go-billy.v4 v4.3.0/go.mod h1:tm33zBoOwxjYHZIE+OV8bxTWFMJLrconzFMd38aARFk= gopkg.in/src-d/go-git-fixtures.v3 v3.1.1/go.mod h1:dLBcvytrw/TYZsNTWCnkNF2DSIlzWYqTe3rJR56Ac7g= diff --git a/internal/auth/authenticator.go b/internal/auth/authenticator.go index dc2c4d37..d713624a 100644 --- a/internal/auth/authenticator.go +++ b/internal/auth/authenticator.go @@ -108,7 +108,7 @@ func (p *Authenticator) newMux() http.Handler { serviceMux.HandleFunc("/start", p.withMethods(p.OAuthStart, "GET")) serviceMux.HandleFunc("/sign_in", p.withMethods(p.validateClientID(p.validateRedirectURI(p.validateSignature(p.SignIn))), "GET")) serviceMux.HandleFunc("/sign_out", p.withMethods(p.validateRedirectURI(p.validateSignature(p.SignOut)), "GET", "POST")) - serviceMux.HandleFunc("/callback", p.withMethods(p.OAuthCallback, "GET")) + serviceMux.HandleFunc("/callback", p.withMethods(p.OAuthCallback, "GET", "POST")) serviceMux.HandleFunc("/profile", p.withMethods(p.validateClientID(p.validateClientSecret(p.GetProfile)), "GET")) serviceMux.HandleFunc("/validate", p.withMethods(p.validateClientID(p.validateClientSecret(p.ValidateToken)), "GET")) serviceMux.HandleFunc("/redeem", p.withMethods(p.validateClientID(p.validateClientSecret(p.Redeem)), "POST")) @@ -460,8 +460,8 @@ func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) { // Here we validate the redirect that is nested within the redirect_uri. // `authRedirectURL` points to step D, `proxyRedirectURL` points to step E. // - // A* B C D E - // /start -> Google -> auth /callback -> /sign_in -> proxy /callback + // A* B C D E + // /start -> IdProvider -> auth /callback -> /sign_in -> proxy /callback // // * you are here proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri")) @@ -485,7 +485,7 @@ func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) { func (p *Authenticator) redeemCode(host, code string) (*sessions.SessionState, error) { // The authenticator redeems `code` for an access token, and uses the token to request user - // info from the provider (Google). + // info from the provider. redirectURI := p.GetRedirectURI(host) // see providers/google.go#Redeem for more info @@ -501,7 +501,7 @@ func (p *Authenticator) redeemCode(host, code string) (*sessions.SessionState, e } func (p *Authenticator) getOAuthCallback(rw http.ResponseWriter, req *http.Request) (string, error) { - // After the provider (Google) redirects back to the sso proxy, the proxy uses this + // After the provider redirects back to the sso proxy, the proxy uses this // endpoint to set up auth cookies. logger := log.NewLogEntry() diff --git a/internal/auth/configuration.go b/internal/auth/configuration.go index c36334e5..0d245634 100644 --- a/internal/auth/configuration.go +++ b/internal/auth/configuration.go @@ -182,7 +182,9 @@ type ProviderConfig struct { Scope string `mapstructure:"scope"` // provider specific + AzureProviderConfig AzureProviderConfig `mapstructure:"azure"` GoogleProviderConfig GoogleProviderConfig `mapstructure:"google"` + OIDCProviderConfig OIDCProviderConfig `mapstructure:"oidc"` OktaProviderConfig OktaProviderConfig `mapstructure:"okta"` // caching @@ -225,6 +227,22 @@ func (pc ProviderConfig) Validate() error { return nil } +type AzureProviderConfig struct { + Tenant string `mapstructure:"tenant"` + ApprovalPrompt string `mapstructure:"prompt"` +} + +func (apc AzureProviderConfig) Validate() error { + if apc.Tenant == "" { + return xerrors.New("must specify tenant ID") + } + + if apc.ApprovalPrompt == "" { + apc.ApprovalPrompt = "consent" + } + return nil +} + type GoogleProviderConfig struct { Credentials string `mapstructure:"credentials"` Impersonate string `mapstructure:"impersonate"` @@ -250,6 +268,18 @@ func (gpc GoogleProviderConfig) Validate() error { return nil } +type OIDCProviderConfig struct { + DiscoveryURL string `mapstructure:"discovery"` +} + +func (opc OIDCProviderConfig) Validate() error { + if opc.DiscoveryURL == "" { + return xerrors.New("must specify discovery URL") + } + + return nil +} + type OktaProviderConfig struct { ServerID string `mapstructure:"server"` OrgURL string `mapstructure:"url"` diff --git a/internal/auth/options.go b/internal/auth/options.go index 5d15130e..cd041a02 100644 --- a/internal/auth/options.go +++ b/internal/auth/options.go @@ -25,7 +25,15 @@ func newProvider(pc ProviderConfig, sc SessionConfig) (providers.Provider, error var singleFlightProvider providers.Provider switch pc.ProviderType { - case providers.GoogleProviderName: // Google + case providers.AzureProviderName: + apc := pc.AzureProviderConfig + azureProvider, err := providers.NewAzureV2Provider(p) + if err != nil { + return nil, err + } + azureProvider.Configure(apc.Tenant) + singleFlightProvider = providers.NewSingleFlightProvider(azureProvider) + case providers.GoogleProviderName: gpc := pc.GoogleProviderConfig googleProvider, err := providers.NewGoogleProvider(p, gpc.ApprovalPrompt, @@ -41,6 +49,13 @@ func newProvider(pc ProviderConfig, sc SessionConfig) (providers.Provider, error googleProvider.GroupsCache = cache singleFlightProvider = providers.NewSingleFlightProvider(googleProvider) + case providers.OIDCProviderName: + opc := pc.OIDCProviderConfig + oidcProvider, err := providers.NewOIDCProvider(p, opc.DiscoveryURL) + if err != nil { + return nil, err + } + singleFlightProvider = providers.NewSingleFlightProvider(oidcProvider) case providers.OktaProviderName: opc := pc.OktaProviderConfig oktaProvider, err := providers.NewOktaProvider(p, opc.OrgURL, opc.ServerID) diff --git a/internal/auth/providers/azure.go b/internal/auth/providers/azure.go new file mode 100644 index 00000000..4c794abb --- /dev/null +++ b/internal/auth/providers/azure.go @@ -0,0 +1,309 @@ +package providers + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/buzzfeed/sso/internal/pkg/aead" + "github.com/buzzfeed/sso/internal/pkg/sessions" + "github.com/datadog/datadog-go/statsd" + "golang.org/x/oauth2" + + log "github.com/buzzfeed/sso/internal/pkg/logging" +) + +var ( + azureOIDCConfigURLTemplate = "https://login.microsoftonline.com/{tenant}/v2.0" + azureOIDCProfileURL = "https://graph.microsoft.com/oidc/userinfo" + + // This is a compile-time check to make sure our types correctly implement the interface: + // https://medium.com/@matryer/c167afed3aae + _ Provider = &AzureV2Provider{} +) + +// AzureV2Provider is an Azure AD v2 specific implementation of the Provider interface. +type AzureV2Provider struct { + *ProviderData + *OIDCProvider + + Tenant string + + StatsdClient *statsd.Client + NonceCipher aead.Cipher + GraphService GraphService +} + +// NewAzureV2Provider creates a new AzureV2Provider struct +func NewAzureV2Provider(p *ProviderData) (*AzureV2Provider, error) { + if p.ProviderName == "" { + p.ProviderName = "Azure AD" + } + + if p.ClientSecret == "" { + return nil, errors.New("client secret cannot be empty") + } + // Can't guarantee the client secret will be 32 or 64 bytes in length, + // hash to derive a key, error on empty string to avoid silent failure. + key := sha256.Sum256([]byte(p.ClientSecret)) + nonceCipher, err := aead.NewMiscreantCipher(key[:]) + if err != nil { + return nil, err + } + + return &AzureV2Provider{ + ProviderData: p, + NonceCipher: nonceCipher, + OIDCProvider: nil, + }, nil +} + +// SetStatsdClient sets the azure provider statsd client +func (p *AzureV2Provider) SetStatsdClient(statsdClient *statsd.Client) { + p.StatsdClient = statsdClient +} + +// Redeem fulfills the Provider interface. +// The authenticator uses this method to redeem the code provided to /callback after the user logs into their Azure AD account. +func (p *AzureV2Provider) Redeem(redirectURL, code string) (*sessions.SessionState, error) { + ctx := context.Background() + c := oauth2.Config{ + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: p.RedeemURL.String(), + }, + RedirectURL: redirectURL, + } + token, err := c.Exchange(ctx, code) + if err != nil { + return nil, fmt.Errorf("token exchange: %v", err) + } + + rawIDToken, ok := token.Extra("id_token").(string) + // BUG? rawIDToken is empty string here in current test cases. Test cases + // are shipping invalid ID tokens, but the Extra function doesn't seem to have + // code that would be affected by that. + if !ok || rawIDToken == "" { + fmt.Printf("token: %+v\n", token) + return nil, fmt.Errorf("token response did not contain an id_token") + } + + // should only happen if oidc autodiscovery is broken or unconfigured + if p.OIDCProvider == nil || p.OIDCProvider.Verifier == nil { + return nil, fmt.Errorf("oidc verifier missing") + } + + // Parse and verify ID Token payload. + idToken, err := p.OIDCProvider.Verifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, fmt.Errorf("could not verify id_token: %v", err) + } + + // Extract custom claims. + var claims struct { + Email string `json:"email"` + UPN string `json:"upn"` + Nonce string `json:"nonce"` + } + if err := idToken.Claims(&claims); err != nil { + return nil, fmt.Errorf("failed to parse id_token claims: %v", err) + } + if claims.Email == "" { + return nil, fmt.Errorf("id_token did not contain an email") + } + if claims.Nonce == "" { + return nil, fmt.Errorf("id_token did not contain a nonce") + } + if !p.validateNonce(claims.Nonce) { + return nil, fmt.Errorf("unable to validate id_token nonce") + } + // TODO: test this w/ an account that uses an alias and compare email claim + // with UPN claim; UPN has usually been what you want, but I think it's not + // rendered as a full email address here. + + s := &sessions.SessionState{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + + RefreshDeadline: token.Expiry, + LifetimeDeadline: sessions.ExtendDeadline(p.SessionLifetimeTTL), + + Email: claims.Email, + User: claims.UPN, + } + + if p.GraphService != nil { + groupNames, err := p.GraphService.GetGroups(claims.Email) + if err != nil { + return nil, fmt.Errorf("could not get groups: %v", err) + } + s.Groups = groupNames + } + return s, nil +} + +// Configure sets the Azure tenant ID value for the provider +func (p *AzureV2Provider) Configure(tenant string) error { + p.Tenant = tenant + if p.Tenant == "" { + // TODO: See below, "common" is the right default value, and while + // Azure AD docs suggest this should work, it results in an error. + p.Tenant = "common" + } + discoveryURL := strings.Replace(azureOIDCConfigURLTemplate, "{tenant}", p.Tenant, -1) + + // Configure discoverable provider data. + var err error + p.OIDCProvider, err = NewOIDCProvider(p.ProviderData, discoveryURL) + if err != nil { + return err + } + + p.GraphService = NewMSGraphService(p.ClientID, p.ClientSecret, p.RedeemURL.String()) + return nil +} + +// RefreshSessionIfNeeded takes in a SessionState and +// returns false if the session is not refreshed and true if it is. +func (p *AzureV2Provider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { + if s == nil || !s.RefreshPeriodExpired() || s.RefreshToken == "" { + return false, nil + } + + newToken, duration, err := p.RefreshAccessToken(s.RefreshToken) + if err != nil { + return false, err + } + logger := log.NewLogEntry() + + s.AccessToken = newToken + + s.RefreshDeadline = time.Now().Add(duration).Truncate(time.Second) + logger.WithUser(s.Email).WithRefreshDeadline(s.RefreshDeadline).Info("refreshed access token") + + return true, nil +} + +// RefreshAccessToken uses default OAuth2 TokenSource method to get a new access token. +func (p *AzureV2Provider) RefreshAccessToken(refreshToken string) (string, time.Duration, error) { + if refreshToken == "" { + return "", 0, errors.New("missing refresh token") + } + logger := log.NewLogEntry() + logger.Info("refreshing access token") + + ctx := context.Background() + c := oauth2.Config{ + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: p.RedeemURL.String(), + }, + } + t := oauth2.Token{ + RefreshToken: refreshToken, + } + ts := c.TokenSource(ctx, &t) + newToken, err := ts.Token() + if err != nil { + return "", 0, fmt.Errorf("token exchange: %v", err) + } + + return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil +} + +// Revoke does nothing for Azure AD, but needs to be declared to avoid a +// not implemented error which would prevent clearing sessions on sign out. +func (p *AzureV2Provider) Revoke(s *sessions.SessionState) error { + return nil +} + +// ValidateSessionState attempts to validate the session state's access token. +func (p *AzureV2Provider) ValidateSessionState(s *sessions.SessionState) bool { + // return validateToken(p, s.AccessToken, nil) + // TODO Validate ID token + return true +} + +// GetSignInURL returns the sign in url with typical oauth parameters +func (p *AzureV2Provider) GetSignInURL(redirectURI, state string) string { + var a url.URL + a = *p.SignInURL + params, _ := url.ParseQuery(a.RawQuery) + params.Set("client_id", p.ClientID) + params.Set("response_type", "id_token code") + params.Set("redirect_uri", redirectURI) + params.Set("response_mode", "form_post") + params.Add("scope", p.Scope) + params.Add("state", state) + params.Set("prompt", "FIXME") + params.Set("nonce", p.calculateNonce(state)) // required parameter + a.RawQuery = params.Encode() + + return a.String() +} + +// calculateNonce generates a verifiable nonce from the state value. +// A nonce can be subsequently validated by attempting to decrypt it. +func (p *AzureV2Provider) calculateNonce(state string) string { + rawNonce, err := p.NonceCipher.Encrypt([]byte(state)) + if err != nil { + // GetSignInURL can't return an error and this shouldn't fail silently + panic(err) + } + return base64.URLEncoding.EncodeToString(rawNonce) +} + +// validateNonce attempts to decrypt the nonce value. If it decrypts +// successfully, the nonce is considered valid. +func (p *AzureV2Provider) validateNonce(nonce string) bool { + rawNonce, err := base64.URLEncoding.DecodeString(nonce) + if err != nil { + return false + } + state, err := p.NonceCipher.Decrypt(rawNonce) + if err != nil { + return false + } + // Sanity check to ensure state contains roughly what we expect + _, err = base64.URLEncoding.DecodeString(string(state)) + if err != nil { + return false + } + return true +} + +// ValidateGroupMembership takes in an email and the allowed groups and returns the groups that the email is part of in that list. +// If `allGroups` is an empty list it returns all the groups that the user belongs to. +func (p *AzureV2Provider) ValidateGroupMembership(email string, allGroups []string, _ string) ([]string, error) { + if p.GraphService == nil { + panic("provider has not been configured") + } + + userGroups, err := p.GraphService.GetGroups(email) + if err != nil { + return nil, err + } + + // if `allGroups` is empty use the groups resource + if len(allGroups) == 0 { + return userGroups, nil + } + + filtered := []string{} + for _, userGroup := range userGroups { + for _, allowedGroup := range allGroups { + if userGroup == allowedGroup { + filtered = append(filtered, userGroup) + } + } + } + + return filtered, nil +} diff --git a/internal/auth/providers/azure_test.go b/internal/auth/providers/azure_test.go new file mode 100644 index 00000000..9f1bed90 --- /dev/null +++ b/internal/auth/providers/azure_test.go @@ -0,0 +1,561 @@ +package providers + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "log" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "strings" + "testing" + "time" + + "github.com/buzzfeed/sso/internal/pkg/sessions" + "github.com/buzzfeed/sso/internal/pkg/testutil" + + jose "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +const ( + microsoftTenantID = "9188040d-6c67-4c5b-b112-36a304b66dad" + testClientID = "a4c35c92-e858-41e8-bd2c-ade04cb622b1" + testClientSecret = "4" // number chosen at random +) + +func newAzureProviderServer(redeemBody *[]byte, redeemCode int, pubKey *rsa.PublicKey) (*url.URL, *httptest.Server) { + var u *url.URL + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Header().Set("Content-Type", "application/json") + + // simple router for the mock IdP server + switch r.RequestURI { + case "/.well-known/openid-configuration": + rw.WriteHeader(200) + rw.Write([]byte(fmt.Sprintf(`{ + "token_endpoint":"%s/oauth2/token", + "jwks_uri":"%s/common/discovery/keys", + "issuer":"%s" + }`, u, u, u))) + case "/common/discovery/keys": + pub := jose.JSONWebKey{Key: pubKey, Algorithm: "RSA", Use: "sig"} + keyData, err := pub.MarshalJSON() + if err != nil { + panic(err) + } + rw.WriteHeader(200) + rw.Write([]byte(fmt.Sprintf(`{ + "keys":[%s] + }`, keyData))) + case "/oauth2/token": + rw.WriteHeader(redeemCode) + rw.Write(*redeemBody) + } + })) + u, _ = url.Parse(s.URL) + return u, s +} + +func newAzureV2Provider(providerData *ProviderData) *AzureV2Provider { + if providerData == nil { + providerData = &ProviderData{ + ProviderName: "", + ClientID: testClientID, + ClientSecret: testClientSecret, + SignInURL: &url.URL{}, + RedeemURL: &url.URL{}, + RevokeURL: &url.URL{}, + ProfileURL: &url.URL{}, + ValidateURL: &url.URL{}, + Scope: ""} + } + provider, err := NewAzureV2Provider(providerData) + if err != nil { + panic(err) + } + return provider +} + +func TestAzureV2ProviderDefaults(t *testing.T) { + expectedResults := []struct { + name string + providerData *ProviderData + signInURL string + redeemURL string + revokeURL string + profileURL string + validateURL string + scope string + }{ + { + name: "defaults", + signInURL: "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/oauth2/v2.0/authorize", + redeemURL: "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/oauth2/v2.0/token", + profileURL: "https://graph.microsoft.com/oidc/userinfo", + revokeURL: "", // does not exist + validateURL: "", // does not exist + scope: "openid email profile offline_access", + }, + { + name: "with provider overrides", + providerData: &ProviderData{ + ClientID: "1234", + ClientSecret: "4", // Number chosen at random + SignInURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/auth"}, + RedeemURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/token"}, + RevokeURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/deauth"}, + ProfileURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/profile"}, + ValidateURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/tokeninfo"}, + Scope: "profile", + }, + signInURL: "https://example.com/oauth/auth", + redeemURL: "https://example.com/oauth/token", + revokeURL: "https://example.com/oauth/deauth", + profileURL: "https://example.com/oauth/profile", + validateURL: "https://example.com/oauth/tokeninfo", + scope: "profile", + }, + } + for _, expected := range expectedResults { + t.Run(expected.name, func(t *testing.T) { + p := newAzureV2Provider(expected.providerData) + err := p.Configure(microsoftTenantID) + if err != nil { + t.Error(err) + } + if p == nil { + t.Errorf("azure provider was nil") + } + if p.Data().ProviderName != "Azure AD" { + t.Errorf("expected provider name Azure AD, got %s", p.Data().ProviderName) + } + if p.Data().SignInURL.String() != expected.signInURL { + log.Printf("expected %s", expected.signInURL) + log.Printf("got %s", p.Data().SignInURL.String()) + t.Errorf("unexpected signin url") + } + + if p.Data().RedeemURL.String() != expected.redeemURL { + log.Printf("expected %s", expected.redeemURL) + log.Printf("got %s", p.Data().RedeemURL.String()) + t.Errorf("unexpected redeem url") + } + + if p.Data().RevokeURL.String() != expected.revokeURL { + log.Printf("expected %s", expected.revokeURL) + log.Printf("got %s", p.Data().RevokeURL.String()) + t.Errorf("unexpected revoke url") + } + + if p.Data().ValidateURL.String() != expected.validateURL { + log.Printf("expected %s", expected.validateURL) + log.Printf("got %s", p.Data().ValidateURL.String()) + t.Errorf("unexpected validate url") + } + + if p.Data().ProfileURL.String() != expected.profileURL { + log.Printf("expected %s", expected.profileURL) + log.Printf("got %s", p.Data().ProfileURL.String()) + t.Errorf("unexpected profile url") + } + + if p.Data().Scope != expected.scope { + log.Printf("expected %s", expected.scope) + log.Printf("got %s", p.Data().Scope) + t.Errorf("unexpected scope") + } + }) + + } +} + +// claims represents public claim values (as specified in RFC 7519). +type claims struct { + Issuer string `json:"iss,omitempty"` + Audience string `json:"aud,omitempty"` + Expiry *jwt.NumericDate `json:"exp,omitempty"` + Name string `json:"name,omitempty"` + Email string `json:"email,omitempty"` + Nonce string `json:"nonce,omitempty"` + NotEmail string `json:"not_email,omitempty"` +} + +func TestAzureV2ProviderRedeem(t *testing.T) { + // For testing create the RSA key pair in the code + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Error(err) + } + // create Square.jose signing key + key := jose.SigningKey{Algorithm: jose.RS256, Key: privKey} + + // create a Square.jose RSA signer, used to sign the JWT + var signerOpts = jose.SignerOptions{} + signerOpts.WithType("JWT") + rsaSigner, err := jose.NewSigner(key, &signerOpts) + if err != nil { + t.Error(err) + } + + testCases := []struct { + name string + claims *claims + resp redeemResponse + expectedError bool + expectedSession *sessions.SessionState + }{ + { + name: "redeem", + claims: &claims{ + Issuer: "{mock-issuer}", + Audience: testClientID, + Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), + Name: "Michael Bland", + Email: "michael.bland@gsa.gov", + Nonce: "{mock-nonce}", + }, + resp: redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + }, + expectedSession: &sessions.SessionState{ + Email: "michael.bland@gsa.gov", + AccessToken: "a1234", + RefreshToken: "refresh12345", + }, + }, + { + name: "missing issuer", + claims: &claims{ + Audience: testClientID, + Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), + Name: "Michael Bland", + Email: "michael.bland@gsa.gov", + Nonce: "{mock-nonce}", + }, + resp: redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + }, + expectedError: true, + }, + { + name: "invalid issuer", + claims: &claims{ + Issuer: "https://example.com/bogus/issuer", + Audience: testClientID, + Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), + Name: "Michael Bland", + Email: "michael.bland@gsa.gov", + Nonce: "{mock-nonce}", + }, + resp: redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + }, + expectedError: true, + }, + { + name: "missing nonce", + claims: &claims{ + Issuer: "{mock-issuer}", + Audience: testClientID, + Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), + Name: "Michael Bland", + Email: "michael.bland@gsa.gov", + }, + resp: redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + }, + expectedError: true, + }, + { + name: "invalid nonce", + claims: &claims{ + Issuer: "{mock-issuer}", + Audience: testClientID, + Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), + Name: "Michael Bland", + Email: "michael.bland@gsa.gov", + Nonce: "123456789", + }, + resp: redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + }, + expectedError: true, + }, + { + name: "invalid encoding", + resp: redeemResponse{ + AccessToken: "a1234", + IDToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + `{"name": "Michael Bland","email": "michael.bland@gsa.gov"}` + ".SC_eJ3K04rLOPLLDIWEKwr0DPZqw5KlFySybzmxfM6Y", + }, + expectedError: true, + }, + { + name: "invalid json", + resp: redeemResponse{ + AccessToken: "a1234", + IDToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)) + ".SC_eJ3K04rLOPLLDIWEKwr0DPZqw5KlFySybzmxfM6Y", + }, + expectedError: true, + }, + { + name: "missing email", + claims: &claims{ + NotEmail: "missing", + }, + resp: redeemResponse{ + AccessToken: "a1234", + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var body []byte + var server *httptest.Server + // pointer to body to bypass chicken/egg issue w/ mock server urls + providerURL, server := newAzureProviderServer(&body, http.StatusOK, &privKey.PublicKey) + defer server.Close() + // swap the global OIDC URL template for a test provider URL + azureOIDCConfigURLTemplate = providerURL.String() + + p := newAzureV2Provider(nil) + err = p.Configure(microsoftTenantID) + if err != nil { + t.Error(err) + } + + if tc.claims != nil { + // create an instance of Builder that uses the rsa signer + builder := jwt.Signed(rsaSigner) + + // add claims to the Builder + tc.claims.Issuer = strings.Replace(tc.claims.Issuer, "{mock-issuer}", providerURL.String(), -1) + tc.claims.Nonce = strings.Replace(tc.claims.Nonce, "{mock-nonce}", p.calculateNonce("1234"), -1) + builder = builder.Claims(tc.claims) + + // build and inject ID token into response + idToken, err := builder.CompactSerialize() + if err != nil { + t.Error(err) + } + tc.resp.IDToken = idToken + } + + body, err = json.Marshal(tc.resp) + testutil.Equal(t, nil, err) + + // graph service mock has to be set after p.Configure + p.GraphService = &MockMSGraphService{} + + session, err := p.Redeem("http://redirect/", "code1234") + if tc.expectedError && err == nil { + t.Errorf("expected redeem error but was nil") + } + if !tc.expectedError && err != nil { + t.Errorf("unexpected error %s", err) + } + if tc.expectedSession == nil && session != nil { + t.Errorf("expected session to be nil but it was %s", session) + } + if session != nil && tc.expectedSession != nil { + if session.Email != tc.expectedSession.Email { + log.Printf("expected email %s", tc.expectedSession.Email) + log.Printf("got %s", session.Email) + t.Errorf("unexpected session email") + } + + if session.AccessToken != tc.expectedSession.AccessToken { + log.Printf("expected access token %s", tc.expectedSession.AccessToken) + log.Printf("got %s", session.AccessToken) + t.Errorf("unexpected access token") + } + + if session.RefreshToken != tc.expectedSession.RefreshToken { + log.Printf("expected refresh token %s", tc.expectedSession.RefreshToken) + log.Printf("got %s", session.RefreshToken) + t.Errorf("unexpected session refresh token") + } + } + }) + } +} + +type groupsClientMock struct { +} + +func (c *groupsClientMock) Do(req *http.Request) (*http.Response, error) { + return &http.Response{}, nil +} + +func TestAzureV2GetSignInURL(t *testing.T) { + testCases := []struct { + name string + redirectURI string + state string + expectedParams url.Values + }{ + { + name: "nonce values passed to azure should validate, pass one", + redirectURI: "https://example.com/oauth/callback", + state: "1234", + expectedParams: url.Values{ + "redirect_uri": []string{"https://example.com/oauth/callback"}, + "response_mode": []string{"form_post"}, + "response_type": []string{"id_token code"}, + "scope": []string{"openid email profile offline_access"}, + "state": []string{"1234"}, + "client_id": []string{testClientID}, + "prompt": []string{"consent"}, + }, + }, + { + name: "nonce values passed to azure should validate, pass two", + redirectURI: "https://example.com/oauth/callback", + state: "1234", + expectedParams: url.Values{ + "redirect_uri": []string{"https://example.com/oauth/callback"}, + "response_mode": []string{"form_post"}, + "response_type": []string{"id_token code"}, + "scope": []string{"openid email profile offline_access"}, + "state": []string{"1234"}, + "client_id": []string{testClientID}, + "prompt": []string{"consent"}, + }, + }, + { + name: "nonce values passed to azure should validate, pass three", + redirectURI: "https://example.com/oauth/callback", + state: "4321", + expectedParams: url.Values{ + "redirect_uri": []string{"https://example.com/oauth/callback"}, + "response_mode": []string{"form_post"}, + "response_type": []string{"id_token code"}, + "scope": []string{"openid email profile offline_access"}, + "state": []string{"4321"}, + "client_id": []string{testClientID}, + "prompt": []string{"consent"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p := newAzureV2Provider(nil) + p.Scope = "openid email profile offline_access" + + signInURL := p.GetSignInURL(tc.redirectURI, tc.state) + parsedURL, err := url.Parse(signInURL) + if err != nil { + t.Error(err) + } + + for k := range tc.expectedParams { + if tc.expectedParams.Get(k) != parsedURL.Query().Get(k) { + t.Logf("expected param %s: %+v", k, tc.expectedParams.Get(k)) + t.Logf("got param %s: %+v", k, parsedURL.Query().Get(k)) + } + } + + nonce := parsedURL.Query().Get("nonce") + if p.validateNonce(nonce) != true { + t.Logf("expected valid nonce, got: %+v", nonce) + } + }) + } +} + +func TestAzureV2ValidateGroupMembers(t *testing.T) { + testCases := []struct { + name string + allowedGroups []string + mockedGroups []string + mockedError error + expectedGroups []string + expectedErrorString string + }{ + { + name: "allowed groups and groups resource output exactly match should return all groups", + allowedGroups: []string{"group1", "group2", "group3"}, + mockedGroups: []string{"group1", "group2", "group3"}, + expectedGroups: []string{"group1", "group2", "group3"}, + }, + { + name: "allowed groups should restrict to subset of groups", + allowedGroups: []string{"group1", "group2"}, + mockedGroups: []string{"group1", "group2", "group3"}, + expectedGroups: []string{"group1", "group2"}, + }, + { + name: "allowed groups superset should not restrict to subset of groups", + allowedGroups: []string{"group1", "group2", "group3"}, + mockedGroups: []string{"group1", "group2"}, + expectedGroups: []string{"group1", "group2"}, + }, + { + name: "groups allowed zero value should default to return all groups", + allowedGroups: []string{}, + mockedGroups: []string{"group1"}, + expectedGroups: []string{"group1"}, + }, + { + name: "empty inputs and error on groups resource should return error", + allowedGroups: []string{}, + mockedError: fmt.Errorf("error"), + expectedErrorString: "error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p := newAzureV2Provider(nil) + p.GraphService = &MockMSGraphService{ + Groups: tc.mockedGroups, + GroupsError: tc.mockedError, + } + + groups, err := p.ValidateGroupMembership("test@example.com", tc.allowedGroups, "") + + if err != nil { + if tc.expectedErrorString != err.Error() { + t.Errorf("expected error %s but err was %s", tc.expectedErrorString, err) + } + } + if !reflect.DeepEqual(tc.expectedGroups, groups) { + t.Logf("expected groups %v", tc.expectedGroups) + t.Logf("got groups %v", groups) + t.Errorf("unexpected groups returned") + } + }) + } +} diff --git a/internal/auth/providers/google.go b/internal/auth/providers/google.go index 9a8cf5fa..5bf37b60 100644 --- a/internal/auth/providers/google.go +++ b/internal/auth/providers/google.go @@ -21,6 +21,12 @@ import ( "github.com/datadog/datadog-go/statsd" ) +var ( + // This is a compile-time check to make sure our types correctly implement the interface: + // https://medium.com/@matryer/c167afed3aae + _ Provider = &GoogleProvider{} +) + // GoogleProvider is an implementation of the Provider interface. type GoogleProvider struct { *ProviderData diff --git a/internal/auth/providers/google_admin.go b/internal/auth/providers/google_admin.go index 14b319a6..bbfc54b0 100644 --- a/internal/auth/providers/google_admin.go +++ b/internal/auth/providers/google_admin.go @@ -18,6 +18,8 @@ import ( ) // AdminService wraps calls to provider admin APIs +// +// This interface allows the service to be more readily mocked in tests. type AdminService interface { ListMemberships(group string, depth int) (members []string, err error) CheckMemberships(groups []string, user string) (inGroups []string, errr error) diff --git a/internal/auth/providers/google_test.go b/internal/auth/providers/google_test.go index 596a1ac2..52cea26f 100644 --- a/internal/auth/providers/google_test.go +++ b/internal/auth/providers/google_test.go @@ -20,12 +20,14 @@ import ( func newProviderServer(body []byte, code int) (*url.URL, *httptest.Server) { s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Header().Set("Content-Type", "application/json") rw.WriteHeader(code) rw.Write(body) })) u, _ := url.Parse(s.URL) return u, s } + func newGoogleProvider(providerData *ProviderData) *GoogleProvider { if providerData == nil { providerData = &ProviderData{ @@ -279,7 +281,7 @@ func TestGoogleProviderRevoke(t *testing.T) { } } -func TestValidateGroupMembers(t *testing.T) { +func TestGoogleValidateGroupMembers(t *testing.T) { testCases := []struct { name string inputAllowedGroups []string diff --git a/internal/auth/providers/ms_graph_api.go b/internal/auth/providers/ms_graph_api.go new file mode 100644 index 00000000..225c3776 --- /dev/null +++ b/internal/auth/providers/ms_graph_api.go @@ -0,0 +1,179 @@ +package providers + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + "sync" + + lru "github.com/hashicorp/golang-lru" + "golang.org/x/oauth2/clientcredentials" +) + +// The Microsoft Graph API provides a mechanism to obtain group membership +// information for users authenticated with the Azure AD provider. + +// azureGroupCacheSize controls the size of the caches of AD group info +const azureGroupCacheSize = 1024 + +// GraphService wraps calls to provider admin APIs +// +// This interface allows the service to be more readily mocked in tests. +type GraphService interface { + GetGroups(string) ([]string, error) +} + +// MSGraphService implements graph API calls for the Azure provider +type MSGraphService struct { + client *http.Client + groupMembershipCache *lru.Cache + groupNameCache *lru.Cache +} + +// NewMSGraphService creates a new graph service for getting groups +func NewMSGraphService(clientID string, clientSecret string, tokenURL string) *MSGraphService { + clientConfig := &clientcredentials.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + TokenURL: tokenURL, + Scopes: []string{ + "https://graph.microsoft.com/.default", + }, + } + ctx := context.Background() + client := clientConfig.Client(ctx) + memberCache, err := lru.New(azureGroupCacheSize) + if err != nil { + panic(err) // Should only happen if azureGroupCacheSize is a negative number + } + nameCache, err := lru.New(azureGroupCacheSize) + if err != nil { + panic(err) // Should only happen if azureGroupCacheSize is a negative number + } + return &MSGraphService{ + client: client, + groupMembershipCache: memberCache, + groupNameCache: nameCache, + } +} + +// GetGroups lists groups user belongs to. +func (gs *MSGraphService) GetGroups(email string) ([]string, error) { + if gs.client == nil { + return nil, errors.New("oauth client must be configured") + } + if email == "" { + return nil, errors.New("missing email") + } + + var wg sync.WaitGroup + var mux sync.Mutex + var err error + groupNames := make([]string, 0) + // See: https://developer.microsoft.com/en-us/graph/docs/api-reference/beta/api/user_getmembergroups + requestBody := `{"securityEnabledOnly": false}` + requestURL := fmt.Sprintf("https://graph.microsoft.com/beta/users/%s/getMemberGroups", url.PathEscape(email)) + for { + groupResponse, err := gs.client.Post(requestURL, "application/json", strings.NewReader(requestBody)) + if err != nil { + return nil, err + } + + groupData := struct { + // Link to next page of data, see: + // https://docs.microsoft.com/en-us/graph/query-parameters#skip-parameter + Next string `json:"@odata.nextLink"` + Value []string `json:"value"` + }{} + + body, err := ioutil.ReadAll(groupResponse.Body) + if err != nil { + return nil, err + } + if groupResponse.StatusCode != http.StatusOK { + return nil, fmt.Errorf("api error: %s", string(body)) + } + + err = json.Unmarshal(body, &groupData) + if err != nil { + return nil, err + } + + for _, groupID := range groupData.Value { + wg.Add(1) + id := groupID + go func(wg *sync.WaitGroup) { + defer wg.Done() + + var name string + // check the cache for the group name first + if cachedName, ok := gs.groupNameCache.Get(id); !ok { + // didn't have the group name, make concurrent API call to fetch it + name, err = gs.getGroupName(id) + // the err value is not shadowed in the goroutine, so if this isn't + // nil, it will return the err value after wg.Wait() is called + if err == nil { + // got the name ok, populate the cache + gs.groupNameCache.Add(id, name) + } + } else { + // cache hit + name = cachedName.(string) + } + mux.Lock() + groupNames = append(groupNames, name) + mux.Unlock() + }(&wg) + } + + if groupData.Next != "" { + requestURL = groupData.Next + } else { + break + } + } + wg.Wait() + // any err value set above will cause this to fail + if err != nil { + return nil, err + } + + return groupNames, nil +} + +// getGroupName returns the group name, preferentially pulling from cache +func (gs *MSGraphService) getGroupName(id string) (string, error) { + if gs.client == nil { + return "", errors.New("oauth client must be configured") + } + // See: https://developer.microsoft.com/en-us/graph/docs/api-reference/v1.0/api/group_get + requestURL := fmt.Sprintf("https://graph.microsoft.com/v1.0/groups/%s", url.PathEscape(id)) + groupMetaResponse, err := gs.client.Get(requestURL) + if err != nil { + return "", err + } + + groupMetadata := struct { + DisplayName string `json:"displayName"` + }{} + + body, err := ioutil.ReadAll(groupMetaResponse.Body) + if err != nil { + return "", err + } + if groupMetaResponse.StatusCode != http.StatusOK { + return "", fmt.Errorf("api error: %s", string(body)) + } + + err = json.Unmarshal(body, &groupMetadata) + if err != nil { + return "", err + } + + return groupMetadata.DisplayName, nil +} diff --git a/internal/auth/providers/ms_graph_mock.go b/internal/auth/providers/ms_graph_mock.go new file mode 100644 index 00000000..e841b842 --- /dev/null +++ b/internal/auth/providers/ms_graph_mock.go @@ -0,0 +1,12 @@ +package providers + +// MockMSGraphService is an implementation of GraphService to be used for testing +type MockMSGraphService struct { + Groups []string + GroupsError error +} + +// GetGroups mocks the GetGroups function +func (ms *MockMSGraphService) GetGroups(string) ([]string, error) { + return ms.Groups, ms.GroupsError +} diff --git a/internal/auth/providers/oidc.go b/internal/auth/providers/oidc.go new file mode 100644 index 00000000..7d1a5be5 --- /dev/null +++ b/internal/auth/providers/oidc.go @@ -0,0 +1,180 @@ +package providers + +import ( + "context" + "errors" + "fmt" + "net/url" + "time" + + "golang.org/x/oauth2" + + log "github.com/buzzfeed/sso/internal/pkg/logging" + "github.com/buzzfeed/sso/internal/pkg/sessions" + oidc "github.com/coreos/go-oidc" +) + +// OIDCProvider is a generic OpenID Connect provider +type OIDCProvider struct { + *ProviderData + + Verifier *oidc.IDTokenVerifier +} + +// NewOIDCProvider creates a new generic OpenID Connect provider +func NewOIDCProvider(p *ProviderData, discoveryURL string) (*OIDCProvider, error) { + provider := &OIDCProvider{ + ProviderData: p, + } + if p.ProviderName == "" { + p.ProviderName = "OpenID Connect" + } + + // Configure discoverable provider data. + oidcProvider, err := oidc.NewProvider(context.Background(), discoveryURL) + if err != nil { + // TODO: This seems like it _should_ work for "common", but it doesn't + // Does anyone actually want to use this with "common" though? + return nil, err + } + + provider.Verifier = oidcProvider.Verifier(&oidc.Config{ + ClientID: p.ClientID, + }) + // Set these only if they haven't been overridden + if p.SignInURL == nil || p.SignInURL.String() == "" { + p.SignInURL, err = url.Parse(oidcProvider.Endpoint().AuthURL) + if err != nil { + return nil, err + } + } + if p.RedeemURL == nil || p.RedeemURL.String() == "" { + p.RedeemURL, err = url.Parse(oidcProvider.Endpoint().TokenURL) + if err != nil { + return nil, err + } + } + if p.ProfileURL == nil || p.ProfileURL.String() == "" { + p.ProfileURL, err = url.Parse(azureOIDCProfileURL) + } + if err != nil { + return nil, err + } + if p.Scope == "" { + p.Scope = "openid email profile offline_access" + } + if p.RedeemURL.String() == "" { + return nil, errors.New("redeem url must be set") + } + + return provider, nil +} + +// Redeem fulfills the Provider interface. +// The authenticator uses this method to redeem the code provided to /callback after the user logs into their OpenID Connect account. +func (p *OIDCProvider) Redeem(redirectURL, code string) (*sessions.SessionState, error) { + ctx := context.Background() + c := oauth2.Config{ + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: p.RedeemURL.String(), + }, + RedirectURL: redirectURL, + } + token, err := c.Exchange(ctx, code) + if err != nil { + return nil, fmt.Errorf("token exchange: %v", err) + } + + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return nil, fmt.Errorf("token response did not contain an id_token") + } + + // Parse and verify ID Token payload. + idToken, err := p.Verifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, fmt.Errorf("could not verify id_token: %v", err) + } + + // Extract custom claims. + var claims struct { + Email string `json:"email"` + Verified *bool `json:"email_verified"` + } + if err := idToken.Claims(&claims); err != nil { + return nil, fmt.Errorf("failed to parse id_token claims: %v", err) + } + + if claims.Email == "" { + return nil, fmt.Errorf("id_token did not contain an email") + } + if claims.Verified != nil && !*claims.Verified { + return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) + } + + s := &sessions.SessionState{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + RefreshDeadline: token.Expiry, + LifetimeDeadline: sessions.ExtendDeadline(p.SessionLifetimeTTL), + Email: claims.Email, + } + + return s, nil +} + +// RefreshSessionIfNeeded takes in a SessionState and +// returns false if the session is not refreshed and true if it is. +func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { + if s == nil || !s.RefreshPeriodExpired() || s.RefreshToken == "" { + return false, nil + } + + newToken, duration, err := p.RefreshAccessToken(s.RefreshToken) + if err != nil { + return false, err + } + logger := log.NewLogEntry() + + s.AccessToken = newToken + + s.RefreshDeadline = time.Now().Add(duration).Truncate(time.Second) + logger.WithUser(s.Email).WithRefreshDeadline(s.RefreshDeadline).Info("refreshed access token") + + return true, nil +} + +// RefreshAccessToken uses default OAuth2 TokenSource method to get a new access token. +func (p *OIDCProvider) RefreshAccessToken(refreshToken string) (string, time.Duration, error) { + if refreshToken == "" { + return "", 0, errors.New("missing refresh token") + } + + ctx := context.Background() + c := oauth2.Config{ + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: p.RedeemURL.String(), + }, + } + t := oauth2.Token{ + RefreshToken: refreshToken, + } + ts := c.TokenSource(ctx, &t) + newToken, err := ts.Token() + if err != nil { + return "", 0, fmt.Errorf("token exchange: %v", err) + } + + return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil +} + +// ValidateSessionState attempts to validate the session state's access token. +func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool { + // return validateToken(p, s.AccessToken, nil) + // TODO Validate ID token + return true +} diff --git a/internal/auth/providers/provider_default.go b/internal/auth/providers/provider_default.go index 054ce77d..45b12b95 100644 --- a/internal/auth/providers/provider_default.go +++ b/internal/auth/providers/provider_default.go @@ -172,7 +172,7 @@ func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, e return false, nil } -// RefreshAccessToken returns a nont implemented error. +// RefreshAccessToken returns a not implemented error. func (p *ProviderData) RefreshAccessToken(refreshToken string) (string, time.Duration, error) { return "", 0, ErrNotImplemented } diff --git a/internal/auth/providers/providers.go b/internal/auth/providers/providers.go index f69583dd..16ecc540 100644 --- a/internal/auth/providers/providers.go +++ b/internal/auth/providers/providers.go @@ -29,8 +29,12 @@ var ( ) const ( + // AzureProviderName identifies the Azure AD v2 provider + AzureProviderName = "azure_v2" // GoogleProviderName identifies the Google provider GoogleProviderName = "google" + // OIDCProviderName identifies the OpenID Connect provider + OIDCProviderName = "oidc" // OktaProviderName identities the Okta provider OktaProviderName = "okta" ) diff --git a/internal/auth/providers/singleflight_middleware.go b/internal/auth/providers/singleflight_middleware.go index 081d46dc..2db1bb17 100644 --- a/internal/auth/providers/singleflight_middleware.go +++ b/internal/auth/providers/singleflight_middleware.go @@ -15,7 +15,7 @@ import ( var ( // This is a compile-time check to make sure our types correctly implement the interface: - // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae + // https://medium.com/@matryer/c167afed3aae _ Provider = &SingleFlightProvider{} ) diff --git a/internal/pkg/aead/aead.go b/internal/pkg/aead/aead.go index 5019ba05..d9a41aa4 100644 --- a/internal/pkg/aead/aead.go +++ b/internal/pkg/aead/aead.go @@ -89,8 +89,9 @@ func (c *MiscreantCipher) Decrypt(joined []byte) ([]byte, error) { return plaintext, nil } -// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher -// and base64 encodes the binary value as a string and returns the result +// Marshal marshals the interface state as JSON, gzips it, encrypts the JSON +// using the cipher, and base64 encodes the binary value as a string and +// returns the result. func (c *MiscreantCipher) Marshal(s interface{}) (string, error) { // encode json value plaintext, err := json.Marshal(s) @@ -98,7 +99,7 @@ func (c *MiscreantCipher) Marshal(s interface{}) (string, error) { return "", err } - // gunzip the bytes + // gzip the bytes var jsonBuffer bytes.Buffer w := gzip.NewWriter(&jsonBuffer) w.Write(plaintext) @@ -115,8 +116,9 @@ func (c *MiscreantCipher) Marshal(s interface{}) (string, error) { return encoded, nil } -// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the -// byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed +// Unmarshal takes the marshaled string, base64-decodes into a byte slice, +// decrypts the byte slice the pased cipher, gunzips it, and unmarshals +// the resulting JSON into the struct pointer passed. func (c *MiscreantCipher) Unmarshal(value string, s interface{}) error { // convert base64 string value to bytes ciphertext, err := base64.RawURLEncoding.DecodeString(value) @@ -130,7 +132,7 @@ func (c *MiscreantCipher) Unmarshal(value string, s interface{}) error { return err } - // gzip the bytes + // gunzip the bytes var jsonBuffer bytes.Buffer r, err := gzip.NewReader(bytes.NewBuffer(plaintext)) if err != nil { diff --git a/internal/proxy/providers/singleflight_middleware.go b/internal/proxy/providers/singleflight_middleware.go index 40e6fcf7..f21bb586 100644 --- a/internal/proxy/providers/singleflight_middleware.go +++ b/internal/proxy/providers/singleflight_middleware.go @@ -15,7 +15,7 @@ import ( var ( // This is a compile-time check to make sure our types correctly implement the interface: - // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae + // https://medium.com/@matryer/c167afed3aae _ Provider = &SingleFlightProvider{} ) diff --git a/internal/proxy/providers/sso.go b/internal/proxy/providers/sso.go index 44eaff27..18a0d4bb 100644 --- a/internal/proxy/providers/sso.go +++ b/internal/proxy/providers/sso.go @@ -23,7 +23,7 @@ import ( var ( // This is a compile-time check to make sure our types correctly implement the interface: - // https://medium.com/@matryer/golang-tip-compile-time-checks-to-ensure-your-type-satisfies-an-interface-c167afed3aae + // https://medium.com/@matryer/c167afed3aae _ Provider = &SSOProvider{} ) @@ -177,9 +177,11 @@ func (p *SSOProvider) Redeem(redirectURL, code string) (*sessions.SessionState, // an authorized group. func (p *SSOProvider) ValidateGroup(email string, allowedGroups []string, accessToken string) ([]string, bool, error) { logger := log.NewLogEntry() + logger.Info("called sso.go ValidateGroup") logger.WithUser(email).WithAllowedGroups(allowedGroups).Info("validating groups") inGroups := []string{} + logger.Printf("allowedGroups: %v", allowedGroups) if len(allowedGroups) == 0 { return inGroups, true, nil } diff --git a/internal/proxy/proxy_config.go b/internal/proxy/proxy_config.go index 120eb1f9..2544052e 100644 --- a/internal/proxy/proxy_config.go +++ b/internal/proxy/proxy_config.go @@ -8,6 +8,7 @@ import ( "time" "github.com/18F/hmacauth" + log "github.com/buzzfeed/sso/internal/pkg/logging" "github.com/imdario/mergo" "gopkg.in/yaml.v2" ) @@ -198,6 +199,9 @@ func loadServiceConfigs(raw []byte, cluster, scheme string, configVars map[strin if err != nil { return nil, err } + logger := log.NewLogEntry() + logger.Printf("proxy.AllowedGroups: %v", proxy.AllowedGroups) + } for _, proxy := range configs {