Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/golangci.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
version: "2"
run:
modules-download-mode: readonly
relative-path-mode: gomod
linters:
enable:
- bodyclose
Expand Down
8 changes: 6 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ install-dev-tools:
deps:
@go mod tidy

GO_BUILD_TAGS := $(if $(BUILD_TAGS),-tags "$(BUILD_TAGS)")

agent_bin:
echo "ORB_VERSION: $(ORB_VERSION)-$(COMMIT_HASH)"
CGO_ENABLED=$(CGO_ENABLED) GOOS=$(GOOS) GOARCH=$(GOARCH) GOARM=$(GOARM) go build -mod=mod -ldflags="$(LDFLAGS)" -o ${BUILD_DIR}/orb-agent cmd/main.go
CGO_ENABLED=$(CGO_ENABLED) GOOS=$(GOOS) GOARCH=$(GOARCH) GOARM=$(GOARM) go build -mod=mod -ldflags="$(LDFLAGS)" $(GO_BUILD_TAGS) -o ${BUILD_DIR}/orb-agent cmd/main.go

.PHONY: test
test:
@go test -race ./...
@go test $(GO_BUILD_TAGS)-race ./...

.PHONY: test-timed
test-timed:
Expand Down Expand Up @@ -84,6 +86,7 @@ agent:
--build-arg GOARCH=$(GOARCH) \
--build-arg PKTVISOR_TAG=$(PKTVISOR_TAG) \
--build-arg SNMP_DISCOVERY_TAG=$(SNMP_DISCOVERY_TAG) \
--build-arg BUILD_TAGS=$(BUILD_TAGS) \
--tag=$(ORB_DOCKERHUB_REPO)/orb-agent:$(REF_TAG) \
--tag=$(ORB_DOCKERHUB_REPO)/orb-agent:$(ORB_VERSION) \
--tag=$(ORB_DOCKERHUB_REPO)/orb-agent:$(ORB_VERSION)-$(COMMIT_HASH) \
Expand All @@ -94,6 +97,7 @@ agent_fast:
--build-arg GOARCH=$(GOARCH) \
--build-arg PKTVISOR_TAG=$(PKTVISOR_TAG) \
--build-arg SNMP_DISCOVERY_TAG=$(SNMP_DISCOVERY_TAG) \
--build-arg BUILD_TAGS=$(BUILD_TAGS) \
--tag=$(ORB_DOCKERHUB_REPO)/orb-agent:$(REF_TAG) \
--tag=$(ORB_DOCKERHUB_REPO)/orb-agent:$(ORB_VERSION) \
--tag=$(ORB_DOCKERHUB_REPO)/orb-agent:$(ORB_VERSION)-$(COMMIT_HASH) \
Expand Down
50 changes: 50 additions & 0 deletions agent/configmgr/fleet.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,18 @@ func (fleetManager *FleetConfigManager) Start(cfg config.Config, backends map[st
}
fleetManager.logger.Info("OTLP bridge server started", slog.Int("grpc_port", grpcPort))

// Wire up token refresher so autopaho's ConnectPacketBuilder can fetch a fresh JWT
// on every reconnect attempt, eliminating the stale-token race window.
if mqttConn, ok := fleetManager.connection.(*fleet.MQTTConnection); ok {
mqttConn.SetTokenRefresher(func(ctx context.Context) (string, error) {
resp, err := fleetManager.authTokenManager.RefreshToken(ctx)
if err != nil {
return "", err
}
return resp.AccessToken, nil
})
}

err = fleetManager.connection.Connect(ctx, connectionDetails, backends, cfg.OrbAgent.Labels, string(configYaml))
if err != nil {
return err
Expand Down Expand Up @@ -228,6 +240,10 @@ func (fleetManager *FleetConfigManager) Start(cfg config.Config, backends map[st
// Start background goroutine to monitor token expiry and trigger proactive reconnection
go fleetManager.monitorTokenExpiry()

// Start debug handler (no-op unless built with -tags debug).
// The handler can force a token rotation or log token status on demand
fleet.StartDebugTrigger(fleetManager.monitorCtx, fleetManager.logger, fleetManager)

return nil
}

Expand Down Expand Up @@ -398,6 +414,40 @@ func (fleetManager *FleetConfigManager) configToSafeString(cfg config.Config) (s
return string(configYaml), nil
}

// RotateCredentials refreshes the JWT token and signals the reconnect worker.
// Implements fleet.DebugCredentials.
func (fleetManager *FleetConfigManager) RotateCredentials(ctx context.Context) error {
oldExpiry := fleetManager.authTokenManager.GetTokenExpiryTime()
_, err := fleetManager.authTokenManager.RefreshToken(ctx)
if err != nil {
return err
}
newExpiry := fleetManager.authTokenManager.GetTokenExpiryTime()
fleetManager.logger.Warn("credentials rotated",
"previous_expiry", oldExpiry,
"new_expiry", newExpiry,
"new_time_until_expiry", time.Until(newExpiry).Truncate(time.Second))

select {
case fleetManager.reconnectChan <- struct{}{}:
fleetManager.logger.Debug("reconnect signal sent after credential rotation")
default:
fleetManager.logger.Debug("reconnect already in progress, skipping signal")
}
return nil
}

// LogCredentials logs current token age and status.
// Implements fleet.DebugCredentials.
func (fleetManager *FleetConfigManager) LogCredentials() {
expiry := fleetManager.authTokenManager.GetTokenExpiryTime()
fleetManager.logger.Warn("token status",
"expires_at", expiry,
"time_until_expiry", time.Until(expiry).Truncate(time.Second),
"expired", fleetManager.authTokenManager.IsTokenExpired(),
"expiring_soon", fleetManager.authTokenManager.IsTokenExpiringSoon(2*time.Minute))
}

// GetContext returns the context for the Fleet configuration manager
func (fleetManager *FleetConfigManager) GetContext(ctx context.Context) context.Context {
// Empty implementation for now - just return the context as-is
Expand Down
56 changes: 56 additions & 0 deletions agent/configmgr/fleet/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
groupMembershipFailCount int
heartbeatFailCount int
mu sync.Mutex
tokenRefresher func(ctx context.Context) (string, error) // returns fresh JWT on reconnect
}

// NewMQTTConnection creates a new MQTTConnection
Expand All @@ -65,6 +66,13 @@
}
}

// SetTokenRefresher sets a callback that returns a fresh JWT. When set, the MQTT
// connection will call this before every CONNECT packet (including auto-reconnects)
// to ensure the broker always receives a valid token.
func (connection *MQTTConnection) SetTokenRefresher(fn func(ctx context.Context) (string, error)) {
connection.tokenRefresher = fn
}

// AddOnReadyHook registers a callback to be invoked when MQTT connection is ready.
func (connection *MQTTConnection) AddOnReadyHook(fn func(cm *autopaho.ConnectionManager, topics TokenResponseTopics)) {
connection.onReadyHooks = append(connection.onReadyHooks, fn)
Expand Down Expand Up @@ -352,6 +360,15 @@
cfg.ConnectPassword = []byte(details.Token)
}

// On every reconnect, refresh the JWT before sending CONNECT so autopaho's
// auto-reconnect never presents a stale token to the broker. When initialToken
// is non-empty (managed reconnect via Reconnect), the first CONNECT uses that
// token as-is to stay consistent with the topics derived from it; subsequent
// auto-reconnects call tokenRefresher normally.
if builder := buildConnectPacketBuilder(connection, details.Token); builder != nil {
cfg.ConnectPacketBuilder = builder
}

// Create and start the connection manager using the long-lived context.
connection.connectionManager, err = autopaho.NewConnection(ctx, cfg)
if err != nil {
Expand Down Expand Up @@ -461,3 +478,42 @@
connection.heartbeatFailCount = 0
return nil
}

// buildConnectPacketBuilder returns a ConnectPacketBuilder callback that refreshes
// the JWT before every CONNECT packet. Returns nil when no tokenRefresher is set.
//
// initialToken is the token already used to derive topics/zone for this Connect call.
// The first invocation of the returned closure uses initialToken as-is (keeping
// password and topics consistent). Subsequent invocations (autopaho auto-reconnects)
// call tokenRefresher to obtain a fresh JWT. The "first call" state is scoped to the
// closure instance, not to the connection struct, so each Connect creates an
// independent builder with its own lifecycle.
func buildConnectPacketBuilder(connection *MQTTConnection, initialToken string) func(*paho.Connect, *url.URL) (*paho.Connect, error) {

Check failure on line 491 in agent/configmgr/fleet/connection.go

View workflow job for this annotation

GitHub Actions / golangci

unused-parameter: parameter 'initialToken' seems to be unused, consider removing or renaming it as _ (revive)
if connection.tokenRefresher == nil {
return nil
}
firstCall := true
return func(cp *paho.Connect, _ *url.URL) (*paho.Connect, error) {
// First call: use the token that was already placed in ConnectPassword
// and that topics were derived from — no extra refresh needed.
if firstCall {
firstCall = false
connection.logger.Debug("using initial token for CONNECT")
return cp, nil
}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

freshJWT, err := connection.tokenRefresher(ctx)
if err != nil {
connection.logger.Error("failed to refresh token for MQTT reconnect", "error", err)
// Fall through with existing credentials — broker will reject if truly expired,
// and autopaho will retry (calling this builder again).
return cp, nil
}
connection.logger.Info("JWT refreshed for MQTT reconnect")
cp.Password = []byte(freshJWT)
return cp, nil
}
}
12 changes: 12 additions & 0 deletions agent/configmgr/fleet/debug.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package fleet

import "context"

// DebugCredentials is the interface consumed by the debug trigger to force
// token rotation and inspect token state. It is implemented outside this
// package (by FleetConfigManager) so that fleet debug code has no dependency
// on the configmgr package.
type DebugCredentials interface {
RotateCredentials(ctx context.Context) error
LogCredentials()
}
44 changes: 44 additions & 0 deletions agent/configmgr/fleet/debug_trigger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//go:build debug

package fleet

import (
"context"
"log/slog"
"os"
"os/signal"
"syscall"
)

// StartDebugTrigger listens for OS signals to trigger debug actions.
// Only compiled when built with "-tags debug".
//
// SIGUSR1 → force token rotation + reconnect
// SIGUSR2 → log current token age/status
//
// The goroutine exits when ctx is cancelled; no explicit stop needed.
func StartDebugTrigger(ctx context.Context, logger *slog.Logger, dc DebugCredentials) {
sigRotate := make(chan os.Signal, 1)
sigPeek := make(chan os.Signal, 1)
signal.Notify(sigRotate, syscall.SIGUSR1)
signal.Notify(sigPeek, syscall.SIGUSR2)

go func() {
logger.Info("debug triggers active (SIGUSR1=rotate, SIGUSR2=peek)")
for {
select {
case <-ctx.Done():
signal.Stop(sigRotate)
signal.Stop(sigPeek)
return
case <-sigRotate:
logger.Warn("debug: SIGUSR1 received — forcing token rotation")
if err := dc.RotateCredentials(ctx); err != nil {
logger.Error("debug: token rotation failed", "error", err)
}
case <-sigPeek:
dc.LogCredentials()
}
}
}()
}
11 changes: 11 additions & 0 deletions agent/configmgr/fleet/debug_trigger_off.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//go:build !debug

package fleet

import (
"context"
"log/slog"
)

// StartDebugTrigger is a no-op when built without "-tags debug".
func StartDebugTrigger(_ context.Context, _ *slog.Logger, _ DebugCredentials) {}
90 changes: 90 additions & 0 deletions agent/configmgr/fleet/debug_trigger_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//go:build debug

package fleet

import (
"context"
"log/slog"
"os"
"syscall"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

// fakeDebugCredentials records calls for test assertions.
type fakeDebugCredentials struct {
rotateCalled chan struct{}
logCalled chan struct{}
rotateErr error
}

func (f *fakeDebugCredentials) RotateCredentials(_ context.Context) error {
f.rotateCalled <- struct{}{}
return f.rotateErr
}

func (f *fakeDebugCredentials) LogCredentials() {
f.logCalled <- struct{}{}
}

func newFakeDC() *fakeDebugCredentials {
return &fakeDebugCredentials{
rotateCalled: make(chan struct{}, 1),
logCalled: make(chan struct{}, 1),
}
}

func TestDebugTrigger_SIGUSR1_CallsRotate(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
dc := newFakeDC()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

StartDebugTrigger(ctx, logger, dc)
time.Sleep(50 * time.Millisecond)

_ = syscall.Kill(os.Getpid(), syscall.SIGUSR1)

select {
case <-dc.rotateCalled:
// expected
case <-time.After(time.Second):
t.Fatal("expected RotateCredentials to be called")
}
}

func TestDebugTrigger_SIGUSR2_CallsLog(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
dc := newFakeDC()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

StartDebugTrigger(ctx, logger, dc)
time.Sleep(50 * time.Millisecond)

_ = syscall.Kill(os.Getpid(), syscall.SIGUSR2)

select {
case <-dc.logCalled:
// expected
case <-time.After(time.Second):
t.Fatal("expected LogCredentials to be called")
}
}

func TestDebugTrigger_ContextCancel(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
dc := newFakeDC()
ctx, cancel := context.WithCancel(context.Background())

StartDebugTrigger(ctx, logger, dc)
time.Sleep(50 * time.Millisecond)

cancel()
time.Sleep(50 * time.Millisecond)

// After cancel, signals should not be delivered
assert.NotNil(t, dc)
}
Loading
Loading