diff --git a/cmd/fleetint/enroll.go b/cmd/fleetint/enroll.go index 06779c81..d12ade5e 100644 --- a/cmd/fleetint/enroll.go +++ b/cmd/fleetint/enroll.go @@ -30,7 +30,7 @@ import ( ) var ( - performEnrollWorkflow = enrollment.EnrollWithConfig + performEnrollWorkflow = enrollment.EnrollWithConfigAndMetadata fleetintEnvFilePath = config.DefaultEnvFilePath ) @@ -84,6 +84,10 @@ func resolveToken(cliContext *cli.Context) (string, error) { func enrollCommand(cliContext *cli.Context) error { baseEndpoint := cliContext.String("endpoint") force := cliContext.Bool("force") + metadata := &enrollment.EnrollMetadata{ + NodeGroup: optionalFlagValue(cliContext, "node-group"), + ComputeZone: optionalFlagValue(cliContext, "compute-zone"), + } sakToken, err := resolveToken(cliContext) if err != nil { @@ -117,5 +121,13 @@ func enrollCommand(cliContext *cli.Context) error { return fmt.Errorf("failed to configure loop settings from environment variables: %w", err) } - return performEnrollWorkflow(ctx, baseEndpoint, sakToken, cfg) + return performEnrollWorkflow(ctx, baseEndpoint, sakToken, cfg, metadata) +} + +func optionalFlagValue(cliContext *cli.Context, name string) *string { + if !cliContext.IsSet(name) { + return nil + } + value := strings.TrimSpace(cliContext.String(name)) + return &value } diff --git a/cmd/fleetint/enroll_test.go b/cmd/fleetint/enroll_test.go index ed85bfac..e9d9fa55 100644 --- a/cmd/fleetint/enroll_test.go +++ b/cmd/fleetint/enroll_test.go @@ -28,6 +28,7 @@ import ( "github.com/stretchr/testify/require" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" + "github.com/NVIDIA/fleet-intelligence-agent/internal/enrollment" "github.com/NVIDIA/fleet-intelligence-agent/internal/precheck" ) @@ -107,8 +108,11 @@ func TestEnrollCommandBlocksOnFailedPrecheck(t *testing.T) { }, }, nil } - performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config) error { + performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config, metadata *enrollment.EnrollMetadata) error { enrollmentCalled = true + require.NotNil(t, metadata) + require.Nil(t, metadata.NodeGroup) + require.Nil(t, metadata.ComputeZone) return nil } @@ -142,8 +146,11 @@ func TestEnrollCommandForceBypassesFailedPrecheck(t *testing.T) { }, }, nil } - performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config) error { + performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config, metadata *enrollment.EnrollMetadata) error { enrollmentCalled = true + require.NotNil(t, metadata) + require.Nil(t, metadata.NodeGroup) + require.Nil(t, metadata.ComputeZone) return nil } @@ -174,11 +181,14 @@ func TestEnrollCommandPassesTimeoutContext(t *testing.T) { }, nil } - performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config) error { + performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config, metadata *enrollment.EnrollMetadata) error { deadline, ok := ctx.Deadline() require.True(t, ok) require.LessOrEqual(t, time.Until(deadline), defaultEnrollTimeout) require.Greater(t, time.Until(deadline), 55*time.Second) + require.NotNil(t, metadata) + require.Nil(t, metadata.NodeGroup) + require.Nil(t, metadata.ComputeZone) return nil } @@ -217,7 +227,7 @@ FLEETINT_ATTESTATION_INTERVAL="6h" `), 0o600)) fleetintEnvFilePath = envFilePath - performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config) error { + performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config, metadata *enrollment.EnrollMetadata) error { require.NotNil(t, cfg) require.NotNil(t, cfg.Inventory) require.False(t, cfg.Inventory.Enabled) @@ -225,6 +235,9 @@ FLEETINT_ATTESTATION_INTERVAL="6h" require.NotNil(t, cfg.Attestation) require.True(t, cfg.Attestation.Enabled) require.Equal(t, 6*time.Hour, cfg.Attestation.Interval.Duration) + require.NotNil(t, metadata) + require.Nil(t, metadata.NodeGroup) + require.Nil(t, metadata.ComputeZone) return nil } @@ -234,3 +247,83 @@ FLEETINT_ATTESTATION_INTERVAL="6h" err := app.Run([]string{"fleetint", "enroll", "--endpoint", "https://example.com", "--token", "token"}) require.NoError(t, err) } + +func TestEnrollCommandPassesOptionalMetadata(t *testing.T) { + useMissingFleetintEnvFile(t) + + originalRunPrecheck := runPrecheck + originalEnrollWorkflow := performEnrollWorkflow + t.Cleanup(func() { + runPrecheck = originalRunPrecheck + performEnrollWorkflow = originalEnrollWorkflow + }) + + runPrecheck = func() (precheck.Result, error) { + return precheck.Result{ + Checks: []precheck.Check{ + {Name: "gpu-present", Message: "ok", Passed: true}, + }, + }, nil + } + + performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config, metadata *enrollment.EnrollMetadata) error { + require.NotNil(t, metadata) + require.NotNil(t, metadata.NodeGroup) + require.Equal(t, "prod-group", *metadata.NodeGroup) + require.NotNil(t, metadata.ComputeZone) + require.Equal(t, "us-east-1c", *metadata.ComputeZone) + return nil + } + + app := App() + app.Writer = &bytes.Buffer{} + + err := app.Run([]string{ + "fleetint", "enroll", + "--endpoint", "https://example.com", + "--token", "token", + "--node-group", "prod-group", + "--compute-zone", "us-east-1c", + }) + require.NoError(t, err) +} + +func TestEnrollCommandTreatsExplicitEmptyMetadataAsClear(t *testing.T) { + useMissingFleetintEnvFile(t) + + originalRunPrecheck := runPrecheck + originalEnrollWorkflow := performEnrollWorkflow + t.Cleanup(func() { + runPrecheck = originalRunPrecheck + performEnrollWorkflow = originalEnrollWorkflow + }) + + runPrecheck = func() (precheck.Result, error) { + return precheck.Result{ + Checks: []precheck.Check{ + {Name: "gpu-present", Message: "ok", Passed: true}, + }, + }, nil + } + + performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config, metadata *enrollment.EnrollMetadata) error { + require.NotNil(t, metadata) + require.NotNil(t, metadata.NodeGroup) + require.Empty(t, *metadata.NodeGroup) + require.NotNil(t, metadata.ComputeZone) + require.Empty(t, *metadata.ComputeZone) + return nil + } + + app := App() + app.Writer = &bytes.Buffer{} + + err := app.Run([]string{ + "fleetint", "enroll", + "--endpoint", "https://example.com", + "--token", "token", + "--node-group=", + "--compute-zone=", + }) + require.NoError(t, err) +} diff --git a/cmd/fleetint/root.go b/cmd/fleetint/root.go index c876e51a..fc78c00a 100644 --- a/cmd/fleetint/root.go +++ b/cmd/fleetint/root.go @@ -214,6 +214,14 @@ func App() *cli.App { Name: "force", Usage: "continue enrollment even when precheck fails", }, + &cli.StringFlag{ + Name: "node-group", + Usage: "optional node group metadata associated with this node", + }, + &cli.StringFlag{ + Name: "compute-zone", + Usage: "optional compute zone metadata associated with this node", + }, }, }, { diff --git a/cmd/fleetint/unenroll.go b/cmd/fleetint/unenroll.go index 1d6decf8..27b28f9a 100644 --- a/cmd/fleetint/unenroll.go +++ b/cmd/fleetint/unenroll.go @@ -71,6 +71,8 @@ func removeEnrollmentMetadata(ctx context.Context, dbRW *sql.DB) error { agentstate.MetadataKeySAKToken, agentstate.MetadataKeyBackendBaseURL, agentstate.MetadataKeyEnrolledAt, + agentstate.MetadataKeyNodeGroup, + agentstate.MetadataKeyComputeZone, "enroll_endpoint", "metrics_endpoint", "logs_endpoint", diff --git a/cmd/fleetint/unenroll_test.go b/cmd/fleetint/unenroll_test.go index b52c9e5e..34ebe06a 100644 --- a/cmd/fleetint/unenroll_test.go +++ b/cmd/fleetint/unenroll_test.go @@ -23,6 +23,8 @@ import ( pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" ) func TestRemoveEnrollmentMetadata(t *testing.T) { @@ -36,14 +38,16 @@ func TestRemoveEnrollmentMetadata(t *testing.T) { require.NoError(t, pkgmetadata.CreateTableMetadata(ctx, db)) for key, value := range map[string]string{ - pkgmetadata.MetadataKeyToken: "jwt-token", - "sak_token": "sak-token", - "backend_base_url": "https://backend.example.com", - "enroll_endpoint": "https://backend.example.com/api/v1/enroll", - "metrics_endpoint": "https://backend.example.com/api/v1/health/metrics", - "logs_endpoint": "https://backend.example.com/api/v1/health/logs", - "nonce_endpoint": "https://backend.example.com/api/v1/attest/nonce", - "keep_me": "still-here", + pkgmetadata.MetadataKeyToken: "jwt-token", + agentstate.MetadataKeySAKToken: "sak-token", + agentstate.MetadataKeyBackendBaseURL: "https://backend.example.com", + agentstate.MetadataKeyNodeGroup: "group-a", + agentstate.MetadataKeyComputeZone: "zone-a", + "enroll_endpoint": "https://backend.example.com/api/v1/enroll", + "metrics_endpoint": "https://backend.example.com/api/v1/health/metrics", + "logs_endpoint": "https://backend.example.com/api/v1/health/logs", + "nonce_endpoint": "https://backend.example.com/api/v1/attest/nonce", + "keep_me": "still-here", } { require.NoError(t, pkgmetadata.SetMetadata(ctx, db, key, value)) } @@ -52,8 +56,10 @@ func TestRemoveEnrollmentMetadata(t *testing.T) { for _, key := range []string{ pkgmetadata.MetadataKeyToken, - "sak_token", - "backend_base_url", + agentstate.MetadataKeySAKToken, + agentstate.MetadataKeyBackendBaseURL, + agentstate.MetadataKeyNodeGroup, + agentstate.MetadataKeyComputeZone, "enroll_endpoint", "metrics_endpoint", "logs_endpoint", diff --git a/docs/usage.md b/docs/usage.md index 3bc5d9f0..0dfc5cc4 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -128,12 +128,19 @@ One of `--token` or `--token-file` is required. **Optional Flags:** - `--force`: Continue enrollment even if `fleetint precheck` fails +- `--node-group`: Optional node group metadata persisted in local agent metadata +- `--compute-zone`: Optional compute zone metadata persisted in local agent metadata + +Metadata update behavior for `--node-group` and `--compute-zone`: +- If the flag is omitted, the existing stored value is preserved. +- If the flag is provided, the stored value is overwritten with the provided value. +- Providing an empty value (for example `--node-group=""`) clears the stored value. **What it does:** 1. Runs the same prerequisite validation as `fleetint precheck` 2. Validates the endpoint URL (must be HTTPS) 3. Makes an enrollment request to exchange the SAK token for a JWT token -4. Stores the JWT token and backend endpoints (metrics, logs, nonce) in the local metadata database +4. Stores the JWT token, backend endpoints (metrics, logs, nonce), and optional enrollment metadata (`node_group`, `compute_zone`) in the local metadata database 5. The stored credentials are used automatically by the agent for data export **Example output:** diff --git a/go.mod b/go.mod index dbbb0b4a..8a6f8fd5 100644 --- a/go.mod +++ b/go.mod @@ -99,12 +99,12 @@ require ( go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect golang.org/x/arch v0.22.0 // indirect - golang.org/x/crypto v0.50.0 // indirect + golang.org/x/crypto v0.52.0 // indirect golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect - golang.org/x/net v0.53.0 // indirect + golang.org/x/net v0.55.0 // indirect golang.org/x/sync v0.20.0 // indirect - golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.36.0 // indirect + golang.org/x/sys v0.45.0 // indirect + golang.org/x/text v0.37.0 // indirect golang.org/x/time v0.12.0 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect diff --git a/go.sum b/go.sum index b624e48a..67fd4380 100644 --- a/go.sum +++ b/go.sum @@ -248,20 +248,20 @@ golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= -golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988= +golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= -golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= +golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= +golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= -golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= +golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -278,12 +278,12 @@ golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= +golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= -golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/agentstate/sqlite.go b/internal/agentstate/sqlite.go index 84aa3b3c..523c682f 100644 --- a/internal/agentstate/sqlite.go +++ b/internal/agentstate/sqlite.go @@ -112,6 +112,22 @@ func (s *sqliteState) SetNodeUUID(ctx context.Context, value string) error { return s.setMetadata(ctx, pkgmetadata.MetadataKeyMachineID, value) } +func (s *sqliteState) GetNodeGroup(ctx context.Context) (string, bool, error) { + return s.getMetadata(ctx, MetadataKeyNodeGroup) +} + +func (s *sqliteState) SetNodeGroup(ctx context.Context, value string) error { + return s.setMetadata(ctx, MetadataKeyNodeGroup, value) +} + +func (s *sqliteState) GetComputeZone(ctx context.Context) (string, bool, error) { + return s.getMetadata(ctx, MetadataKeyComputeZone) +} + +func (s *sqliteState) SetComputeZone(ctx context.Context, value string) error { + return s.setMetadata(ctx, MetadataKeyComputeZone, value) +} + func (s *sqliteState) GetEnrollmentTime(ctx context.Context) (time.Time, bool, error) { value, ok, err := s.getMetadata(ctx, MetadataKeyEnrolledAt) if err != nil || !ok { diff --git a/internal/agentstate/sqlite_test.go b/internal/agentstate/sqlite_test.go index 4c302cde..af28c849 100644 --- a/internal/agentstate/sqlite_test.go +++ b/internal/agentstate/sqlite_test.go @@ -51,6 +51,10 @@ func TestSQLiteStateRoundTrip(t *testing.T) { require.NoError(t, err) err = state.SetNodeUUID(ctx, "node-1") require.NoError(t, err) + err = state.SetNodeGroup(ctx, "group-a") + require.NoError(t, err) + err = state.SetComputeZone(ctx, "us-west-2a") + require.NoError(t, err) enrollmentTime := time.Date(2026, 5, 6, 15, 0, 0, 123456789, time.UTC) err = state.SetEnrollmentTime(ctx, enrollmentTime) require.NoError(t, err) @@ -75,6 +79,16 @@ func TestSQLiteStateRoundTrip(t *testing.T) { require.True(t, ok) require.Equal(t, "node-1", value) + value, ok, err = state.GetNodeGroup(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "group-a", value) + + value, ok, err = state.GetComputeZone(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "us-west-2a", value) + gotEnrollmentTime, ok, err := state.GetEnrollmentTime(ctx) require.NoError(t, err) require.True(t, ok) @@ -115,6 +129,8 @@ func TestSQLiteStateMissingMetadataTableIsTreatedAsAbsent(t *testing.T) { state.GetJWT, state.GetSAK, state.GetNodeUUID, + state.GetNodeGroup, + state.GetComputeZone, } { value, ok, err := get(ctx) require.NoError(t, err) diff --git a/internal/agentstate/state.go b/internal/agentstate/state.go index bc89caa0..23da576a 100644 --- a/internal/agentstate/state.go +++ b/internal/agentstate/state.go @@ -25,6 +25,8 @@ const ( MetadataKeyBackendBaseURL = "backend_base_url" MetadataKeySAKToken = "sak_token" MetadataKeyEnrolledAt = "enrolled_at" + MetadataKeyNodeGroup = "nodegroup" + MetadataKeyComputeZone = "compute_zone" ) // State provides local persisted metadata/state access for backend workflows. @@ -43,4 +45,10 @@ type State interface { GetEnrollmentTime(ctx context.Context) (value time.Time, ok bool, err error) SetEnrollmentTime(ctx context.Context, value time.Time) error + + GetNodeGroup(ctx context.Context) (value string, ok bool, err error) + SetNodeGroup(ctx context.Context, value string) error + + GetComputeZone(ctx context.Context) (value string, ok bool, err error) + SetComputeZone(ctx context.Context, value string) error } diff --git a/internal/attestation/backend_test.go b/internal/attestation/backend_test.go index 1d6471f7..86b109ae 100644 --- a/internal/attestation/backend_test.go +++ b/internal/attestation/backend_test.go @@ -51,6 +51,14 @@ func (s *stubState) GetNodeUUID(context.Context) (string, bool, error) { return s.nodeUUID, s.nodeOK, s.nodeErr } func (s *stubState) SetNodeUUID(context.Context, string) error { return nil } +func (s *stubState) GetNodeGroup(context.Context) (string, bool, error) { + return "", false, nil +} +func (s *stubState) SetNodeGroup(context.Context, string) error { return nil } +func (s *stubState) GetComputeZone(context.Context) (string, bool, error) { + return "", false, nil +} +func (s *stubState) SetComputeZone(context.Context, string) error { return nil } func (s *stubState) GetEnrollmentTime(context.Context) (time.Time, bool, error) { return time.Time{}, false, nil } diff --git a/internal/backendclient/types.go b/internal/backendclient/types.go index 5a25e0e7..dfad2a6c 100644 --- a/internal/backendclient/types.go +++ b/internal/backendclient/types.go @@ -36,6 +36,8 @@ type NodeUpsertRequest struct { Uptime *time.Time `json:"uptime,omitempty"` EnrolledAt *time.Time `json:"enrolledAt,omitempty"` NetPrivateIP string `json:"netPrivateIP,omitempty"` + NodeGroup string `json:"nodeGroup"` + ComputeZone string `json:"computeZone"` } type NodeResources struct { diff --git a/internal/enrollment/enrollment.go b/internal/enrollment/enrollment.go index 625c87d3..7e50303c 100644 --- a/internal/enrollment/enrollment.go +++ b/internal/enrollment/enrollment.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "net/url" + "strings" "time" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" @@ -44,6 +45,12 @@ var ( postEnrollInventorySyncTimeout = time.Minute ) +// EnrollMetadata contains optional enrollment metadata values persisted for runtime use. +type EnrollMetadata struct { + NodeGroup *string + ComputeZone *string +} + // Enroll runs the full enrollment workflow and performs a best-effort initial inventory sync. func Enroll(ctx context.Context, baseEndpoint, sakToken string) error { return EnrollWithConfig(ctx, baseEndpoint, sakToken, nil) @@ -51,6 +58,11 @@ func Enroll(ctx context.Context, baseEndpoint, sakToken string) error { // EnrollWithConfig runs the full enrollment workflow and uses cfg for best-effort inventory metadata. func EnrollWithConfig(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config) error { + return EnrollWithConfigAndMetadata(ctx, baseEndpoint, sakToken, cfg, nil) +} + +// EnrollWithConfigAndMetadata runs the full enrollment workflow and persists optional metadata values. +func EnrollWithConfigAndMetadata(ctx context.Context, baseEndpoint, sakToken string, cfg *config.Config, metadata *EnrollMetadata) error { baseURL, err := normalizeBackendBaseURL(baseEndpoint) if err != nil { return fmt.Errorf("invalid enrollment endpoint: %w", err) @@ -65,7 +77,7 @@ func EnrollWithConfig(ctx context.Context, baseEndpoint, sakToken string, cfg *c return err } enrolledAt := time.Now().UTC() - if err := storeConfigInMetadata(ctx, baseURL.String(), jwtToken, sakToken, enrolledAt); err != nil { + if err := storeConfigInMetadata(ctx, baseURL.String(), jwtToken, sakToken, enrolledAt, normalizedEnrollMetadata(metadata)); err != nil { return fmt.Errorf("failed to store configuration: %w", err) } syncCtx, cancel := context.WithTimeout(ctx, postEnrollInventorySyncTimeout) @@ -112,7 +124,25 @@ func normalizeBackendBaseURL(raw string) (*url.URL, error) { return endpoint.ValidateBackendEndpoint(normalized) } -func storeConfigInMetadata(ctx context.Context, baseURL, jwtToken, sakToken string, enrolledAt time.Time) error { +func normalizedEnrollMetadata(metadata *EnrollMetadata) EnrollMetadata { + if metadata == nil { + return EnrollMetadata{} + } + return EnrollMetadata{ + NodeGroup: trimmedOptionalString(metadata.NodeGroup), + ComputeZone: trimmedOptionalString(metadata.ComputeZone), + } +} + +func trimmedOptionalString(value *string) *string { + if value == nil { + return nil + } + trimmed := strings.TrimSpace(*value) + return &trimmed +} + +func storeConfigInMetadata(ctx context.Context, baseURL, jwtToken, sakToken string, enrolledAt time.Time, metadata EnrollMetadata) error { stateFile, err := config.DefaultStateFile() if err != nil { return fmt.Errorf("failed to get state file path: %w", err) @@ -143,6 +173,16 @@ func storeConfigInMetadata(ctx context.Context, baseURL, jwtToken, sakToken stri if err := pkgmetadata.SetMetadata(ctx, dbRW, agentstate.MetadataKeyEnrolledAt, enrolledAt.Format(time.RFC3339Nano)); err != nil { return fmt.Errorf("failed to set enrollment time: %w", err) } + if metadata.NodeGroup != nil { + if err := pkgmetadata.SetMetadata(ctx, dbRW, agentstate.MetadataKeyNodeGroup, *metadata.NodeGroup); err != nil { + return fmt.Errorf("failed to set nodegroup: %w", err) + } + } + if metadata.ComputeZone != nil { + if err := pkgmetadata.SetMetadata(ctx, dbRW, agentstate.MetadataKeyComputeZone, *metadata.ComputeZone); err != nil { + return fmt.Errorf("failed to set compute zone: %w", err) + } + } return nil } diff --git a/internal/enrollment/enrollment_test.go b/internal/enrollment/enrollment_test.go index d65bd509..a01f7b72 100644 --- a/internal/enrollment/enrollment_test.go +++ b/internal/enrollment/enrollment_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" ) @@ -35,6 +36,10 @@ type fakeBackendClient struct { enrollErr error } +func strPtr(value string) *string { + return &value +} + func (f *fakeBackendClient) Enroll(_ context.Context, sakToken string) (string, error) { f.enrollSAK = sakToken return f.enrollJWT, f.enrollErr @@ -110,6 +115,137 @@ func TestEnrollWorkflowPassesConfigToInventorySync(t *testing.T) { require.NoError(t, err) } +func TestEnrollWorkflowStoresOptionalMetadata(t *testing.T) { + originalFactory := newBackendClient + originalSync := syncInventoryAfterEnroll + t.Cleanup(func() { + newBackendClient = originalFactory + syncInventoryAfterEnroll = originalSync + }) + + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + return &fakeBackendClient{enrollJWT: "jwt-token"}, nil + } + syncInventoryAfterEnroll = func(context.Context, *config.Config) error { return nil } + + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + err := EnrollWithConfigAndMetadata( + context.Background(), + "https://example.com", + "sak-token", + nil, + &EnrollMetadata{ + NodeGroup: strPtr(" nodegroup-a "), + ComputeZone: strPtr(" zone-a "), + }, + ) + require.NoError(t, err) + + state := agentstate.NewSQLite() + nodeGroup, ok, err := state.GetNodeGroup(context.Background()) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "nodegroup-a", nodeGroup) + + computeZone, ok, err := state.GetComputeZone(context.Background()) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "zone-a", computeZone) +} + +func TestStoreConfigInMetadataPreservesOptionalMetadataWhenUnset(t *testing.T) { + if os.Geteuid() == 0 { + t.Skip("test expects non-root default state path resolution") + } + + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + ctx := context.Background() + + err := storeConfigInMetadata( + ctx, + "https://example.com", + "jwt-token", + "sak-token", + time.Now().UTC(), + EnrollMetadata{ + NodeGroup: strPtr("group-a"), + ComputeZone: strPtr("zone-a"), + }, + ) + require.NoError(t, err) + + err = storeConfigInMetadata( + ctx, + "https://example.com", + "jwt-token-2", + "sak-token-2", + time.Now().UTC(), + EnrollMetadata{}, + ) + require.NoError(t, err) + + state := agentstate.NewSQLite() + nodeGroup, ok, err := state.GetNodeGroup(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "group-a", nodeGroup) + + computeZone, ok, err := state.GetComputeZone(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "zone-a", computeZone) +} + +func TestStoreConfigInMetadataClearsOptionalMetadataWhenSetToEmpty(t *testing.T) { + if os.Geteuid() == 0 { + t.Skip("test expects non-root default state path resolution") + } + + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + ctx := context.Background() + + err := storeConfigInMetadata( + ctx, + "https://example.com", + "jwt-token", + "sak-token", + time.Now().UTC(), + EnrollMetadata{ + NodeGroup: strPtr("group-a"), + ComputeZone: strPtr("zone-a"), + }, + ) + require.NoError(t, err) + + err = storeConfigInMetadata( + ctx, + "https://example.com", + "jwt-token-2", + "sak-token-2", + time.Now().UTC(), + EnrollMetadata{ + NodeGroup: strPtr(""), + ComputeZone: strPtr(""), + }, + ) + require.NoError(t, err) + + state := agentstate.NewSQLite() + nodeGroup, ok, err := state.GetNodeGroup(ctx) + require.NoError(t, err) + require.False(t, ok) + require.Empty(t, nodeGroup) + + computeZone, ok, err := state.GetComputeZone(ctx) + require.NoError(t, err) + require.False(t, ok) + require.Empty(t, computeZone) +} + func TestEnrollWorkflowDefaultWrapperUsesNilConfig(t *testing.T) { originalFactory := newBackendClient originalSync := syncInventoryAfterEnroll @@ -336,6 +472,7 @@ func TestStoreConfigInMetadataSecuresFreshStateFile(t *testing.T) { "jwt-token", "sak-token", time.Now().UTC(), + EnrollMetadata{}, ) require.NoError(t, err) diff --git a/internal/exporter/collector/collector.go b/internal/exporter/collector/collector.go index 79957242..25035605 100644 --- a/internal/exporter/collector/collector.go +++ b/internal/exporter/collector/collector.go @@ -50,6 +50,8 @@ func GenerateEventID() string { type HealthData struct { CollectionID string MachineID string + NodeGroup string + ComputeZone string Timestamp time.Time MachineInfo *machineinfo.MachineInfo GPUUUIDToIndex map[string]string diff --git a/internal/exporter/converter/otlp.go b/internal/exporter/converter/otlp.go index 08c22cbd..b210f282 100644 --- a/internal/exporter/converter/otlp.go +++ b/internal/exporter/converter/otlp.go @@ -126,6 +126,22 @@ func (c *otlpConverter) createOTLPResource(data *collector.HealthData) *resource }, }) } + if data.NodeGroup != "" { + attributes = append(attributes, &commonv1.KeyValue{ + Key: "node_group", + Value: &commonv1.AnyValue{ + Value: &commonv1.AnyValue_StringValue{StringValue: data.NodeGroup}, + }, + }) + } + if data.ComputeZone != "" { + attributes = append(attributes, &commonv1.KeyValue{ + Key: "compute_zone", + Value: &commonv1.AnyValue{ + Value: &commonv1.AnyValue_StringValue{StringValue: data.ComputeZone}, + }, + }) + } return &resourcev1.Resource{ Attributes: attributes, diff --git a/internal/exporter/converter/otlp_test.go b/internal/exporter/converter/otlp_test.go index 92d2f192..8ae04f94 100644 --- a/internal/exporter/converter/otlp_test.go +++ b/internal/exporter/converter/otlp_test.go @@ -747,8 +747,10 @@ func TestOTLPConverter_UpMetric(t *testing.T) { func TestOTLPConverter_ResourceAttributes(t *testing.T) { data := &collector.HealthData{ - Timestamp: time.Now(), - MachineID: "test-machine-123", + Timestamp: time.Now(), + MachineID: "test-machine-123", + NodeGroup: "group-a", + ComputeZone: "zone-a", ComponentData: map[string]interface{}{ "comp1": map[string]any{}, "comp2": map[string]any{}, @@ -771,6 +773,38 @@ func TestOTLPConverter_ResourceAttributes(t *testing.T) { assert.Equal(t, "fleet-intelligence-agent", attrMap["service.name"]) assert.Equal(t, "test-machine-123", attrMap["machine.id"]) + assert.Equal(t, "group-a", attrMap["node_group"]) + assert.Equal(t, "zone-a", attrMap["compute_zone"]) + + logResourceAttrMap := make(map[string]string) + for _, attr := range otlpData.Logs.ResourceLogs[0].Resource.Attributes { + if attr.Value.GetStringValue() != "" { + logResourceAttrMap[attr.Key] = attr.Value.GetStringValue() + } + } + assert.Equal(t, "group-a", logResourceAttrMap["node_group"]) + assert.Equal(t, "zone-a", logResourceAttrMap["compute_zone"]) +} + +func TestOTLPConverter_ResourceAttributesOmitEmptyOptionalValues(t *testing.T) { + data := &collector.HealthData{ + Timestamp: time.Now(), + MachineID: "test-machine-123", + } + + converter := NewOTLPConverter() + otlpData := converter.Convert(data) + + rm := otlpData.Metrics.ResourceMetrics[0] + attrMap := make(map[string]string) + for _, attr := range rm.Resource.Attributes { + attrMap[attr.Key] = attr.Value.GetStringValue() + } + + _, nodeGroupExists := attrMap["node_group"] + _, computeZoneExists := attrMap["compute_zone"] + assert.False(t, nodeGroupExists) + assert.False(t, computeZoneExists) } func TestOTLPConverter_Interface(t *testing.T) { diff --git a/internal/exporter/exporter.go b/internal/exporter/exporter.go index 59605203..9c42d51e 100644 --- a/internal/exporter/exporter.go +++ b/internal/exporter/exporter.go @@ -30,6 +30,7 @@ import ( "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/collector" @@ -165,6 +166,7 @@ func (e *healthExporter) export() error { if err != nil { return fmt.Errorf("collection failed: %w", err) } + e.populateOptionalResourceMetadata(collectionCtx, healthData) // Export data based on mode if e.options.config.OfflineMode { @@ -176,6 +178,26 @@ func (e *healthExporter) export() error { } } +func (e *healthExporter) populateOptionalResourceMetadata(ctx context.Context, data *collector.HealthData) { + if data == nil || e.options.dbRO == nil { + return + } + + nodeGroup, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, agentstate.MetadataKeyNodeGroup) + if err != nil { + log.Logger.Debugw("nodegroup metadata not available for telemetry resource", "error", err) + } else { + data.NodeGroup = nodeGroup + } + + computeZone, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, agentstate.MetadataKeyComputeZone) + if err != nil { + log.Logger.Debugw("compute zone metadata not available for telemetry resource", "error", err) + } else { + data.ComputeZone = computeZone + } +} + // exportToFile writes health data to files func (e *healthExporter) exportToFile(data *collector.HealthData) error { outputFormat := e.options.config.OutputFormat diff --git a/internal/inventory/sink/backend.go b/internal/inventory/sink/backend.go index ba997b97..45a3f90a 100644 --- a/internal/inventory/sink/backend.go +++ b/internal/inventory/sink/backend.go @@ -81,6 +81,18 @@ func (s *backendSink) Export(ctx context.Context, snap *inventory.Snapshot) erro return fmt.Errorf("create backend client: %w", err) } req := mapper.ToNodeUpsertRequest(snap) + nodeGroup, ok, err := s.state.GetNodeGroup(ctx) + if err != nil { + log.Logger.Warnw("inventory export continuing without nodegroup metadata", "error", err) + } else if ok { + req.NodeGroup = nodeGroup + } + computeZone, ok, err := s.state.GetComputeZone(ctx) + if err != nil { + log.Logger.Warnw("inventory export continuing without compute zone metadata", "error", err) + } else if ok { + req.ComputeZone = computeZone + } enrollmentTime, ok, err := s.state.GetEnrollmentTime(ctx) if err != nil { log.Logger.Warnw("inventory export continuing without enrollment time", "error", err) @@ -89,6 +101,7 @@ func (s *backendSink) Export(ctx context.Context, snap *inventory.Snapshot) erro req.EnrolledAt = &normalized } outbound.LogIssues("inventory-backend-sink", "NodeUpsertRequest", outbound.ValidateNodeUpsertRequest(req), "node_uuid", nodeUUID) + if err := client.UpsertNode(ctx, nodeUUID, req, jwt); err != nil { return err } diff --git a/internal/inventory/sink/backend_test.go b/internal/inventory/sink/backend_test.go index 89cfcbd5..12713457 100644 --- a/internal/inventory/sink/backend_test.go +++ b/internal/inventory/sink/backend_test.go @@ -31,6 +31,10 @@ type fakeState struct { baseURL string jwt string nodeUUID string + nodeGroup string + computeZone string + nodeGroupErr error + computeErr error enrolled time.Time enrollmentErr error err error @@ -59,6 +63,26 @@ func (f *fakeState) GetNodeUUID(context.Context) (string, bool, error) { return f.nodeUUID, f.nodeUUID != "", nil } func (f *fakeState) SetNodeUUID(context.Context, string) error { return nil } +func (f *fakeState) GetNodeGroup(context.Context) (string, bool, error) { + if f.nodeGroupErr != nil { + return "", false, f.nodeGroupErr + } + if f.err != nil { + return "", false, f.err + } + return f.nodeGroup, f.nodeGroup != "", nil +} +func (f *fakeState) SetNodeGroup(context.Context, string) error { return nil } +func (f *fakeState) GetComputeZone(context.Context) (string, bool, error) { + if f.computeErr != nil { + return "", false, f.computeErr + } + if f.err != nil { + return "", false, f.err + } + return f.computeZone, f.computeZone != "", nil +} +func (f *fakeState) SetComputeZone(context.Context, string) error { return nil } func (f *fakeState) GetEnrollmentTime(context.Context) (time.Time, bool, error) { if f.enrollmentErr != nil { return time.Time{}, false, f.enrollmentErr @@ -133,10 +157,12 @@ func TestBackendSinkExportUsesState(t *testing.T) { client := &fakeClient{} s := &backendSink{ state: &fakeState{ - baseURL: "https://example.com", - jwt: "jwt-token", - nodeUUID: "node-1", - enrolled: enrollmentTime, + baseURL: "https://example.com", + jwt: "jwt-token", + nodeUUID: "node-1", + nodeGroup: "group-a", + computeZone: "zone-a", + enrolled: enrollmentTime, }, clientFactory: func(string) (backendclient.Client, error) { return client, nil @@ -152,10 +178,35 @@ func TestBackendSinkExportUsesState(t *testing.T) { require.Equal(t, "jwt-token", client.jwt) require.NotNil(t, client.req) require.Equal(t, "host-a", client.req.Hostname) + require.Equal(t, "group-a", client.req.NodeGroup) + require.Equal(t, "zone-a", client.req.ComputeZone) require.NotNil(t, client.req.EnrolledAt) require.Equal(t, enrollmentTime, *client.req.EnrolledAt) } +func TestBackendSinkExportWithoutOptionalMetadataUsesEmptyStrings(t *testing.T) { + client := &fakeClient{} + s := &backendSink{ + state: &fakeState{ + baseURL: "https://example.com", + jwt: "jwt-token", + nodeUUID: "node-1", + }, + clientFactory: func(string) (backendclient.Client, error) { + return client, nil + }, + } + + err := s.Export(context.Background(), &inventory.Snapshot{ + Hostname: "host-a", + MachineID: "machine-id", + }) + require.NoError(t, err) + require.NotNil(t, client.req) + require.Equal(t, "", client.req.NodeGroup) + require.Equal(t, "", client.req.ComputeZone) +} + func TestBackendSinkExportEnrollmentTimeErrorIsNonFatal(t *testing.T) { client := &fakeClient{} s := &backendSink{ @@ -179,6 +230,31 @@ func TestBackendSinkExportEnrollmentTimeErrorIsNonFatal(t *testing.T) { require.Nil(t, client.req.EnrolledAt) } +func TestBackendSinkExportOptionalMetadataErrorsAreNonFatal(t *testing.T) { + client := &fakeClient{} + s := &backendSink{ + state: &fakeState{ + baseURL: "https://example.com", + jwt: "jwt-token", + nodeUUID: "node-1", + nodeGroupErr: errors.New("failed to read nodegroup"), + computeErr: errors.New("failed to read compute zone"), + }, + clientFactory: func(string) (backendclient.Client, error) { + return client, nil + }, + } + + err := s.Export(context.Background(), &inventory.Snapshot{ + Hostname: "host-a", + MachineID: "machine-id", + }) + require.NoError(t, err) + require.NotNil(t, client.req) + require.Empty(t, client.req.NodeGroup) + require.Empty(t, client.req.ComputeZone) +} + func TestBackendSinkValidationDoesNotBlockExport(t *testing.T) { client := &fakeClient{} s := &backendSink{ diff --git a/internal/validation/outbound/validator.go b/internal/validation/outbound/validator.go index 4824c202..11074ab7 100644 --- a/internal/validation/outbound/validator.go +++ b/internal/validation/outbound/validator.go @@ -89,6 +89,8 @@ func ValidateNodeUpsertRequest(req *backendclient.NodeUpsertRequest) []Issue { validateLen(&issues, "length", "osImage", req.OSImage, 1024) validateLen(&issues, "length", "agentVersion", req.AgentVersion, 255) validateIP(&issues, "format", "netPrivateIP", req.NetPrivateIP) + validateLen(&issues, "length", "nodeGroup", req.NodeGroup, 255) + validateLen(&issues, "length", "computeZone", req.ComputeZone, 255) validateNonNegative(&issues, "numeric", "agentConfig.totalComponents", req.AgentConfig.TotalComponents) validateNonNegative(&issues, "numeric", "agentConfig.retentionPeriodSeconds", req.AgentConfig.RetentionPeriodSeconds)