Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"net/url"
"slices"
"strings"
"time"

Expand All @@ -27,11 +28,15 @@ var (

// Discover calls the discovery endpoint of the provided issuer and returns its configuration
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) {
func Discover(ctx context.Context, issuers []string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) {
ctx, span := Tracer.Start(ctx, "Discover")
defer span.End()

wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
wellKnown := ""
if len(issuers) > 0 {
wellKnown = strings.TrimSuffix(issuers[0], "/") + oidc.DiscoveryEndpoint
}

if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" {
wellKnown = wellKnownUrl[0]
}
Expand All @@ -48,7 +53,7 @@ func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellK
logger.Debug("discover", "config", discoveryConfig)
}

if false && discoveryConfig.Issuer != issuer {
if !slices.Contains(issuers, discoveryConfig.Issuer) {
Copy link
Copy Markdown

@mmgopher mmgopher Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How it is suppose to work? This check doesn't make sense for me. You do not know tenantid at this stage (for multi-tenant mode). For multi-tenant the issuer is https://login.microsoftonline.com/common/ and the dicoveryConfig.Issuer = https://sts.windows.net/{tenantid}/. The placeholder {tenantid} is not resolved with real tenantid. It will work only if you have string https://sts.windows.net/{tenantid}/ in the slice issuers []string

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, I have added https://sts.windows.net/{tenantid}/ to the permitted issuer list in config

return nil, oidc.ErrIssuerInvalid
}
return discoveryConfig, nil
Expand Down
4 changes: 2 additions & 2 deletions pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import (
"net/http"
"testing"

"github.com/datasapiens/oidc/v3/pkg/oidc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/datasapiens/oidc/v3/pkg/oidc"
)

func TestDiscover(t *testing.T) {
Expand Down Expand Up @@ -45,7 +45,7 @@ func TestDiscover(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Discover(context.Background(), tt.args.issuer, http.DefaultClient, tt.args.wellKnownUrl...)
got, err := Discover(context.Background(), []string{tt.args.issuer}, http.DefaultClient, tt.args.wellKnownUrl...)
require.ErrorIs(t, err, tt.wantErr)
if tt.wantFields == nil {
return
Expand Down
2 changes: 1 addition & 1 deletion pkg/client/profile/jwt_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func NewJWTProfileTokenSource(ctx context.Context, issuer, clientID, keyID strin
opt(source)
}
if source.tokenEndpoint == "" {
config, err := client.Discover(ctx, issuer, source.httpClient)
config, err := client.Discover(ctx, []string{issuer}, source.httpClient)
if err != nil {
return nil, err
}
Expand Down
26 changes: 21 additions & 5 deletions pkg/client/rp/relying_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ var DefaultUnauthorizedHandler UnauthorizedHandler = func(w http.ResponseWriter,
}

type relyingParty struct {
issuer string
issuers []string
DiscoveryEndpoint string
endpoints Endpoints
oauthConfig *oauth2.Config
Expand All @@ -128,7 +128,15 @@ func (rp *relyingParty) OAuthConfig() *oauth2.Config {
}

func (rp *relyingParty) Issuer() string {
return rp.issuer
if len(rp.issuers) == 0 {
return ""
}

return rp.issuers[0]
}

func (rp *relyingParty) Issuers() []string {
return rp.issuers
}

func (rp *relyingParty) IsPKCE() bool {
Expand Down Expand Up @@ -169,7 +177,7 @@ func (rp *relyingParty) GetRevokeEndpoint() string {

func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier {
if rp.idTokenVerifier == nil {
rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...)
rp.idTokenVerifier = NewIDTokenVerifier(rp.issuers, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...)
}
return rp.idTokenVerifier
}
Expand Down Expand Up @@ -236,8 +244,16 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart
// issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions
// it will run discovery on the provided issuer and use the found endpoints
func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) {

return NewRelyingPartyOIDCWithIssuers(ctx, []string{issuer}, clientID, clientSecret, redirectURI, scopes, options...)
}

// NewRelyingPartyOIDCWithIssuers creates an (OIDC) RelyingParty with the given
// issuers, clientID, clientSecret, redirectURI, scopes and possible configOptions
// it will run discovery on the provided issuers and use the found endpoints
func NewRelyingPartyOIDCWithIssuers(ctx context.Context, issuers []string, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) {
rp := &relyingParty{
issuer: issuer,
issuers: issuers,
oauthConfig: &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Expand All @@ -256,7 +272,7 @@ func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, re
}
}
ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC")
discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint)
discoveryConfiguration, err := client.Discover(ctx, rp.issuers, rp.httpClient, rp.DiscoveryEndpoint)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/client/rp/relying_party_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (

func Test_verifyTokenResponse(t *testing.T) {
verifier := &IDTokenVerifier{
Issuer: tu.ValidIssuer,
Issuers: []string{tu.ValidIssuer},
MaxAgeIAT: 2 * time.Minute,
ClientID: tu.ValidClientID,
Offset: time.Second,
Expand Down
6 changes: 3 additions & 3 deletions pkg/client/rp/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenV
return nilClaims, err
}

if err = oidc.CheckIssuer(claims, v.Issuer); err != nil {
if err = oidc.CheckIssuer(claims, v.Issuers); err != nil {
return nilClaims, err
}

Expand Down Expand Up @@ -110,9 +110,9 @@ func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAl
}

// NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification.
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier {
func NewIDTokenVerifier(issuers []string, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier {
v := &IDTokenVerifier{
Issuer: issuer,
Issuers: issuers,
ClientID: clientID,
KeySet: keySet,
Offset: time.Second,
Expand Down
12 changes: 6 additions & 6 deletions pkg/client/rp/verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ import (
"testing"
"time"

tu "github.com/datasapiens/oidc/v3/internal/testutil"
"github.com/datasapiens/oidc/v3/pkg/oidc"
jose "github.com/go-jose/go-jose/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/datasapiens/oidc/v3/internal/testutil"
"github.com/datasapiens/oidc/v3/pkg/oidc"
)

func TestVerifyTokens(t *testing.T) {
verifier := &IDTokenVerifier{
Issuer: tu.ValidIssuer,
Issuers: []string{tu.ValidIssuer},
MaxAgeIAT: 2 * time.Minute,
Offset: time.Second,
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
Expand Down Expand Up @@ -93,7 +93,7 @@ func TestVerifyTokens(t *testing.T) {

func TestVerifyIDToken(t *testing.T) {
verifier := &IDTokenVerifier{
Issuer: tu.ValidIssuer,
Issuers: []string{tu.ValidIssuer},
MaxAgeIAT: 2 * time.Minute,
Offset: time.Second,
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
Expand Down Expand Up @@ -341,7 +341,7 @@ func TestNewIDTokenVerifier(t *testing.T) {
},
},
want: &IDTokenVerifier{
Issuer: tu.ValidIssuer,
Issuers: []string{tu.ValidIssuer},
Offset: time.Minute,
MaxAgeIAT: time.Hour,
ClientID: tu.ValidClientID,
Expand All @@ -355,7 +355,7 @@ func TestNewIDTokenVerifier(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewIDTokenVerifier(tt.args.issuer, tt.args.clientID, tt.args.keySet, tt.args.options...)
got := NewIDTokenVerifier([]string{tt.args.issuer}, tt.args.clientID, tt.args.keySet, tt.args.options...)
assert.Equal(t, tt.want, got)
})
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/client/rp/verifier_tokens_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ const idToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhY3IiOiJzb21ldGhpbmciLCJh
const accessToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOlsidW5pdCIsInRlc3QiXSwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJzdWIiOiJ0aW1AbG9jYWwuY29tIn0.Zrz3LWSRjCMJZUMaI5dUbW4vGdSmEeJQ3ouhaX0bcW9rdFFLgBI4K2FWJhNivq8JDmCGSxwLu3mI680GWmDaEoAx1M5sCO9lqfIZHGZh-lfAXk27e6FPLlkTDBq8Bx4o4DJ9Fw0hRJGjUTjnYv5cq1vo2-UqldasL6CwTbkzNC_4oQFfRtuodC4Ql7dZ1HRv5LXuYx7KPkOssLZtV9cwtJp5nFzKjcf2zEE_tlbjcpynMwypornRUp1EhCWKRUGkJhJeiP71ECY5pQhShfjBu9Nc5wDpSnZmnk2S4YsPrRK3QkE-iEkas8BfsOCrGoErHjEJexAIDjasGO5PFLWfCA`

func ExampleVerifyTokens_customClaims() {
v := rp.NewIDTokenVerifier("local.com", "555666", tu.KeySet{},
v := rp.NewIDTokenVerifier([]string{"local.com"}, "555666", tu.KeySet{},
rp.WithNonce(func(ctx context.Context) string { return "12345" }),
)

Expand Down
2 changes: 1 addition & 1 deletion pkg/client/rs/resource_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func newResourceServer(ctx context.Context, issuer string, authorizer func() (an
optFunc(rs)
}
if rs.introspectURL == "" || rs.tokenURL == "" {
config, err := client.Discover(ctx, rs.issuer, rs.httpClient)
config, err := client.Discover(ctx, []string{issuer}, rs.httpClient)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/client/tokenexchange/tokenexchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"net/http"
"time"

"github.com/go-jose/go-jose/v4"
"github.com/datasapiens/oidc/v3/pkg/client"
httphelper "github.com/datasapiens/oidc/v3/pkg/http"
"github.com/datasapiens/oidc/v3/pkg/oidc"
"github.com/go-jose/go-jose/v4"
)

type TokenExchanger interface {
Expand Down Expand Up @@ -55,7 +55,7 @@ func newOAuthTokenExchange(ctx context.Context, issuer string, authorizer func()
}

if te.tokenEndpoint == "" {
config, err := client.Discover(ctx, issuer, te.httpClient)
config, err := client.Discover(ctx, []string{issuer}, te.httpClient)
if err != nil {
return nil, err
}
Expand Down
9 changes: 4 additions & 5 deletions pkg/oidc/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ var (
// functions. Use package specific constructor functions to know
// which values need to be set.
type Verifier struct {
Issuer string
Issuers []string
MaxAgeIAT time.Duration
Offset time.Duration
ClientID string
Expand Down Expand Up @@ -115,10 +115,9 @@ func CheckSubject(claims Claims) error {
return nil
}

func CheckIssuer(claims Claims, issuer string) error {
return nil
if claims.GetIssuer() != issuer {
return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, issuer, claims.GetIssuer())
func CheckIssuer(claims Claims, issuers []string) error {
if !slices.Contains(issuers, claims.GetIssuer()) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the check makes sense as claims.GetIssuer is the real issuer: https://sts.windows.net/26f2a995-d6fd-4a27-9e38-b8bac98e4ce5/

return fmt.Errorf("%w: Expected one of: %v, got: %s", ErrIssuerInvalid, issuers, claims.GetIssuer())
}
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/oidc/verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestCheckIssuer(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckIssuer(tt.claims, issuer)
err := CheckIssuer(tt.claims, []string{issuer})
assert.ErrorIs(t, err, tt.wantErr)
})
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/op/verifier_access_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func WithSupportedAccessTokenSigningAlgorithms(algs ...string) AccessTokenVerifi
// NewAccessTokenVerifier returns a AccessTokenVerifier suitable for access token verification.
func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) *AccessTokenVerifier {
verifier := &AccessTokenVerifier{
Issuer: issuer,
KeySet: keySet,
Issuers: []string{issuer},
KeySet: keySet,
}
for _, opt := range opts {
opt(verifier)
Expand All @@ -44,7 +44,7 @@ func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v *Acce
return nilClaims, err
}

if err := oidc.CheckIssuer(claims, v.Issuer); err != nil {
if err := oidc.CheckIssuer(claims, v.Issuers); err != nil {
return nilClaims, err
}

Expand Down
12 changes: 6 additions & 6 deletions pkg/op/verifier_access_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/datasapiens/oidc/v3/internal/testutil"
"github.com/datasapiens/oidc/v3/pkg/oidc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewAccessTokenVerifier(t *testing.T) {
Expand All @@ -29,8 +29,8 @@ func TestNewAccessTokenVerifier(t *testing.T) {
keySet: tu.KeySet{},
},
want: &AccessTokenVerifier{
Issuer: tu.ValidIssuer,
KeySet: tu.KeySet{},
Issuers: []string{tu.ValidIssuer},
KeySet: tu.KeySet{},
},
},
{
Expand All @@ -43,7 +43,7 @@ func TestNewAccessTokenVerifier(t *testing.T) {
},
},
want: &AccessTokenVerifier{
Issuer: tu.ValidIssuer,
Issuers: []string{tu.ValidIssuer},
KeySet: tu.KeySet{},
SupportedSignAlgs: []string{"ABC", "DEF"},
},
Expand All @@ -59,7 +59,7 @@ func TestNewAccessTokenVerifier(t *testing.T) {

func TestVerifyAccessToken(t *testing.T) {
verifier := &AccessTokenVerifier{
Issuer: tu.ValidIssuer,
Issuers: []string{tu.ValidIssuer},
MaxAgeIAT: 2 * time.Minute,
Offset: time.Second,
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
Expand Down
6 changes: 3 additions & 3 deletions pkg/op/verifier_id_token_hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func WithSupportedIDTokenHintSigningAlgorithms(algs ...string) IDTokenHintVerifi

func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) *IDTokenHintVerifier {
verifier := &IDTokenHintVerifier{
Issuer: issuer,
KeySet: keySet,
Issuers: []string{issuer},
KeySet: keySet,
}
for _, opt := range opts {
opt(verifier)
Expand Down Expand Up @@ -60,7 +60,7 @@ func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTo
return nilClaims, err
}

if err := oidc.CheckIssuer(claims, v.Issuer); err != nil {
if err := oidc.CheckIssuer(claims, v.Issuers); err != nil {
return nilClaims, err
}

Expand Down
12 changes: 6 additions & 6 deletions pkg/op/verifier_id_token_hint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/datasapiens/oidc/v3/internal/testutil"
"github.com/datasapiens/oidc/v3/pkg/oidc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewIDTokenHintVerifier(t *testing.T) {
Expand All @@ -30,8 +30,8 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
keySet: tu.KeySet{},
},
want: &IDTokenHintVerifier{
Issuer: tu.ValidIssuer,
KeySet: tu.KeySet{},
Issuers: []string{tu.ValidIssuer},
KeySet: tu.KeySet{},
},
},
{
Expand All @@ -44,7 +44,7 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
},
},
want: &IDTokenHintVerifier{
Issuer: tu.ValidIssuer,
Issuers: []string{tu.ValidIssuer},
KeySet: tu.KeySet{},
SupportedSignAlgs: []string{"ABC", "DEF"},
},
Expand All @@ -67,7 +67,7 @@ func Test_IDTokenHintExpiredError(t *testing.T) {

func TestVerifyIDTokenHint(t *testing.T) {
verifier := &IDTokenHintVerifier{
Issuer: tu.ValidIssuer,
Issuers: []string{tu.ValidIssuer},
MaxAgeIAT: 2 * time.Minute,
Offset: time.Second,
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
Expand Down
Loading