Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
7fd1d7e
decode only once in middleware
MariemBaccari Sep 9, 2025
03a06f3
fix formatting error
MariemBaccari Sep 9, 2025
b20dc4c
[logging] Add subject field to log output (#1263)
MariemBaccari Sep 9, 2025
fc27a16
remove merge messages
MariemBaccari Sep 9, 2025
625f684
fix unit test
MariemBaccari Sep 9, 2025
1ef90ba
edit error message
MariemBaccari Sep 9, 2025
5860a2e
Revert changes and improve middleware declarations
MariemBaccari Sep 18, 2025
30dd60f
address pr comments
MariemBaccari Oct 21, 2025
06b3ae2
fix context key nit
MariemBaccari Oct 21, 2025
86bb4ca
Merge branch 'master' into 1051-dup-token
MariemBaccari Oct 21, 2025
363c8da
add missing context
MariemBaccari Oct 21, 2025
9fd7ae0
clarify middleware doc
MariemBaccari Oct 21, 2025
36df8df
decode only once in middleware
MariemBaccari Sep 9, 2025
929fbe6
fix formatting error
MariemBaccari Sep 9, 2025
0621a00
[logging] Add subject field to log output (#1263)
MariemBaccari Sep 9, 2025
7316590
remove merge messages
MariemBaccari Sep 9, 2025
1a74ef7
fix unit test
MariemBaccari Sep 9, 2025
1325293
edit error message
MariemBaccari Sep 9, 2025
c11145a
Revert changes and improve middleware declarations
MariemBaccari Sep 18, 2025
d78dc68
address pr comments
MariemBaccari Oct 21, 2025
33cca92
fix context key nit
MariemBaccari Oct 21, 2025
09f7926
add missing context
MariemBaccari Oct 21, 2025
26eae64
clarify middleware doc
MariemBaccari Oct 21, 2025
ac80360
Merge branch '1051-dup-token' of github.com:Orbitalize/dss into 1051-…
MariemBaccari Oct 21, 2025
6e443ca
fix auth_test.go
MariemBaccari Oct 21, 2025
ed2de91
address review comments
MariemBaccari Dec 15, 2025
71210a2
use keyclaims var name
MariemBaccari Dec 15, 2025
62ed06c
remove added newline
MariemBaccari Dec 15, 2025
97891bb
move ctxkey
MariemBaccari Dec 15, 2025
817115b
use t.context in all testing
MariemBaccari Dec 15, 2025
32e429e
address review comments
MariemBaccari Dec 26, 2025
e33f6b4
log value from claims.FromContext
MariemBaccari Dec 26, 2025
1e5755f
remove unused ctxkey
MariemBaccari Dec 26, 2025
4fdf748
nit
MariemBaccari Dec 30, 2025
60f3db1
Update pkg/logging/http.go
MariemBaccari Jan 6, 2026
799821b
address review comment
MariemBaccari Jan 6, 2026
bcdebdb
nit
MariemBaccari Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 4 additions & 23 deletions cmds/core-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,10 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st
multiRouter.Routers = append(multiRouter.Routers, &scdV1Router)
}

handler := logging.HTTPMiddleware(logger, *dumpRequests,
healthyEndpointMiddleware(logger,
&multiRouter,
))

handler = authDecoderMiddleware(authorizer, handler)
// the middlewares are wrapped and, therefore, executed in the opposite order
handler := healthyEndpointMiddleware(logger, &multiRouter)
handler = logging.HTTPMiddleware(logger, *dumpRequests, handler)
handler = authorizer.TokenMiddleware(handler)

httpServer := &http.Server{
Addr: address,
Expand Down Expand Up @@ -373,23 +371,6 @@ func healthyEndpointMiddleware(logger *zap.Logger, next http.Handler) http.Handl
})
}

// authDecoderMiddleware decodes the authentication token and adds the Subject claim to the context.
func authDecoderMiddleware(authorizer *auth.Authorizer, handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var ctx context.Context

claims, err := authorizer.ExtractClaims(r)
if err != nil {
//remove the stacktrace using the formatting specifier "%#s"
ctx = context.WithValue(r.Context(), logging.CtxAuthError{}, fmt.Sprintf("%#s", err))
} else {
ctx = context.WithValue(r.Context(), logging.CtxAuthSubject{}, claims.Subject)
}

handler.ServeHTTP(w, r.WithContext(ctx))
})
}

