diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index 4a745986..3f7ceb4c 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -15,7 +15,6 @@ import ( "github.com/jzelinskie/cobrautil/v2" "github.com/rodaine/table" "github.com/rs/zerolog/log" - "github.com/samber/lo" "github.com/spf13/cobra" "golang.org/x/exp/constraints" "golang.org/x/exp/maps" @@ -23,8 +22,9 @@ import ( "google.golang.org/grpc/status" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" - corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + schemapkg "github.com/authzed/spicedb/pkg/schema" "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" "github.com/authzed/spicedb/pkg/tuple" "github.com/authzed/zed/internal/client" @@ -33,7 +33,138 @@ import ( "github.com/authzed/zed/pkg/backupformat" ) -const doNotReturnIfExists = false +const ( + returnIfExists = true + doNotReturnIfExists = false +) + +// BackupConfig holds the configuration for creating a backup. +type BackupConfig struct { + // PrefixFilter filters relationships to only include those with this prefix. + PrefixFilter string + // PageLimit defines the number of relationships to be read per page during backup. + PageLimit uint32 + // RewriteLegacy indicates whether to rewrite legacy schema syntax. + RewriteLegacy bool +} + +// ProgressTracker tracks backup progress for resumability. +type ProgressTracker interface { + // GetCursor returns the stored cursor, or nil if no progress exists. + GetCursor() *v1.Cursor + // WriteCursor writes the current cursor to storage. + WriteCursor(cursor *v1.Cursor) error + // MarkComplete marks the backup as complete (e.g., removes progress file). + MarkComplete() error + // Close closes any underlying resources. + Close() error +} + +// fileProgressTracker implements ProgressTracker using a file. +type fileProgressTracker struct { + file *os.File + cursor *v1.Cursor +} + +func newFileProgressTracker(backupFileName string, backupAlreadyExisted bool) (*fileProgressTracker, error) { + progressFileName := toLockFileName(backupFileName) + var cursor *v1.Cursor + var fileMode int + + readCursor, readErr := os.ReadFile(progressFileName) + if backupAlreadyExisted { + // Backup exists - we need a valid progress file to resume + // Check for errors first (except not-exist) to avoid masking permission/I/O errors + if readErr != nil && !os.IsNotExist(readErr) { + return nil, fmt.Errorf("failed to read progress file for existing backup: %w", readErr) + } + if os.IsNotExist(readErr) || len(readCursor) == 0 { + return nil, fmt.Errorf("backup file %s already exists", backupFileName) + } + // Successfully read the cursor + cursor = &v1.Cursor{ + Token: string(readCursor), + } + // if backup existed and there is a progress marker, the latter should not be truncated + fileMode = os.O_WRONLY | os.O_CREATE + log.Info().Str("filename", backupFileName).Msg("backup file already exists, will resume") + } else { + // if a backup did not exist, make sure to truncate the progress file + fileMode = os.O_WRONLY | os.O_CREATE | os.O_TRUNC + } + + progressFile, err := os.OpenFile(progressFileName, fileMode, 0o644) + if err != nil { + return nil, fmt.Errorf("failed to open progress file: %w", err) + } + + return &fileProgressTracker{ + file: progressFile, + cursor: cursor, + }, nil +} + +func (f *fileProgressTracker) GetCursor() *v1.Cursor { + return f.cursor +} + +func (f *fileProgressTracker) WriteCursor(cursor *v1.Cursor) error { + if cursor == nil { + return errors.New("cannot write nil cursor to progress file") + } + + if err := f.file.Truncate(0); err != nil { + return fmt.Errorf("unable to truncate backup progress file: %w", err) + } + + if _, err := f.file.Seek(0, 0); err != nil { + return fmt.Errorf("unable to seek backup progress file: %w", err) + } + + if _, err := f.file.WriteString(cursor.Token); err != nil { + return fmt.Errorf("unable to write result cursor to backup progress file: %w", err) + } + + // Sync to ensure cursor is durably persisted before continuing + if err := f.file.Sync(); err != nil { + return fmt.Errorf("unable to sync backup progress file: %w", err) + } + + // Update in-memory cursor to keep it consistent with persisted state + f.cursor = cursor + + return nil +} + +func (f *fileProgressTracker) MarkComplete() error { + // Check if already closed/completed + if f.file == nil { + return nil + } + + // Close the file handle. The lock file itself will be cleaned up + // by OcfFileEncoder.Close() when it detects that MarkComplete was called. + if err := f.file.Sync(); err != nil { + return fmt.Errorf("failed to sync progress file: %w", err) + } + if err := f.file.Close(); err != nil { + return fmt.Errorf("failed to close progress file: %w", err) + } + f.file = nil // Mark as closed so Close() becomes a no-op + + return nil +} + +func (f *fileProgressTracker) Close() error { + // Check if file is already closed (e.g., by MarkComplete) + if f.file == nil { + return nil + } + syncErr := f.file.Sync() + closeErr := f.file.Close() + f.file = nil + return errors.Join(syncErr, closeErr) +} // cobraRunEFunc is the signature of a cobra.Command.RunE function. type cobraRunEFunc = func(cmd *cobra.Command, args []string) (err error) @@ -157,8 +288,9 @@ func registerBackupRestoreFlags(cmd *cobra.Command) { } func registerBackupCreateFlags(cmd *cobra.Command) { + cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix") + cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") cmd.Flags().Uint32("page-limit", 0, "defines the number of relationships to be read by requested page during backup") - backupformat.RegisterRewriterFlags(cmd) } func createBackupFile(filename string, returnIfExists bool) (*os.File, bool, error) { @@ -190,39 +322,77 @@ func createBackupFile(filename string, returnIfExists bool) (*os.File, bool, err return f, false, nil } -// revisionForServerless determines the latest revision to use for the backup -// because Serverless doesn't return a revision in the ReadSchema response. -func revisionForServerless(ctx context.Context, spiceClient client.Client, schema *compiler.CompiledSchema) (*v1.ZedToken, error) { - var finalErr error - for _, def := range schema.ObjectDefinitions { - stream, err := spiceClient.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{ - RelationshipFilter: &v1.RelationshipFilter{ResourceType: def.Name}, - OptionalLimit: 1, - Consistency: &v1.Consistency{ - Requirement: &v1.Consistency_FullyConsistent{ - FullyConsistent: true, - }, - }, - }) - if err != nil { - finalErr = errors.Join(finalErr, err) - continue +var ( + missingAllowedTypes = regexp.MustCompile(`(\s*)(relation)(.+)(/\* missing allowed types \*/)(.*)`) //nolint:gocritic + shortRelations = regexp.MustCompile(`(\s*)relation [a-z][a-z0-9_]:(.+)`) +) + +func partialPrefixMatch(name, prefix string) bool { + return strings.HasPrefix(name, prefix+"/") +} + +func filterSchemaDefs(schema, prefix string) (filteredSchema string, err error) { + if schema == "" || prefix == "" { + return schema, nil + } + + compiledSchema, err := compiler.Compile( + compiler.InputSchema{Source: "schema", SchemaString: schema}, + compiler.AllowUnprefixedObjectType(), + compiler.SkipValidation(), + ) + if err != nil { + return "", fmt.Errorf("error reading schema: %w", err) + } + + var prefixedDefs []compiler.SchemaDefinition + for _, def := range compiledSchema.ObjectDefinitions { + if partialPrefixMatch(def.Name, prefix) { + prefixedDefs = append(prefixedDefs, def) } + } + for _, def := range compiledSchema.CaveatDefinitions { + if partialPrefixMatch(def.Name, prefix) { + prefixedDefs = append(prefixedDefs, def) + } + } + + if len(prefixedDefs) == 0 { + return "", errors.New("filtered all definitions from schema") + } + + filteredSchema, _, err = generator.GenerateSchema(prefixedDefs) + if err != nil { + return "", fmt.Errorf("error generating filtered schema: %w", err) + } + + // Validate that the type system for the generated schema is comprehensive. + compiledFilteredSchema, err := compiler.Compile( + compiler.InputSchema{Source: "generated-schema", SchemaString: filteredSchema}, + compiler.AllowUnprefixedObjectType(), + ) + if err != nil { + return "", fmt.Errorf("generated invalid schema: %w", err) + } - msg, err := stream.Recv() + for _, rawDef := range compiledFilteredSchema.ObjectDefinitions { + ts := schemapkg.NewTypeSystem(schemapkg.ResolverForCompiledSchema(*compiledFilteredSchema)) + def, err := schemapkg.NewDefinition(ts, rawDef) if err != nil { - if errors.Is(err, io.EOF) { - finalErr = errors.Join(finalErr, fmt.Errorf("no relationships found for definition %q", def.Name)) - } else { - finalErr = errors.Join(finalErr, err) - } - continue + return "", fmt.Errorf("generated invalid schema: %w", err) + } + if _, err := def.Validate(context.Background()); err != nil { + return "", fmt.Errorf("generated invalid schema: %w", err) } - log.Trace().Str("revision", msg.GetReadAt().GetToken()).Msg("determined serverless revision") - return msg.GetReadAt(), nil } - return nil, finalErr + return filteredSchema, nil +} + +func hasRelPrefix(rel *v1.Relationship, prefix string) bool { + // Skip any relationships without the prefix on both sides. + return strings.HasPrefix(rel.Resource.ObjectType, prefix) && + strings.HasPrefix(rel.Subject.Object.ObjectType, prefix) } // CloseAndJoin attempts to close the provided arguement and joins the error @@ -237,191 +407,209 @@ func CloseAndJoin(e *error, maybeCloser any) { } func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { + config := BackupConfig{ + PrefixFilter: cobrautil.MustGetString(cmd, "prefix-filter"), + PageLimit: cobrautil.MustGetUint32(cmd, "page-limit"), + RewriteLegacy: cobrautil.MustGetBool(cmd, "rewrite-legacy"), + } + backupFileName, err := computeBackupFileName(cmd, args) if err != nil { return err } - spiceClient, err := client.NewClient(cmd) + fencoder, backupExisted, err := backupformat.NewFileEncoder(backupFileName) if err != nil { - return fmt.Errorf("unable to initialize client: %w", err) + return err } + encoder := backupformat.WithProgress(fencoder) + defer CloseAndJoin(&err, encoder) - return takeBackup( - cmd.Context(), - spiceClient, - nil, - backupFileName, - backupformat.RewriterFromFlags(cmd), - cobrautil.MustGetUint32(cmd, "page-limit"), - ) -} + progressTracker, err := newFileProgressTracker(backupFileName, backupExisted) + if err != nil { + return err + } + defer func(e *error) { *e = errors.Join(*e, progressTracker.Close()) }(&err) -func takeBackup(ctx context.Context, spiceClient client.Client, encoder backupformat.Encoder, backupFileName string, rw backupformat.Rewriter, pageLimit uint32) error { - schemaResp, err := spiceClient.ReadSchema(ctx, &v1.ReadSchemaRequest{}) + spiceClient, err := client.NewClient(cmd) if err != nil { - return fmt.Errorf("error reading schema: %w", err) - } - - // Determine if the server supports modern APIs for backups and if not, - // fallback to using ReadSchema and ReadRelationships. - // This codepath can be removed when AuthZed Serverless is fully sunset. - if bulkOpsUnsupported := schemaResp.ReadAt == nil; bulkOpsUnsupported { - compiledSchema, err := compiler.Compile( - compiler.InputSchema{Source: "schema", SchemaString: schemaResp.SchemaText}, - compiler.AllowUnprefixedObjectType(), - compiler.SkipValidation(), - ) - if err != nil { - return err - } + return fmt.Errorf("unable to initialize client: %w", err) + } - revision, err := revisionForServerless(ctx, spiceClient, compiledSchema) + var zedToken *v1.ZedToken + if !backupExisted { + zedToken, err = writeSchemaForNewBackup(cmd.Context(), spiceClient, encoder, config) if err != nil { return err } + } - var cursor string - if encoder == nil { - fencoder, backupExisted, err := backupformat.NewFileEncoder(backupFileName) - if err != nil { - return err - } - encoder = backupformat.WithProgress(backupformat.WithRewriter(rw, fencoder)) - defer CloseAndJoin(&err, encoder) - if backupExisted { - cursor, err = fencoder.Cursor() - if err != nil { - return err - } - } else { - if err := encoder.WriteSchema(schemaResp.SchemaText, revision.Token); err != nil { - return err - } - } + backupCompleted, err := backupCreateImpl(cmd.Context(), spiceClient, encoder, progressTracker, config, zedToken) + if err != nil { + return err + } + + if backupCompleted { + encoder.MarkComplete() + if markErr := progressTracker.MarkComplete(); markErr != nil { + err = errors.Join(err, markErr) } - defer CloseAndJoin(&err, encoder) + } - log.Trace().Strs("definitions", lo.Map(compiledSchema.ObjectDefinitions, func(def *corev1.NamespaceDefinition, _ int) string { - return def.Name - })).Msg("parsed object definitions") + return err +} - var cursorObj string // Tracks the definition the cursor was on - for _, def := range compiledSchema.ObjectDefinitions { - req := &v1.ReadRelationshipsRequest{ - RelationshipFilter: &v1.RelationshipFilter{ResourceType: def.Name}, - OptionalLimit: pageLimit, - } - if cursor != "" && cursorObj == def.Name { - req.OptionalCursor = &v1.Cursor{Token: cursor} - } else { - req.Consistency = &v1.Consistency{ - Requirement: &v1.Consistency_AtExactSnapshot{ - AtExactSnapshot: revision, - }, - } - } - log.Trace().Str("resource", def.Name).Str("cursor", cursor).Str("revision", revision.Token).Msg("iterated over definition") +// backupCreateImpl performs the core backup logic. It is designed to be testable +// by accepting dependencies as parameters rather than creating them internally. +func backupCreateImpl( + ctx context.Context, + spiceClient client.Client, + encoder backupformat.Encoder, + progressTracker ProgressTracker, + config BackupConfig, + zedToken *v1.ZedToken, +) (backupCompleted bool, err error) { + cursor := progressTracker.GetCursor() - stream, err := spiceClient.ReadRelationships(ctx, req) - if err != nil { - return err - } + if zedToken == nil && cursor == nil { + return false, errors.New("malformed existing backup, consider recreating it") + } - for msg, err := stream.Recv(); !errors.Is(err, io.EOF); msg, err = stream.Recv() { - switch { - case isCanceled(err) || isCanceled(ctx.Err()): - return context.Canceled - case isRetryableError(err): - newReq := req.CloneVT() - newReq.OptionalCursor = &v1.Cursor{Token: cursor} - stream, err = spiceClient.ReadRelationships(ctx, newReq) - if err != nil { - return errors.New("failed to retry request") - } - case err != nil: - return err - case ctx.Err() != nil: - return fmt.Errorf("aborted backup: %w", err) - default: - cursor = msg.AfterResultCursor.Token - cursorObj = def.Name - log.Trace().Str("cursor", cursor).Stringer("relationship", msg.Relationship).Msg("appending relationship") - if err := encoder.Append(msg.Relationship, cursor); err != nil { - return err - } + req := &v1.ExportBulkRelationshipsRequest{ + OptionalLimit: config.PageLimit, + } + + var cursorToken string + if cursor != nil { + req.OptionalCursor = cursor + cursorToken = cursor.Token + } else { + req.Consistency = &v1.Consistency{ + Requirement: &v1.Consistency_AtExactSnapshot{ + AtExactSnapshot: zedToken, + }, + } + } + + err = takeBackup(ctx, spiceClient, req, func(response *v1.ExportBulkRelationshipsResponse) error { + if response.AfterResultCursor != nil { + cursorToken = response.AfterResultCursor.Token + } + for _, rel := range response.Relationships { + if hasRelPrefix(rel, config.PrefixFilter) { + if err := encoder.Append(rel, cursorToken); err != nil { + return fmt.Errorf("error storing relationship: %w", err) } } } - encoder.MarkComplete() - } else { - var cursor string - if encoder == nil { - fencoder, backupExisted, err := backupformat.NewFileEncoder(backupFileName) - if err != nil { + + if response.AfterResultCursor != nil { + if err := progressTracker.WriteCursor(response.AfterResultCursor); err != nil { return err } - encoder = backupformat.WithProgress(backupformat.WithRewriter(rw, fencoder)) - defer CloseAndJoin(&err, encoder) - if backupExisted { - cursor, err = fencoder.Cursor() - if err != nil { - return err - } - } else { - if err := encoder.WriteSchema(schemaResp.SchemaText, schemaResp.ReadAt.Token); err != nil { - return err - } - } } + return nil + }) + if err != nil { + return false, err + } - req := &v1.ExportBulkRelationshipsRequest{OptionalLimit: pageLimit} - if cursor != "" { - req.OptionalCursor = &v1.Cursor{Token: cursor} - } else { - req.Consistency = &v1.Consistency{ - Requirement: &v1.Consistency_AtExactSnapshot{ - AtExactSnapshot: schemaResp.ReadAt, - }, + return true, nil +} + +func takeBackup(ctx context.Context, spiceClient client.Client, req *v1.ExportBulkRelationshipsRequest, processResponse func(*v1.ExportBulkRelationshipsResponse) error) error { + relationshipStream, err := spiceClient.ExportBulkRelationships(ctx, req) + if err != nil { + return fmt.Errorf("error exporting relationships: %w", err) + } + var lastResponse *v1.ExportBulkRelationshipsResponse + for { + if err := ctx.Err(); err != nil { + if isCanceled(err) { + return context.Canceled } + + return fmt.Errorf("aborted backup: %w", err) } - stream, err := spiceClient.ExportBulkRelationships(ctx, req) + relsResp, err := relationshipStream.Recv() if err != nil { - return err - } + if errors.Is(err, io.EOF) { + break + } - for msg, err := stream.Recv(); !errors.Is(err, io.EOF); msg, err = stream.Recv() { - switch { - case isCanceled(err) || isCanceled(ctx.Err()): + if isCanceled(err) { return context.Canceled - case isRetryableError(err): + } + + if isRetryableError(err) { newReq := req.CloneVT() - newReq.OptionalCursor = &v1.Cursor{Token: cursor} - stream, err = spiceClient.ExportBulkRelationships(ctx, newReq) - if err != nil { - return errors.New("failed to retry request") - } - case err != nil: - return err - case ctx.Err() != nil: - return fmt.Errorf("aborted backup: %w", err) - default: - cursor = msg.AfterResultCursor.Token - for _, r := range msg.Relationships { - if err := encoder.Append(r, cursor); err != nil { - return err - } + cursorToken := "undefined" + if lastResponse != nil && lastResponse.AfterResultCursor != nil { + newReq.OptionalCursor = lastResponse.AfterResultCursor + cursorToken = lastResponse.AfterResultCursor.Token } + + relationshipStream, err = spiceClient.ExportBulkRelationships(ctx, newReq) + log.Info().Err(err).Str("cursor-token", cursorToken).Msg("encountered retryable error, resuming after last known cursor") + continue } + + return fmt.Errorf("error receiving relationships: %w", err) + } + + lastResponse = relsResp + + if err := processResponse(relsResp); err != nil { + return err } - encoder.MarkComplete() } - // NOTE: err is returned here because there's cleanup being done - // in the `defer` blocks that will modify the `err` if the cleanup - // fails - return err + return nil +} + +// writeSchemaForNewBackup reads the schema from SpiceDB and writes it to the encoder. +// It returns the ZedToken at which the backup must be taken. +func writeSchemaForNewBackup(ctx context.Context, c client.Client, encoder backupformat.Encoder, config BackupConfig) (*v1.ZedToken, error) { + schemaResp, err := c.ReadSchema(ctx, &v1.ReadSchemaRequest{}) + if err != nil { + return nil, fmt.Errorf("error reading schema: %w", err) + } + if schemaResp.ReadAt == nil { + return nil, errors.New("`backup` is not supported on this version of SpiceDB") + } + schema := schemaResp.SchemaText + + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if config.RewriteLegacy { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + if config.PrefixFilter != "" { + schema, err = filterSchemaDefs(schema, config.PrefixFilter) + if err != nil { + return nil, err + } + } + + zedToken := schemaResp.ReadAt + + if err := encoder.WriteSchema(schema, zedToken.Token); err != nil { + return nil, fmt.Errorf("error writing schema to backup: %w", err) + } + + return zedToken, nil +} + +func toLockFileName(backupFileName string) string { + return backupFileName + ".lock" +} + +func rewriteLegacy(schema string) string { + schema = string(missingAllowedTypes.ReplaceAll([]byte(schema), []byte("\n/* deleted missing allowed type error */"))) + return string(shortRelations.ReplaceAll([]byte(schema), []byte("\n/* deleted short relation name */"))) } // computeBackupFileName computes the backup file name based. diff --git a/internal/cmd/backup_test.go b/internal/cmd/backup_test.go index bdb69677..7ce10012 100644 --- a/internal/cmd/backup_test.go +++ b/internal/cmd/backup_test.go @@ -20,7 +20,6 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/spicedb/pkg/genutil/mapz" - "github.com/authzed/spicedb/pkg/schemadsl/compiler" "github.com/authzed/spicedb/pkg/tuple" "github.com/authzed/zed/internal/client" @@ -688,79 +687,19 @@ func TestTakeBackupRecoversFromRetryableErrors(t *testing.T) { }, } - encoder := &backupformat.MockEncoder{} - rw := &backupformat.NoopRewriter{} - - err := takeBackup(t.Context(), client, encoder, "ignored", rw, 0) - require.NoError(t, err) - - require.True(t, encoder.Complete, "expecting encoder to be marked complete") - require.Len(t, encoder.Relationships, 2, "expecting two rels in the realized list") - require.Equal(t, "foo", encoder.Relationships[0].Resource.ObjectId) - require.Equal(t, "bar", encoder.Relationships[1].Resource.ObjectId) - - client.assertAllRecvCalls() -} - -func TestRevisionForServerless(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // create a spicedb server - srv := zedtesting.NewTestServer(ctx, t) - errCh := make(chan error, 1) - go func() { - errCh <- srv.Run(ctx) - }() - conn, err := srv.GRPCDialContext(ctx) - require.NoError(t, err) - - c, err := zedtesting.ClientFromConn(conn)(nil) - require.NoError(t, err) - - // write a schema - schemaText := `definition user {} -definition document { - relation view: user -}` - schema, err := compiler.Compile( - compiler.InputSchema{Source: "schema", SchemaString: schemaText}, - compiler.AllowUnprefixedObjectType(), - compiler.SkipValidation(), - ) - require.NoError(t, err) - _, err = c.WriteSchema(ctx, &v1.WriteSchemaRequest{Schema: schemaText}) - require.NoError(t, err) - - // query for serverless revision when there are no relationships in the system should return error - res, err := revisionForServerless(ctx, c, schema) - require.ErrorContains(t, err, "no relationships found") - require.Nil(t, res) - - // write relationships for the *second* object definition - _, err = c.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{ - Updates: []*v1.RelationshipUpdate{ - { - Operation: v1.RelationshipUpdate_OPERATION_CREATE, - Relationship: &v1.Relationship{ - Resource: &v1.ObjectReference{ObjectType: "document", ObjectId: "1"}, - Relation: "view", - Subject: &v1.SubjectReference{ - Object: &v1.ObjectReference{ObjectType: "user", ObjectId: "maria"}, - }, - }, - }, - }, + req := &v1.ExportBulkRelationshipsRequest{} + var collectedRels []*v1.Relationship + err := takeBackup(t.Context(), client, req, func(response *v1.ExportBulkRelationshipsResponse) error { + collectedRels = append(collectedRels, response.Relationships...) + return nil }) require.NoError(t, err) - // now, we should have a result - res, err = revisionForServerless(ctx, c, schema) - require.NoError(t, err) - require.NotNil(t, res) + require.Len(t, collectedRels, 2, "expecting two rels in the realized list") + require.Equal(t, "foo", collectedRels[0].Resource.ObjectId) + require.Equal(t, "bar", collectedRels[1].Resource.ObjectId) - cancel() - require.NoError(t, <-errCh) + client.assertAllRecvCalls() } type mockClientForBackup struct { @@ -821,3 +760,171 @@ func (m *mockClientForBackup) ExportBulkRelationships(_ context.Context, req *v1 func (m *mockClientForBackup) assertAllRecvCalls() { require.Equal(m.t, len(m.recvCalls), m.recvCallIndex, "the number of provided recvCalls should match the number of invocations") } + +type mockProgressTracker struct { + cursor *v1.Cursor + writtenCursors []*v1.Cursor + completed bool +} + +func (m *mockProgressTracker) GetCursor() *v1.Cursor { + return m.cursor +} + +func (m *mockProgressTracker) WriteCursor(cursor *v1.Cursor) error { + m.writtenCursors = append(m.writtenCursors, cursor) + m.cursor = cursor + return nil +} + +func (m *mockProgressTracker) MarkComplete() error { + m.completed = true + return nil +} + +func (m *mockProgressTracker) Close() error { + return nil +} + +func TestBackupCreateImpl(t *testing.T) { + t.Parallel() + + testRels := []*v1.Relationship{ + { + Resource: &v1.ObjectReference{ObjectType: "test/resource", ObjectId: "1"}, + Relation: "reader", + Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: "test/user", ObjectId: "1"}}, + }, + { + Resource: &v1.ObjectReference{ObjectType: "test/resource", ObjectId: "2"}, + Relation: "reader", + Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: "test/user", ObjectId: "2"}}, + }, + } + + t.Run("successful backup with relationships", func(t *testing.T) { + t.Parallel() + + cursor := &v1.Cursor{Token: "after-cursor"} + mockClient := &mockClientForBackup{ + t: t, + recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){ + func() (*v1.ExportBulkRelationshipsResponse, error) { + return &v1.ExportBulkRelationshipsResponse{ + Relationships: testRels, + AfterResultCursor: cursor, + }, nil + }, + }, + } + + progressTracker := &mockProgressTracker{} + zedToken := &v1.ZedToken{Token: "test-token"} + encoder := &backupformat.MockEncoder{} + config := BackupConfig{PrefixFilter: "test"} + + completed, err := backupCreateImpl(t.Context(), mockClient, encoder, progressTracker, config, zedToken) + + require.NoError(t, err) + require.True(t, completed, "backup should be marked as completed") + require.Len(t, progressTracker.writtenCursors, 1, "should have written one cursor") + require.Equal(t, cursor.Token, progressTracker.writtenCursors[0].Token) + + mockClient.assertAllRecvCalls() + }) + + t.Run("returns error when both zedToken and cursor are nil", func(t *testing.T) { + t.Parallel() + + mockClient := &mockClientForBackup{t: t} + progressTracker := &mockProgressTracker{cursor: nil} + encoder := &backupformat.MockEncoder{} + config := BackupConfig{} + + completed, err := backupCreateImpl(t.Context(), mockClient, encoder, progressTracker, config, nil) + + require.Error(t, err) + require.False(t, completed) + require.Contains(t, err.Error(), "malformed existing backup") + }) + + t.Run("resumes from cursor when zedToken is nil", func(t *testing.T) { + t.Parallel() + + existingCursor := &v1.Cursor{Token: "existing-cursor"} + newCursor := &v1.Cursor{Token: "new-cursor"} + mockClient := &mockClientForBackup{ + t: t, + recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){ + func() (*v1.ExportBulkRelationshipsResponse, error) { + return &v1.ExportBulkRelationshipsResponse{ + Relationships: testRels, + AfterResultCursor: newCursor, + }, nil + }, + }, + exportCalls: []func(t *testing.T, req *v1.ExportBulkRelationshipsRequest){ + func(t *testing.T, req *v1.ExportBulkRelationshipsRequest) { + require.NotNil(t, req.OptionalCursor) + require.Equal(t, existingCursor.Token, req.OptionalCursor.Token) + }, + }, + } + + progressTracker := &mockProgressTracker{cursor: existingCursor} + encoder := &backupformat.MockEncoder{} + config := BackupConfig{PrefixFilter: "test"} + + completed, err := backupCreateImpl(t.Context(), mockClient, encoder, progressTracker, config, nil) + + require.NoError(t, err) + require.True(t, completed) + + mockClient.assertAllRecvCalls() + }) + + t.Run("filters relationships by prefix", func(t *testing.T) { + t.Parallel() + + mixedRels := []*v1.Relationship{ + { + Resource: &v1.ObjectReference{ObjectType: "test/resource", ObjectId: "1"}, + Relation: "reader", + Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: "test/user", ObjectId: "1"}}, + }, + { + Resource: &v1.ObjectReference{ObjectType: "other/resource", ObjectId: "2"}, + Relation: "reader", + Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: "other/user", ObjectId: "2"}}, + }, + } + + cursor := &v1.Cursor{Token: "cursor-token"} + mockClient := &mockClientForBackup{ + t: t, + recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){ + func() (*v1.ExportBulkRelationshipsResponse, error) { + return &v1.ExportBulkRelationshipsResponse{ + Relationships: mixedRels, + AfterResultCursor: cursor, + }, nil + }, + }, + } + + progressTracker := &mockProgressTracker{} + zedToken := &v1.ZedToken{Token: "test-token"} + encoder := &backupformat.MockEncoder{} + config := BackupConfig{PrefixFilter: "test"} + + completed, err := backupCreateImpl(t.Context(), mockClient, encoder, progressTracker, config, zedToken) + + require.NoError(t, err) + require.True(t, completed) + // Only the "test/" prefixed relationship should be stored + require.Len(t, encoder.Relationships, 1) + require.Equal(t, "test/resource", encoder.Relationships[0].Resource.ObjectType) + + mockClient.assertAllRecvCalls() + }) +}