func SetDeprecatingHttpFlag(logger *zap.Logger, newFlag **bool, deprecatedFlag **bool) {
if **deprecatedFlag {
logger.Warn("DEPRECATED: enable_http has been renamed to allow_http_base_urls.")
Expand Down
34 changes: 25 additions & 9 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/interuss/dss/pkg/api"
"github.com/interuss/dss/pkg/auth/claims"
dsserr "github.com/interuss/dss/pkg/errors"
"github.com/interuss/dss/pkg/logging"
"github.com/interuss/stacktrace"
Expand Down Expand Up @@ -182,11 +183,26 @@ func (a *Authorizer) setKeys(keys []interface{}) {
a.keyGuard.Unlock()
}

// Authorize extracts and verifies bearer tokens from a http.Request.
// TokenMiddleware decodes the authentication token and passes the claims to the authorizer and to the context for logging.
func (a *Authorizer) TokenMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
claimsValue, err := a.extractClaims(r)
if err != nil {
ctx = claims.NewContextFromError(ctx, err)
} else {
ctx = claims.NewContext(ctx, claimsValue)
}

handler.ServeHTTP(w, r.WithContext(ctx))
})
}

// Authorize extracts and verifies bearer tokens from a http.Request after it was validated by the TokenMiddleware.
func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult {
keyClaims, err := a.ExtractClaims(r)
keyClaims, err := claims.FromContext(r.Context())
if err != nil {
return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Failed to extract claims from access token")}
return api.AuthorizationResult{Error: stacktrace.Propagate(err, "Error retrieving claims from context")}
}

if !a.acceptedAudiences[keyClaims.Audience] {
Expand All @@ -205,21 +221,21 @@ func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptio
}
}

func (a *Authorizer) ExtractClaims(r *http.Request) (claims, error) {
func (a *Authorizer) extractClaims(r *http.Request) (claims.Claims, error) {
tknStr, ok := getToken(r)
if !ok {
return claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token")
return claims.Claims{}, stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token")
}

a.keyGuard.RLock()
keys := a.keys
a.keyGuard.RUnlock()
validated := false
var err error
var keyClaims claims
var keyClaims claims.Claims

for _, key := range keys {
keyClaims = claims{}
keyClaims = claims.Claims{}
key := key
_, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) {
return key, nil
Expand All @@ -234,7 +250,7 @@ func (a *Authorizer) ExtractClaims(r *http.Request) (claims, error) {
if err == nil { // If we have no keys, errs may be nil
err = stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "No keys to validate against")
}
return claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed")
return claims.Claims{}, stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed")
}

return keyClaims, nil
Expand Down Expand Up @@ -278,7 +294,7 @@ func describeAuthorizationExpectations(authOptions []api.AuthorizationOption) st
// validateScopes matches scopes against a set of authorization options. Validation against a single one of those is
// enough for the validation to succeed. Returns true if it succeeds, or returns false and a string describing the
// missing scopes if it fails. Empty authorization options means that the validation passes.
func validateScopes(authOptions []api.AuthorizationOption, clientScopes ScopeSet) (bool, string) {
func validateScopes(authOptions []api.AuthorizationOption, clientScopes claims.ScopeSet) (bool, string) {
if len(authOptions) == 0 {
return true, ""
}
Expand Down
23 changes: 16 additions & 7 deletions pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/interuss/dss/pkg/api"
"github.com/interuss/dss/pkg/api/scdv1"
"github.com/interuss/dss/pkg/auth/claims"
dsserr "github.com/interuss/dss/pkg/errors"
"github.com/interuss/stacktrace"

Expand Down Expand Up @@ -52,7 +53,7 @@ func rsaTokenReqWithMissingIssuer(key *rsa.PrivateKey, exp, nbf int64) *http.Req
}

func TestNewRSAAuthClient(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
defer cancel()

tmpfile, err := os.CreateTemp("/tmp", "bad.pem")
Expand Down Expand Up @@ -103,7 +104,7 @@ func TestRSAAuthInterceptor(t *testing.T) {
{rsaTokenReq(key, 100, 50), dsserr.Unauthenticated},
}

a, err := NewRSAAuthorizer(context.Background(), Configuration{
a, err := NewRSAAuthorizer(t.Context(), Configuration{
KeyResolver: &fromMemoryKeyResolver{
Keys: []interface{}{&key.PublicKey},
},
Expand All @@ -115,7 +116,15 @@ func TestRSAAuthInterceptor(t *testing.T) {

for i, test := range authTests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
res := a.Authorize(nil, test.req, []api.AuthorizationOption{})
ctx := t.Context()
claimsValue, err := a.extractClaims(test.req)
if err != nil {
ctx = claims.NewContextFromError(ctx, err)
} else {
ctx = claims.NewContext(ctx, claimsValue)
}

res := a.Authorize(nil, test.req.WithContext(ctx), []api.AuthorizationOption{})
if test.code != stacktrace.ErrorCode(0) && stacktrace.GetCode(res.Error) != test.code {
t.Logf("%v", res.Error)
t.Errorf("expected: %v, got: %v, with message %s", test.code, stacktrace.GetCode(res.Error), res.Error.Error())
Expand Down Expand Up @@ -193,17 +202,17 @@ func TestMissingScopes(t *testing.T) {
}

func TestClaimsValidation(t *testing.T) {
Now = func() time.Time {
claims.Now = func() time.Time {
return time.Unix(42, 0)
}
jwt.TimeFunc = Now
jwt.TimeFunc = claims.Now

defer func() {
jwt.TimeFunc = time.Now
Now = time.Now
claims.Now = time.Now
}()

claims := &claims{}
claims := &claims.Claims{}

require.Error(t, claims.Valid())

Expand Down
35 changes: 32 additions & 3 deletions pkg/auth/claims.go → pkg/auth/claims/claims.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth
package claims

import (
"context"
"encoding/json"
"errors"
"strings"
Expand All @@ -18,6 +19,34 @@ var (
Now = time.Now
)

type ctxKey string

var (
claimsKey = ctxKey("claims")
errKey = ctxKey("error")
)

func NewContext(ctx context.Context, claims Claims) context.Context {
return context.WithValue(ctx, claimsKey, claims)
}

func NewContextFromError(ctx context.Context, err error) context.Context {
return context.WithValue(ctx, errKey, err)
}

func FromContext(ctx context.Context) (Claims, error) {
claims, ok := ctx.Value(claimsKey).(Claims)
if !ok {
err, ok := ctx.Value(errKey).(error)
if ok {
return Claims{}, err
}
return Claims{}, stacktrace.NewError("No claims or error in context")
}

return claims, nil
}

// ScopeSet models a set of scopes.
type ScopeSet map[string]struct{}

Expand Down Expand Up @@ -61,12 +90,12 @@ func (s *ScopeSet) ToStringSlice() []string {
return scopes
}

type claims struct {
type Claims struct {
jwt.StandardClaims
Scopes ScopeSet `json:"scope"`
}

func (c *claims) Valid() error {
func (c *Claims) Valid() error {
if c.Subject == "" {
return errMissingOrEmptySubject
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/auth/claims_test.go → pkg/auth/claims/claims_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package auth
package claims

import (
"encoding/json"
Expand All @@ -8,7 +8,7 @@ import (
)

func TestScopesJSONUnmarshaling(t *testing.T) {
claims := &claims{}
claims := &Claims{}
require.NoError(t, json.Unmarshal([]byte(`{"scope": "one two three"}`), claims))
require.Contains(t, claims.Scopes, "one")
require.Contains(t, claims.Scopes, "two")
Expand Down
12 changes: 3 additions & 9 deletions pkg/logging/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"time"

"github.com/interuss/dss/pkg/auth/claims"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -39,9 +40,6 @@ func (w *tracingResponseWriter) WriteHeader(statusCode int) {
w.next.WriteHeader(statusCode)
}

type CtxAuthError struct{}
type CtxAuthSubject struct{}

// HTTPMiddleware installs a logging http.Handler that logs requests and
// selected aspects of responses to 'logger'.
func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Handler {
Expand Down Expand Up @@ -72,12 +70,8 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha
}
}

subject, ok := r.Context().Value(CtxAuthSubject{}).(string)
if !ok {
authErrorMsg := r.Context().Value(CtxAuthError{}).(string)
logger = logger.With(zap.String("resp_sub_err", authErrorMsg))
} else {
logger = logger.With(zap.String("req_sub", subject))
if claimsValue, _ := claims.FromContext(r.Context()); claimsValue.Subject != "" {
logger = logger.With(zap.String("req_sub", claimsValue.Subject))
}

handler.ServeHTTP(trw, r)
Expand Down
Loading