From e625a6940494a6ac0f7b0381701b7df39b51a657 Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 13:29:39 -0700 Subject: [PATCH 01/12] linter updates --- backend/internal/auth/apikey/client.go | 10 +++---- backend/internal/auth/apikey/client_test.go | 2 +- backend/internal/auth/authmw/auth.go | 4 +-- backend/internal/auth/authmw/auth_test.go | 2 +- .../internal/cmds/mgmt/serve/connect/cmd.go | 5 ++-- backend/internal/loki/loki.go | 2 +- backend/pkg/integration-test/mux.go | 5 +--- .../sqlmanager/postgres/postgres-manager.go | 12 ++++---- .../mgmt/v1alpha1/job-service/jobs.go | 3 +- cli/internal/auth/account-id.go | 2 +- cli/internal/cmds/neosync/accounts/switch.go | 5 ++-- cli/internal/cmds/neosync/jobs/trigger.go | 2 +- cli/internal/cmds/neosync/sync/sync.go | 28 +++++++++---------- cli/internal/cmds/neosync/version/version.go | 7 +++-- .../benthos-builder/builders/dynamodb.go | 2 +- .../benthos-builder/builders/mongodb.go | 2 +- .../benthos/benthos-builder/builders/sql.go | 2 +- internal/ee/mssql-manager/generate-sql.go | 2 +- internal/neosync-types/registry.go | 2 +- internal/runconfigs/runconfigs.go | 2 +- .../testutil/testcontainers/mysql/mysql.go | 9 +++--- .../testcontainers/postgres/postgres.go | 9 +++--- .../testutil/testcontainers/redis/redis.go | 7 ++--- .../openai_generate/openai_generate.go | 2 +- .../transformers/utils/string_utils.go | 2 +- .../tsql/query-qualifier.go | 7 +++-- .../sync-redis-clean-up/activity.go | 2 +- .../workflows/datasync/workflow/workflow.go | 6 ++-- .../datasync/workflow/workflow_test.go | 4 +-- .../tablesync/shared/identity-allocator.go | 5 ++-- 30 files changed, 76 insertions(+), 78 deletions(-) diff --git a/backend/internal/auth/apikey/client.go b/backend/internal/auth/apikey/client.go index 021c1fa00b..30db9aca80 100644 --- a/backend/internal/auth/apikey/client.go +++ b/backend/internal/auth/apikey/client.go @@ -24,8 +24,8 @@ type TokenContextData struct { } var ( - InvalidApiKeyErr = errors.New("token is not a valid neosync api key") - ApiKeyExpiredErr = nucleuserrors.NewUnauthenticated("token is expired") + ErrInvalidApiKey = errors.New("token is not a valid neosync api key") + ErrApiKeyExpired = nucleuserrors.NewUnauthenticated("token is expired") ) type Queries interface { @@ -66,11 +66,11 @@ func (c *Client) InjectTokenCtx(ctx context.Context, header http.Header, spec co if err != nil && !neosyncdb.IsNoRows(err) { return nil, err } else if err != nil && neosyncdb.IsNoRows(err) { - return nil, InvalidApiKeyErr + return nil, ErrInvalidApiKey } if time.Now().After(apiKey.ExpiresAt.Time) { - return nil, ApiKeyExpiredErr + return nil, ErrApiKeyExpired } return SetTokenData(ctx, &TokenContextData{ @@ -87,7 +87,7 @@ func (c *Client) InjectTokenCtx(ctx context.Context, header http.Header, spec co ApiKeyType: apikey.WorkerApiKey, }), nil } - return nil, InvalidApiKeyErr + return nil, ErrInvalidApiKey } func GetTokenDataFromCtx(ctx context.Context) (*TokenContextData, error) { diff --git a/backend/internal/auth/apikey/client_test.go b/backend/internal/auth/apikey/client_test.go index 47e5aa2a95..1b5a8593cf 100644 --- a/backend/internal/auth/apikey/client_test.go +++ b/backend/internal/auth/apikey/client_test.go @@ -88,7 +88,7 @@ func Test_Client_InjectTokenCtx_Account_Expired(t *testing.T) { "Authorization": []string{fmt.Sprintf("Bearer %s", fakeToken)}, }, connect.Spec{}) assert.Error(t, err) - assert.True(t, errors.Is(err, ApiKeyExpiredErr)) + assert.True(t, errors.Is(err, ErrApiKeyExpired)) assert.Nil(t, newctx) } diff --git a/backend/internal/auth/authmw/auth.go b/backend/internal/auth/authmw/auth.go index 2e1ec4eaf4..3cd23d27fe 100644 --- a/backend/internal/auth/authmw/auth.go +++ b/backend/internal/auth/authmw/auth.go @@ -27,9 +27,9 @@ func New( func (n *AuthMiddleware) InjectTokenCtx(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) { apiKeyCtx, err := n.apiKeyClient.InjectTokenCtx(ctx, header, spec) - if err != nil && !errors.Is(err, auth_apikey.InvalidApiKeyErr) { + if err != nil && !errors.Is(err, auth_apikey.ErrInvalidApiKey) { return nil, err - } else if err != nil && errors.Is(err, auth_apikey.InvalidApiKeyErr) { + } else if err != nil && errors.Is(err, auth_apikey.ErrInvalidApiKey) { return n.jwtClient.InjectTokenCtx(ctx, header, spec) } return apiKeyCtx, nil diff --git a/backend/internal/auth/authmw/auth_test.go b/backend/internal/auth/authmw/auth_test.go index e5318856cd..a63703eb4f 100644 --- a/backend/internal/auth/authmw/auth_test.go +++ b/backend/internal/auth/authmw/auth_test.go @@ -56,7 +56,7 @@ func Test_AuthMiddleware_InjectTokenCtx_ApiKey_JwtFallback(t *testing.T) { ctx := context.Background() mockApiKey.On("InjectTokenCtx", ctx, mock.Anything, mock.Anything). - Return(nil, auth_apikey.InvalidApiKeyErr) + Return(nil, auth_apikey.ErrInvalidApiKey) mockJwt.On("InjectTokenCtx", ctx, mock.Anything, mock.Anything). Return(context.Background(), nil) diff --git a/backend/internal/cmds/mgmt/serve/connect/cmd.go b/backend/internal/cmds/mgmt/serve/connect/cmd.go index bd87b040de..1f6e62d0c0 100644 --- a/backend/internal/cmds/mgmt/serve/connect/cmd.go +++ b/backend/internal/cmds/mgmt/serve/connect/cmd.go @@ -1049,9 +1049,10 @@ func getAuthAdminClient(ctx context.Context, authclient auth_client.Interface, l authApiClientId := getAuthApiClientId() authApiClientSecret := getAuthApiClientSecret() provider := getAuthApiProvider() - if provider == "" || provider == "auth0" { + switch provider { + case "", "auth0": return auth0.New(authApiBaseUrl, authApiClientId, authApiClientSecret) - } else if provider == "keycloak" { + case "keycloak": tokenurl, err := authclient.GetTokenEndpoint(ctx) if err != nil { return nil, err diff --git a/backend/internal/loki/loki.go b/backend/internal/loki/loki.go index 3bc6b67d18..fbc658093b 100644 --- a/backend/internal/loki/loki.go +++ b/backend/internal/loki/loki.go @@ -105,7 +105,7 @@ func GetStreamsFromResponseData(data *QueryResponseData) (Streams, error) { } streams, ok := data.Result.(Streams) if !ok { - return nil, fmt.Errorf("Result data type was not Streams, got: %T", data.Result) + return nil, fmt.Errorf("result data type was not Streams, got: %T", data.Result) } return streams, nil } diff --git a/backend/pkg/integration-test/mux.go b/backend/pkg/integration-test/mux.go index ac828fbafa..870b2f144f 100644 --- a/backend/pkg/integration-test/mux.go +++ b/backend/pkg/integration-test/mux.go @@ -126,10 +126,7 @@ func (s *NeosyncApiTestClient) setupMux( rbacClient rbac.Interface, logger *slog.Logger, ) (*http.ServeMux, error) { - isPresidioEnabled := false - if isLicensed || isNeosyncCloud { - isPresidioEnabled = true - } + isPresidioEnabled := isLicensed || isNeosyncCloud maxAllowed := int64(10000) var license *testutil.FakeEELicense diff --git a/backend/pkg/sqlmanager/postgres/postgres-manager.go b/backend/pkg/sqlmanager/postgres/postgres-manager.go index 89dae79a4e..971cc9e8e7 100644 --- a/backend/pkg/sqlmanager/postgres/postgres-manager.go +++ b/backend/pkg/sqlmanager/postgres/postgres-manager.go @@ -1157,11 +1157,12 @@ func buildTableCol(record *buildTableColRequest) string { pieces := []string{EscapePgColumn(record.ColumnName), record.DataType, buildNullableText(record.IsNullable)} if record.IsSerial { - if record.DataType == "smallint" { + switch record.DataType { + case "smallint": pieces[1] = "SMALLSERIAL" - } else if record.DataType == "bigint" { + case "bigint": pieces[1] = "BIGSERIAL" - } else { + default: pieces[1] = "SERIAL" } } else if record.SequenceDefinition != nil && *record.SequenceDefinition != "" { @@ -1178,9 +1179,10 @@ func buildTableCol(record *buildTableColRequest) string { func buildSequenceDefinition(identityType string, seqConfig *SequenceConfiguration) string { var seqStr string - if identityType == "d" { + switch identityType { + case "d": seqStr = seqConfig.ToGeneratedDefaultIdentity() - } else if identityType == "a" { + case "a": seqStr = seqConfig.ToGeneratedAlwaysIdentity() } return seqStr diff --git a/backend/services/mgmt/v1alpha1/job-service/jobs.go b/backend/services/mgmt/v1alpha1/job-service/jobs.go index ad13146fb7..7e32ee6000 100644 --- a/backend/services/mgmt/v1alpha1/job-service/jobs.go +++ b/backend/services/mgmt/v1alpha1/job-service/jobs.go @@ -20,7 +20,6 @@ import ( connectionmanager "github.com/nucleuscloud/neosync/internal/connection-manager" "github.com/nucleuscloud/neosync/internal/ee/rbac" nucleuserrors "github.com/nucleuscloud/neosync/internal/errors" - "github.com/nucleuscloud/neosync/internal/job" job_util "github.com/nucleuscloud/neosync/internal/job" "github.com/nucleuscloud/neosync/internal/neosyncdb" datasync_workflow "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/workflow" @@ -1614,7 +1613,7 @@ func (s *Service) ValidateJobMappings( } } - validator := job_util.NewJobMappingsValidator(req.Msg.Mappings, job.WithJobSourceOptions(sqlSourceOpts)) + validator := job_util.NewJobMappingsValidator(req.Msg.Mappings, job_util.WithJobSourceOptions(sqlSourceOpts)) result, err := validator.Validate(colInfoMap, req.Msg.VirtualForeignKeys, tableConstraints) if err != nil { return nil, err diff --git a/cli/internal/auth/account-id.go b/cli/internal/auth/account-id.go index fe8338045c..f217731640 100644 --- a/cli/internal/auth/account-id.go +++ b/cli/internal/auth/account-id.go @@ -44,7 +44,7 @@ func ResolveAccountIdFromFlag( } accountId, err := userconfig.GetAccountId() if err != nil { - return "", fmt.Errorf(`unable to resolve account id from account context, please use the "neosync accounts switch" command to set an active account context.`) + return "", fmt.Errorf(`unable to resolve account id from account context, please use the "neosync accounts switch" command to set an active account context: %w`, err) } logger.Debug(fmt.Sprintf("account id %q resolved from user config", accountId)) return accountId, nil diff --git a/cli/internal/cmds/neosync/accounts/switch.go b/cli/internal/cmds/neosync/accounts/switch.go index b89d936209..ce7be7f29c 100644 --- a/cli/internal/cmds/neosync/accounts/switch.go +++ b/cli/internal/cmds/neosync/accounts/switch.go @@ -110,9 +110,10 @@ func switchAccount( personalAccounts := []*mgmtv1alpha1.UserAccount{} teamAccounts := []*mgmtv1alpha1.UserAccount{} for _, a := range accounts { - if a.Type == mgmtv1alpha1.UserAccountType_USER_ACCOUNT_TYPE_PERSONAL { + switch a.Type { + case mgmtv1alpha1.UserAccountType_USER_ACCOUNT_TYPE_PERSONAL: personalAccounts = append(personalAccounts, a) - } else if a.Type == mgmtv1alpha1.UserAccountType_USER_ACCOUNT_TYPE_TEAM { + case mgmtv1alpha1.UserAccountType_USER_ACCOUNT_TYPE_TEAM: teamAccounts = append(teamAccounts, a) } } diff --git a/cli/internal/cmds/neosync/jobs/trigger.go b/cli/internal/cmds/neosync/jobs/trigger.go index 3d754146ca..277caf2383 100644 --- a/cli/internal/cmds/neosync/jobs/trigger.go +++ b/cli/internal/cmds/neosync/jobs/trigger.go @@ -86,7 +86,7 @@ func triggerJob( return err } if job.Msg.GetJob().GetAccountId() != accountId { - return fmt.Errorf("Unable to trigger job run. Job not found. AccountId: %s", accountId) + return fmt.Errorf("unable to trigger job run. job not found. accountId: %s", accountId) } _, err = jobclient.CreateJobRun(ctx, connect.NewRequest[mgmtv1alpha1.CreateJobRunRequest](&mgmtv1alpha1.CreateJobRunRequest{ JobId: jobId, diff --git a/cli/internal/cmds/neosync/sync/sync.go b/cli/internal/cmds/neosync/sync/sync.go index 78b2c71523..690a32aa0b 100644 --- a/cli/internal/cmds/neosync/sync/sync.go +++ b/cli/internal/cmds/neosync/sync/sync.go @@ -17,7 +17,6 @@ import ( "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager" sqlmanager_mysql "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/mysql" sqlmanager_postgres "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/postgres" - sql_manager "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency" "github.com/nucleuscloud/neosync/cli/internal/auth" @@ -736,14 +735,15 @@ func (c *clisync) runDestinationInitStatements( if len(block.Statements) == 0 { continue } - err = db.Db().BatchExec(c.ctx, batchSize, block.Statements, &sql_manager.BatchExecOpts{}) + err = db.Db().BatchExec(c.ctx, batchSize, block.Statements, &sqlmanager_shared.BatchExecOpts{}) if err != nil { c.logger.Error(fmt.Sprintf("Error creating tables: %v", err)) return fmt.Errorf("unable to exec pg %s statements: %w", block.Label, err) } } } - if c.cmd.Destination.Driver == postgresDriver { + switch c.cmd.Destination.Driver { + case postgresDriver: if c.cmd.Destination.TruncateCascade { truncateCascadeStmts := []string{} for _, syncCfg := range syncConfigs { @@ -752,7 +752,7 @@ func (c *clisync) runDestinationInitStatements( truncateCascadeStmts = append(truncateCascadeStmts, stmt) } } - err = db.Db().BatchExec(c.ctx, batchSize, truncateCascadeStmts, &sql_manager.BatchExecOpts{}) + err = db.Db().BatchExec(c.ctx, batchSize, truncateCascadeStmts, &sqlmanager_shared.BatchExecOpts{}) if err != nil { c.logger.Error(fmt.Sprintf("Error truncate cascade tables: %v", err)) return err @@ -772,7 +772,7 @@ func (c *clisync) runDestinationInitStatements( return err } } - } else if c.cmd.Destination.Driver == mysqlDriver { + case mysqlDriver: orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(dependencyMap) if err != nil { return err @@ -781,8 +781,8 @@ func (c *clisync) runDestinationInitStatements( for _, t := range orderedTablesResp.OrderedTables { orderedTableTruncateStatements = append(orderedTableTruncateStatements, schemaConfig.TruncateTableStatementsMap[t.String()]) } - disableFkChecks := sql_manager.DisableForeignKeyChecks - err = db.Db().BatchExec(c.ctx, batchSize, orderedTableTruncateStatements, &sql_manager.BatchExecOpts{Prefix: &disableFkChecks}) + disableFkChecks := sqlmanager_shared.DisableForeignKeyChecks + err = db.Db().BatchExec(c.ctx, batchSize, orderedTableTruncateStatements, &sqlmanager_shared.BatchExecOpts{Prefix: &disableFkChecks}) if err != nil { c.logger.Error(fmt.Sprintf("Error truncating tables: %v", err)) return err @@ -863,7 +863,7 @@ func getTableInitStatementMap( type schemaConfig struct { Schemas []*mgmtv1alpha1.DatabaseColumn - TableConstraints map[string][]*sql_manager.ForeignConstraint + TableConstraints map[string][]*sqlmanager_shared.ForeignConstraint TablePrimaryKeys map[string]*mgmtv1alpha1.PrimaryConstraint TruncateTableStatementsMap map[string]string InitSchemaStatements []*mgmtv1alpha1.SchemaInitStatements @@ -955,18 +955,18 @@ func (c *clisync) getSourceConnectionSqlSchemaConfig( if err := errgrp.Wait(); err != nil { return nil, err } - tc := map[string][]*sql_manager.ForeignConstraint{} + tc := map[string][]*sqlmanager_shared.ForeignConstraint{} for table, constraints := range tableConstraints { - fkConstraints := []*sql_manager.ForeignConstraint{} + fkConstraints := []*sqlmanager_shared.ForeignConstraint{} for _, fk := range constraints.GetConstraints() { - var foreignKey *sql_manager.ForeignKey + var foreignKey *sqlmanager_shared.ForeignKey if fk.ForeignKey != nil { - foreignKey = &sql_manager.ForeignKey{ + foreignKey = &sqlmanager_shared.ForeignKey{ Table: fk.GetForeignKey().GetTable(), Columns: fk.GetForeignKey().GetColumns(), } } - fkConstraints = append(fkConstraints, &sql_manager.ForeignConstraint{ + fkConstraints = append(fkConstraints, &sqlmanager_shared.ForeignConstraint{ Columns: fk.GetColumns(), NotNullable: fk.GetNotNullable(), ForeignKey: foreignKey, @@ -1082,7 +1082,7 @@ func (c *clisync) getDestinationSchemaConfig( }, nil } -func (c *clisync) getDestinationTableConstraints(schemas []string) (*sql_manager.TableConstraints, error) { +func (c *clisync) getDestinationTableConstraints(schemas []string) (*sqlmanager_shared.TableConstraints, error) { cctx, cancel := context.WithDeadline(c.ctx, time.Now().Add(5*time.Second)) defer cancel() destConnection := cmdConfigToDestinationConnection(c.cmd) diff --git a/cli/internal/cmds/neosync/version/version.go b/cli/internal/cmds/neosync/version/version.go index 080f4624c5..8bfade325d 100644 --- a/cli/internal/cmds/neosync/version/version.go +++ b/cli/internal/cmds/neosync/version/version.go @@ -24,19 +24,20 @@ func NewCmd() *cobra.Command { } versionInfo := version.Get() - if output == "json" { + switch output { + case "json": marshaled, err := json.MarshalIndent(&versionInfo, "", " ") if err != nil { return err } fmt.Println(string(marshaled)) //nolint:forbidigo - } else if output == "yaml" { + case "yaml": marshaled, err := yaml.Marshal(&versionInfo) if err != nil { return err } fmt.Println(string(marshaled)) //nolint:forbidigo - } else { + default: fmt.Println("Git Version:", versionInfo.GitVersion) //nolint:forbidigo fmt.Println("Git Commit:", versionInfo.GitCommit) //nolint:forbidigo fmt.Println("Build Date:", versionInfo.BuildDate) //nolint:forbidigo diff --git a/internal/benthos/benthos-builder/builders/dynamodb.go b/internal/benthos/benthos-builder/builders/dynamodb.go index 6a7a718eaf..4d21655a19 100644 --- a/internal/benthos/benthos-builder/builders/dynamodb.go +++ b/internal/benthos/benthos-builder/builders/dynamodb.go @@ -121,7 +121,7 @@ func (b *dyanmodbSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb return nil, err } for _, pc := range processorConfigs { - bc.StreamConfig.Pipeline.Processors = append(bc.StreamConfig.Pipeline.Processors, *pc) + bc.Pipeline.Processors = append(bc.Pipeline.Processors, *pc) } benthosConfigs = append(benthosConfigs, &bb_internal.BenthosSourceConfig{ diff --git a/internal/benthos/benthos-builder/builders/mongodb.go b/internal/benthos/benthos-builder/builders/mongodb.go index aa095d385e..d5648f46c8 100644 --- a/internal/benthos/benthos-builder/builders/mongodb.go +++ b/internal/benthos/benthos-builder/builders/mongodb.go @@ -90,7 +90,7 @@ func (b *mongodbSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_ return nil, err } for _, pc := range processorConfigs { - bc.StreamConfig.Pipeline.Processors = append(bc.StreamConfig.Pipeline.Processors, *pc) + bc.Pipeline.Processors = append(bc.Pipeline.Processors, *pc) } benthosConfigs = append(benthosConfigs, &bb_internal.BenthosSourceConfig{ diff --git a/internal/benthos/benthos-builder/builders/sql.go b/internal/benthos/benthos-builder/builders/sql.go index 516ffb9a23..567a44f6f1 100644 --- a/internal/benthos/benthos-builder/builders/sql.go +++ b/internal/benthos/benthos-builder/builders/sql.go @@ -247,7 +247,7 @@ func buildBenthosSqlSourceConfigResponses( return nil, err } for _, pc := range processorConfigs { - bc.StreamConfig.Pipeline.Processors = append(bc.StreamConfig.Pipeline.Processors, *pc) + bc.Pipeline.Processors = append(bc.Pipeline.Processors, *pc) } cursors, err := buildIdentityCursors(ctx, transformerclient, mappings.Mappings) diff --git a/internal/ee/mssql-manager/generate-sql.go b/internal/ee/mssql-manager/generate-sql.go index fbb5ec4c7b..096d98a5ab 100644 --- a/internal/ee/mssql-manager/generate-sql.go +++ b/internal/ee/mssql-manager/generate-sql.go @@ -40,7 +40,7 @@ func generateCreateTableStatement(rows []*mssql_queries.GetDatabaseTableSchemasB sb.WriteString(fmt.Sprintf(" [%s] ", row.ColumnName)) - if !(row.IsComputed && row.GenerationExpression.Valid) { + if !row.IsComputed || !row.GenerationExpression.Valid { switch { case row.CharacterMaximumLength.Valid: if row.CharacterMaximumLength.Int32 == -1 { diff --git a/internal/neosync-types/registry.go b/internal/neosync-types/registry.go index 937f9080be..7cf00eba54 100644 --- a/internal/neosync-types/registry.go +++ b/internal/neosync-types/registry.go @@ -68,7 +68,7 @@ func (r *TypeRegistry) New(typeId string, version Version) (NeosyncAdapter, erro return newTypeFunc() } - return nil, fmt.Errorf("unknown version %d for type Id: %s. latest version not found.", version, typeId) + return nil, fmt.Errorf("unknown version %d for type Id: %s. latest version not found", version, typeId) } // UnmarshalAny deserializes a value of type any into an appropriate type based on the Neosync type system. diff --git a/internal/runconfigs/runconfigs.go b/internal/runconfigs/runconfigs.go index 1638d42ce7..28d7b784bc 100644 --- a/internal/runconfigs/runconfigs.go +++ b/internal/runconfigs/runconfigs.go @@ -230,7 +230,7 @@ func BuildRunConfigs( // check run path if !isValidRunOrder(configs) { - return nil, errors.New("Unsupported circular dependency detected. At least one foreign key in circular dependency must be nullable") + return nil, errors.New("unsupported circular dependency detected. at least one foreign key in circular dependency must be nullable") } return configs, nil diff --git a/internal/testutil/testcontainers/mysql/mysql.go b/internal/testutil/testcontainers/mysql/mysql.go index 885fffbf17..a09e71c008 100644 --- a/internal/testutil/testcontainers/mysql/mysql.go +++ b/internal/testutil/testcontainers/mysql/mysql.go @@ -12,7 +12,6 @@ import ( "github.com/nucleuscloud/neosync/internal/sshtunnel/connectors/mysqltunconnector" "github.com/nucleuscloud/neosync/internal/testutil" "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/mysql" testmysql "github.com/testcontainers/testcontainers-go/modules/mysql" "github.com/testcontainers/testcontainers-go/wait" "golang.org/x/sync/errgroup" @@ -131,9 +130,9 @@ func WithTls() Option { // Creates and starts a MySQL test container and sets up the connection. func setup(ctx context.Context, cfg *mysqlTestContainerConfig) (*MysqlTestContainer, error) { tcopts := []testcontainers.ContainerCustomizer{ - mysql.WithDatabase(cfg.database), - mysql.WithUsername(cfg.username), - mysql.WithPassword(cfg.password), + testmysql.WithDatabase(cfg.database), + testmysql.WithUsername(cfg.username), + testmysql.WithPassword(cfg.password), testcontainers.WithWaitStrategy( wait.ForLog("port: 3306 MySQL Community Server").WithOccurrence(1).WithStartupTimeout(20 * time.Second), ), @@ -175,7 +174,7 @@ func setup(ctx context.Context, cfg *mysqlTestContainerConfig) (*MysqlTestContai })), ) } - mysqlContainer, err := mysql.Run( + mysqlContainer, err := testmysql.Run( ctx, "mysql:8.0.36", tcopts..., diff --git a/internal/testutil/testcontainers/postgres/postgres.go b/internal/testutil/testcontainers/postgres/postgres.go index f78b7c6101..024c7848f8 100644 --- a/internal/testutil/testcontainers/postgres/postgres.go +++ b/internal/testutil/testcontainers/postgres/postgres.go @@ -11,7 +11,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/nucleuscloud/neosync/internal/testutil" "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/postgres" testpg "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" "golang.org/x/sync/errgroup" @@ -130,9 +129,9 @@ func WithTls() Option { // Creates and starts a PostgreSQL test container and sets up the connection. func setup(ctx context.Context, cfg *pgTestContainerConfig) (*PostgresTestContainer, error) { tcopts := []testcontainers.ContainerCustomizer{ - postgres.WithDatabase(cfg.database), - postgres.WithUsername(cfg.username), - postgres.WithPassword(cfg.password), + testpg.WithDatabase(cfg.database), + testpg.WithUsername(cfg.username), + testpg.WithPassword(cfg.password), testcontainers.WithWaitStrategy( wait.ForLog("database system is ready to accept connections"). WithOccurrence(2).WithStartupTimeout(20 * time.Second), @@ -175,7 +174,7 @@ func setup(ctx context.Context, cfg *pgTestContainerConfig) (*PostgresTestContai })), ) } - pgContainer, err := postgres.Run( + pgContainer, err := testpg.Run( ctx, "postgres:15", tcopts..., diff --git a/internal/testutil/testcontainers/redis/redis.go b/internal/testutil/testcontainers/redis/redis.go index c041179db4..9dfdfa9119 100644 --- a/internal/testutil/testcontainers/redis/redis.go +++ b/internal/testutil/testcontainers/redis/redis.go @@ -4,7 +4,6 @@ import ( "context" "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/redis" testredis "github.com/testcontainers/testcontainers-go/modules/redis" "github.com/testcontainers/testcontainers-go/wait" ) @@ -29,11 +28,11 @@ func NewRedisTestContainer(ctx context.Context, opts ...Option) (*RedisTestConta // Creates and starts a Redis test container func (r *RedisTestContainer) Setup(ctx context.Context) (*RedisTestContainer, error) { - redisContainer, err := redis.Run( + redisContainer, err := testredis.Run( ctx, "docker.io/redis:7", - redis.WithSnapshotting(10, 1), - redis.WithLogLevel(redis.LogLevelVerbose), + testredis.WithSnapshotting(10, 1), + testredis.WithLogLevel(testredis.LogLevelVerbose), testcontainers.WithWaitStrategy( wait.ForLog("* Ready to accept connections"), wait.ForExposedPort(), diff --git a/worker/pkg/benthos/openai_generate/openai_generate.go b/worker/pkg/benthos/openai_generate/openai_generate.go index 912aa5dbca..d777dc93a9 100644 --- a/worker/pkg/benthos/openai_generate/openai_generate.go +++ b/worker/pkg/benthos/openai_generate/openai_generate.go @@ -249,7 +249,7 @@ func (b *generateReader) ReadBatch(ctx context.Context) (service.MessageBatch, s b.count -= 1 } if len(messageBatch) == 0 { - return nil, nil, errors.New("openai_generate: received response from openai but was unable to successfully process records to a structured format. see logs for more details.") + return nil, nil, errors.New("openai_generate: received response from openai but was unable to successfully process records to a structured format. see logs for more details") } return messageBatch, emptyAck, nil } diff --git a/worker/pkg/benthos/transformers/utils/string_utils.go b/worker/pkg/benthos/transformers/utils/string_utils.go index d788376dcb..8c56b64e67 100644 --- a/worker/pkg/benthos/transformers/utils/string_utils.go +++ b/worker/pkg/benthos/transformers/utils/string_utils.go @@ -223,7 +223,7 @@ func WithoutCharacters(input string, invalidChars []rune) string { } func GetRandomCharacterString(randomizer rng.Rand, size int64) string { - var stringBuilder []rune = make([]rune, size) + var stringBuilder = make([]rune, size) for i := int64(0); i < size; i++ { num := randomizer.Intn(26) stringBuilder[i] = rune('a' + num) diff --git a/worker/pkg/select-query-builder/tsql/query-qualifier.go b/worker/pkg/select-query-builder/tsql/query-qualifier.go index 956e0f6f7c..1eb5385606 100644 --- a/worker/pkg/select-query-builder/tsql/query-qualifier.go +++ b/worker/pkg/select-query-builder/tsql/query-qualifier.go @@ -139,17 +139,18 @@ func (l *tsqlListener) setToken(startToken, stopToken antlr.Token, text string) func (l *tsqlListener) addNodeText(node antlr.TerminalNode) { if node.GetSymbol().GetTokenType() != antlr.TokenEOF { text := node.GetText() - if text == "," { + switch text { + case ",": // add space after commas l.pop() l.push(text) l.push(" ") - } else if text == "." { + case ".": // remove space before periods // should be table.column not table . column l.pop() l.push(text) - } else { + default: // add space after each node text l.push(text) l.push(" ") diff --git a/worker/pkg/workflows/datasync/activities/sync-redis-clean-up/activity.go b/worker/pkg/workflows/datasync/activities/sync-redis-clean-up/activity.go index 897a2de9cf..388592e55f 100644 --- a/worker/pkg/workflows/datasync/activities/sync-redis-clean-up/activity.go +++ b/worker/pkg/workflows/datasync/activities/sync-redis-clean-up/activity.go @@ -62,7 +62,7 @@ func (a *Activity) DeleteRedisHash( ) if a.redisclient == nil { - return nil, fmt.Errorf("missing redis client. this operation requires redis.") + return nil, fmt.Errorf("missing redis client. this operation requires redis") } slogger.Debug("redis client provided") diff --git a/worker/pkg/workflows/datasync/workflow/workflow.go b/worker/pkg/workflows/datasync/workflow/workflow.go index fe289b7390..b5c260b902 100644 --- a/worker/pkg/workflows/datasync/workflow/workflow.go +++ b/worker/pkg/workflows/datasync/workflow/workflow.go @@ -46,7 +46,7 @@ func New(eelicense license.EEInterface) *Workflow { } var ( - invalidAccountStatusError = errors.New("exiting workflow due to invalid account status") + errInvalidAccountStatusError = errors.New("exiting workflow due to invalid account status") ) func withGenerateBenthosConfigsActivityOptions(ctx workflow.Context) workflow.Context { @@ -141,7 +141,7 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes if initialCheckAccountStatusResponse.Reason != nil { reason = *initialCheckAccountStatusResponse.Reason } - return nil, fmt.Errorf("halting job run due to account in invalid state. Reason: %q: %w", reason, invalidAccountStatusError) + return nil, fmt.Errorf("halting job run due to account in invalid state. Reason: %q: %w", reason, errInvalidAccountStatusError) } info := workflow.GetInfo(ctx) @@ -237,7 +237,7 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes workselector.AddReceive(stopChan, func(c workflow.ReceiveChannel, more bool) { // Stop signal received, exit the routing logger.Warn("received signal to stop workflow based on account status") - activityErr = invalidAccountStatusError + activityErr = errInvalidAccountStatusError cancelHandler() }) diff --git a/worker/pkg/workflows/datasync/workflow/workflow_test.go b/worker/pkg/workflows/datasync/workflow/workflow_test.go index 4d3311605e..23d4d43a72 100644 --- a/worker/pkg/workflows/datasync/workflow/workflow_test.go +++ b/worker/pkg/workflows/datasync/workflow/workflow_test.go @@ -706,7 +706,7 @@ func Test_Workflow_Halts_Activities_On_InvalidAccountStatus(t *testing.T) { require.Error(t, err) var applicationErr *temporal.ApplicationError require.True(t, errors.As(err, &applicationErr)) - require.ErrorContains(t, applicationErr, invalidAccountStatusError.Error()) + require.ErrorContains(t, applicationErr, errInvalidAccountStatusError.Error()) env.AssertExpectations(t) } @@ -1036,7 +1036,7 @@ func Test_Workflow_Initial_AccountStatus(t *testing.T) { assert.Error(t, err) var applicationErr *temporal.ApplicationError assert.True(t, errors.As(err, &applicationErr)) - assert.ErrorContains(t, applicationErr, invalidAccountStatusError.Error()) + assert.ErrorContains(t, applicationErr, errInvalidAccountStatusError.Error()) env.AssertExpectations(t) } diff --git a/worker/pkg/workflows/tablesync/shared/identity-allocator.go b/worker/pkg/workflows/tablesync/shared/identity-allocator.go index 564576525a..d97b32e990 100644 --- a/worker/pkg/workflows/tablesync/shared/identity-allocator.go +++ b/worker/pkg/workflows/tablesync/shared/identity-allocator.go @@ -8,7 +8,6 @@ import ( "sync" "github.com/nucleuscloud/neosync/worker/pkg/rng" - "go.temporal.io/sdk/client" temporalclient "go.temporal.io/sdk/client" ) @@ -51,7 +50,7 @@ func NewTemporalBlockAllocator(temporalclient temporalclient.Client, workflowId, } func (i *TemporalBlockAllocator) GetNextBlock(ctx context.Context, token string, blockSize uint) (*IdentityRange, error) { - handle, err := i.temporalclient.UpdateWorkflow(ctx, client.UpdateWorkflowOptions{ + handle, err := i.temporalclient.UpdateWorkflow(ctx, temporalclient.UpdateWorkflowOptions{ WorkflowID: i.workflowId, RunID: i.runId, UpdateName: AllocateIdentityBlock, @@ -59,7 +58,7 @@ func (i *TemporalBlockAllocator) GetNextBlock(ctx context.Context, token string, Id: token, BlockSize: blockSize, }}, - WaitForStage: client.WorkflowUpdateStageCompleted, + WaitForStage: temporalclient.WorkflowUpdateStageCompleted, }) if err != nil { return nil, fmt.Errorf("unable to send update to get next block size for identity %s: %w", token, err) From b11d55a08f7b49af40d283fe3bb06afb88ba45ca Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 13:31:11 -0700 Subject: [PATCH 02/12] more linter fixes --- backend/internal/cmds/mgmt/serve/connect/cmd.go | 2 +- cli/internal/cmds/neosync/sync/config.go | 2 +- cli/internal/cmds/neosync/sync/util.go | 4 ++-- internal/json-anonymizer/json-anonymizer.go | 2 +- internal/schema-manager/mssql/mssql.go | 2 +- worker/pkg/benthos/transformers/generator_utils.go | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/backend/internal/cmds/mgmt/serve/connect/cmd.go b/backend/internal/cmds/mgmt/serve/connect/cmd.go index 1f6e62d0c0..2c8d11ae07 100644 --- a/backend/internal/cmds/mgmt/serve/connect/cmd.go +++ b/backend/internal/cmds/mgmt/serve/connect/cmd.go @@ -1140,7 +1140,7 @@ func getRunLogConfig() (*v1alpha1_jobservice.RunLogConfig, error) { }, }, nil default: - return nil, errors.New("unsupported or no run log type configured, but run logs are enabled.") + return nil, errors.New("unsupported or no run log type configured, but run logs are enabled") } } diff --git a/cli/internal/cmds/neosync/sync/config.go b/cli/internal/cmds/neosync/sync/config.go index 1639946a52..81a981439e 100644 --- a/cli/internal/cmds/neosync/sync/config.go +++ b/cli/internal/cmds/neosync/sync/config.go @@ -258,7 +258,7 @@ func isConfigValid(cmd *cmdConfig, logger *slog.Logger, sourceConnection *mgmtv1 } if sourceConnection.AccountId != *cmd.AccountId { - return fmt.Errorf("Connection not found. AccountId: %s", *cmd.AccountId) + return fmt.Errorf("connection not found. accountId: %s", *cmd.AccountId) } var destinationDriver *DriverType diff --git a/cli/internal/cmds/neosync/sync/util.go b/cli/internal/cmds/neosync/sync/util.go index 67ca3297a6..bf4539ce73 100644 --- a/cli/internal/cmds/neosync/sync/util.go +++ b/cli/internal/cmds/neosync/sync/util.go @@ -123,11 +123,11 @@ func areSourceAndDestCompatible(connection *mgmtv1alpha1.Connection, destination switch connection.ConnectionConfig.Config.(type) { case *mgmtv1alpha1.ConnectionConfig_PgConfig: if destinationDriver != nil && *destinationDriver != postgresDriver { - return fmt.Errorf("Connection and destination types are incompatible [postgres, %s]", *destinationDriver) + return fmt.Errorf("connection and destination types are incompatible [postgres, %s]", *destinationDriver) } case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: if destinationDriver != nil && *destinationDriver != mysqlDriver { - return fmt.Errorf("Connection and destination types are incompatible [mysql, %s]", *destinationDriver) + return fmt.Errorf("connection and destination types are incompatible [mysql, %s]", *destinationDriver) } case *mgmtv1alpha1.ConnectionConfig_AwsS3Config, *mgmtv1alpha1.ConnectionConfig_GcpCloudstorageConfig, *mgmtv1alpha1.ConnectionConfig_DynamodbConfig: default: diff --git a/internal/json-anonymizer/json-anonymizer.go b/internal/json-anonymizer/json-anonymizer.go index c808ace9e4..7930bf7825 100644 --- a/internal/json-anonymizer/json-anonymizer.go +++ b/internal/json-anonymizer/json-anonymizer.go @@ -54,7 +54,7 @@ func NewAnonymizer(opts ...Option) (*JsonAnonymizer, error) { } if len(a.transformerMappings) == 0 && a.defaultTransformers == nil { - return nil, fmt.Errorf("failed to initialize JSON anonymizer. must provide either default transformers or transformer mappings.") + return nil, fmt.Errorf("failed to initialize JSON anonymizer. must provide either default transformers or transformer mappings") } // Initialize transformerExecutors diff --git a/internal/schema-manager/mssql/mssql.go b/internal/schema-manager/mssql/mssql.go index f38d1bf089..a6ef9a4852 100644 --- a/internal/schema-manager/mssql/mssql.go +++ b/internal/schema-manager/mssql/mssql.go @@ -67,7 +67,7 @@ func (d *MssqlSchemaManager) InitializeSchema(ctx context.Context, uniqueTables return initErrors, nil } if !d.eelicense.IsValid() { - return nil, fmt.Errorf("invalid or non-existent Neosync License. SQL Server schema init requires valid Enterprise license.") + return nil, fmt.Errorf("invalid or non-existent Neosync License. SQL Server schema init requires valid Enterprise license") } tables := []*sqlmanager_shared.SchemaTable{} for tableKey := range uniqueTables { diff --git a/worker/pkg/benthos/transformers/generator_utils.go b/worker/pkg/benthos/transformers/generator_utils.go index a66843bde1..751604b03e 100644 --- a/worker/pkg/benthos/transformers/generator_utils.go +++ b/worker/pkg/benthos/transformers/generator_utils.go @@ -69,7 +69,7 @@ func ExtractBenthosSpec(fileSet *token.FileSet) ([]*BenthosSpec, error) { if !d.IsDir() && filepath.Ext(path) == ".go" { node, err := parser.ParseFile(fileSet, path, nil, parser.ParseComments) if err != nil { - return fmt.Errorf("Failed to parse file %s: %v", path, err) + return fmt.Errorf("failed to parse file %s: %v", path, err) } for _, cgroup := range node.Comments { for _, comment := range cgroup.List { From f06d343a261699c5777e09e690e33b3f6b8f6408 Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 13:35:49 -0700 Subject: [PATCH 03/12] Fixes all linter errors --- cli/internal/cmds/neosync/sync/config.go | 4 ++-- worker/pkg/benthos/transformers/utils/string_utils.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cli/internal/cmds/neosync/sync/config.go b/cli/internal/cmds/neosync/sync/config.go index 81a981439e..b164256a96 100644 --- a/cli/internal/cmds/neosync/sync/config.go +++ b/cli/internal/cmds/neosync/sync/config.go @@ -216,10 +216,10 @@ func newCobraCmdConfig( func isConfigValid(cmd *cmdConfig, logger *slog.Logger, sourceConnection *mgmtv1alpha1.Connection, sourceConnectionType benthosbuilder_shared.ConnectionType) error { if sourceConnectionType == benthosbuilder_shared.ConnectionTypeAwsS3 && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") { - return errors.New("S3 source connection type requires job-id or job-run-id.") + return errors.New("s3 source connection type requires job-id or job-run-id") } if sourceConnectionType == benthosbuilder_shared.ConnectionTypeGCP && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") { - return errors.New("GCP Cloud Storage source connection type requires job-id or job-run-id") + return errors.New("gcp cloud storage source connection type requires job-id or job-run-id") } if (sourceConnectionType == benthosbuilder_shared.ConnectionTypeAwsS3 || sourceConnectionType == benthosbuilder_shared.ConnectionTypeGCP) && cmd.Destination.InitSchema { diff --git a/worker/pkg/benthos/transformers/utils/string_utils.go b/worker/pkg/benthos/transformers/utils/string_utils.go index 8c56b64e67..cd6a3a752e 100644 --- a/worker/pkg/benthos/transformers/utils/string_utils.go +++ b/worker/pkg/benthos/transformers/utils/string_utils.go @@ -176,7 +176,7 @@ func IsValidEmail(email string) bool { // use MaxASCII to ensure that the unicode value is only within the ASCII block which only contains latin numbers, letters and characters. func IsValidChar(s string) bool { for _, r := range s { - if !(r <= unicode.MaxASCII && (unicode.IsNumber(r) || unicode.IsLetter(r) || unicode.IsSpace(r) || IsAllowedSpecialChar(r))) { + if r > unicode.MaxASCII || (!unicode.IsNumber(r) && !unicode.IsLetter(r) && !unicode.IsSpace(r) && !IsAllowedSpecialChar(r)) { return false } } From 4b43f733872f404475d84ab9c0276576dce60551 Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 13:39:14 -0700 Subject: [PATCH 04/12] reduces minimum complexity to 15 --- .golangci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.golangci.yaml b/.golangci.yaml index 5cc20a2150..c271d17711 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -47,7 +47,7 @@ linters: - performance - style gocyclo: - min-complexity: 20 + min-complexity: 15 lll: line-length: 140 misspell: From f7aae4cb781a6f95ddcac2bbfc7cbabbcc3e6fa7 Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 13:41:11 -0700 Subject: [PATCH 05/12] reduces cyclo complexity to 10 --- .golangci.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.golangci.yaml b/.golangci.yaml index c271d17711..8aa0b0ff1a 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -47,7 +47,7 @@ linters: - performance - style gocyclo: - min-complexity: 15 + min-complexity: 10 lll: line-length: 140 misspell: @@ -105,6 +105,7 @@ formatters: enable: - gofmt - goimports + - golines settings: gofmt: rewrite-rules: From 67c660be877eabe6449e3eb2ca6b6fe6637333c6 Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 13:43:08 -0700 Subject: [PATCH 06/12] adds golines formatter --- backend/internal/auth/apikey/client.go | 23 +- backend/internal/auth/authmw/auth.go | 12 +- backend/internal/auth/client/client.go | 33 +- .../auth/clientcred_token_provider/client.go | 16 +- backend/internal/auth/jwt/client.go | 23 +- .../internal/cmds/mgmt/migrate/down/down.go | 6 +- backend/internal/cmds/mgmt/migrate/up/up.go | 6 +- .../cmds/mgmt/run/stripe-usage/cmd.go | 103 ++- .../internal/cmds/mgmt/serve/connect/cmd.go | 143 ++- .../interceptors/accountid/interceptor.go | 8 +- .../connect/interceptors/auth/interceptor.go | 8 +- .../interceptors/auth_logging/interceptor.go | 8 +- .../interceptors/bookend/interceptor.go | 8 +- .../interceptors/logger/interceptor.go | 8 +- backend/internal/dtomaps/job-runs.go | 19 +- backend/internal/dtomaps/jobs.go | 11 +- backend/internal/dtomaps/transformers.go | 5 +- backend/internal/ee/hooks/accounts/service.go | 370 ++++++-- backend/internal/ee/hooks/jobs/service.go | 102 ++- backend/internal/loki/loki.go | 32 +- backend/internal/userdata/client.go | 15 +- backend/internal/userdata/entity_enforcer.go | 80 +- backend/internal/userdata/user.go | 13 +- backend/internal/version/version.go | 10 +- backend/pkg/clienttls/clienttls.go | 4 +- backend/pkg/integration-test/clients.go | 24 +- .../integration-test/integration-test-util.go | 23 +- .../pkg/integration-test/integration-test.go | 58 +- backend/pkg/integration-test/mux.go | 121 ++- backend/pkg/metrics/usage.go | 28 +- backend/pkg/metrics/util.go | 44 +- backend/pkg/mongoconnect/connector.go | 18 +- backend/pkg/mssql-querier/querier.go | 48 +- backend/pkg/mssql-querier/system.sql.go | 68 +- backend/pkg/sqlconnect/sql-connector.go | 74 +- backend/pkg/sqlmanager/mssql/mssql-manager.go | 114 ++- backend/pkg/sqlmanager/mysql/mysql-manager.go | 437 +++++++-- .../sqlmanager/postgres/postgres-manager.go | 414 +++++++-- backend/pkg/sqlmanager/sql-manager.go | 87 +- backend/pkg/sqlretry/dbtx_retry.go | 26 +- .../pkg/table-dependency/table-dependency.go | 10 +- .../v1alpha1/account-hooks-service/service.go | 60 +- .../anonymization-service/anonymization.go | 94 +- .../mgmt/v1alpha1/api-key-service/api-keys.go | 16 +- .../mgmt/v1alpha1/auth-service/tokens.go | 19 +- .../connection-data.go | 165 +++- .../v1alpha1/connection-service/connection.go | 63 +- .../mgmt/v1alpha1/job-service/jobs.go | 260 ++++-- .../mgmt/v1alpha1/job-service/runs.go | 280 ++++-- .../mgmt/v1alpha1/metrics-service/metrics.go | 34 +- .../v1alpha1/transformers-service/entities.go | 32 +- .../system_transformers.go | 826 ++++++++++++------ .../userdefined_transformers.go | 120 ++- .../account-onboarding.go | 24 +- .../account-temporal-config.go | 24 +- .../v1alpha1/user-account-service/billing.go | 88 +- .../v1alpha1/user-account-service/users.go | 296 +++++-- backend/sql/postgresql/models/models.go | 92 +- backend/sql/postgresql/models/transformers.go | 18 +- cli/internal/auth/account-id.go | 10 +- cli/internal/auth/tokens.go | 47 +- cli/internal/cmds/neosync/accounts/list.go | 4 +- cli/internal/cmds/neosync/accounts/switch.go | 42 +- cli/internal/cmds/neosync/connections/list.go | 12 +- cli/internal/cmds/neosync/jobs/list.go | 23 +- cli/internal/cmds/neosync/jobs/trigger.go | 21 +- cli/internal/cmds/neosync/login/login.go | 29 +- cli/internal/cmds/neosync/neosync.go | 12 +- cli/internal/cmds/neosync/sync/config.go | 51 +- cli/internal/cmds/neosync/sync/job.go | 19 +- cli/internal/cmds/neosync/sync/sync.go | 220 +++-- cli/internal/cmds/neosync/sync/ui.go | 46 +- cli/internal/cmds/neosync/sync/util.go | 15 +- cli/internal/output/output.go | 3 +- cli/internal/userconfig/folder.go | 4 +- cli/internal/version/version.go | 21 +- internal/authmgmt/auth0/admin-client.go | 5 +- internal/aws/aws-manager.go | 36 +- internal/aws/dynamodb-client.go | 29 +- internal/benthos-stream/client.go | 4 +- .../benthos-builder/benthos-builder.go | 95 +- .../benthos-builder/builders/aws-s3.go | 14 +- .../benthos-builder/builders/dynamodb.go | 50 +- .../builders/gcp-cloud-storage.go | 10 +- .../benthos-builder/builders/generate-ai.go | 55 +- .../benthos-builder/builders/generate.go | 77 +- .../benthos-builder/builders/mongodb.go | 54 +- .../builders/neosync-connection-data.go | 10 +- .../benthos-builder/builders/processors.go | 183 +++- .../benthos-builder/builders/sql-util.go | 195 ++++- .../benthos/benthos-builder/builders/sql.go | 178 +++- .../benthos-builder/generate-benthos.go | 21 +- .../benthos/benthos-builder/internal/types.go | 5 +- internal/benthos_slogger/logger.go | 5 +- internal/billing/stripe-billing.go | 42 +- internal/connection-manager/manager.go | 86 +- .../providers/mongo/mongo-pool-provider.go | 12 +- .../pool/providers/sql/sql-pool-provider.go | 12 +- .../providers/mongoprovider/provider.go | 5 +- .../providers/sqlprovider/provider.go | 11 +- internal/connectiondata/aws-s3.go | 116 ++- internal/connectiondata/connectiondata.go | 24 +- internal/connectiondata/dynamodb.go | 15 +- internal/connectiondata/gcp.go | 28 +- internal/connectiondata/mongodb.go | 15 +- internal/connectiondata/sql.go | 177 +++- .../interceptors/retry/interceptor.go | 16 +- internal/connectrpc/validate/validate.go | 8 +- .../database-record-mapper/builder/builder.go | 8 +- .../database-record-mapper.go | 4 +- .../database-record-mapper/dynamodb/mapper.go | 10 +- .../database-record-mapper/mongodb/mapper.go | 10 +- .../database-record-mapper/mssql/mapper.go | 4 +- .../database-record-mapper/mysql/mapper.go | 9 +- .../database-record-mapper/postgres/mapper.go | 33 +- internal/ee/cloud-license/license.go | 11 +- internal/ee/license/license.go | 5 +- internal/ee/mssql-manager/ee-mssql-manager.go | 124 ++- internal/ee/mssql-manager/generate-sql.go | 32 +- internal/ee/presidio/interface.go | 18 +- internal/ee/presidio/util.go | 4 +- internal/ee/rbac/allow_all_client.go | 62 +- internal/ee/rbac/enforcer/enforcer.go | 8 +- internal/ee/rbac/policy.go | 117 ++- internal/ee/rbac/roles.go | 5 +- internal/ee/slack/slack.go | 49 +- .../ee/transformers/functions/functions.go | 56 +- internal/ee/transformers/transformers.go | 6 +- internal/errors/errors.go | 3 +- internal/gcp/client.go | 42 +- .../worker/workflow/datasync-workflow.go | 45 +- .../javascript/functions/benthos/functions.go | 418 +++++---- internal/javascript/functions/functions.go | 15 +- .../javascript/functions/neosync/functions.go | 198 +++-- internal/javascript/vm/vm.go | 12 +- internal/job/jobmapping-validator.go | 280 +++++- internal/job/validate-schema.go | 29 +- internal/json-anonymizer/json-anonymizer.go | 101 ++- internal/json-anonymizer/neosync-operator.go | 24 +- internal/neosync-types/array.go | 20 +- internal/neosync-types/binary.go | 10 +- internal/neosync-types/bits.go | 10 +- internal/neosync-types/datetime.go | 6 +- internal/neosync-types/interval.go | 10 +- internal/neosync-types/registry.go | 20 +- internal/neosyncdb/db.go | 7 +- internal/neosyncdb/users.go | 16 +- internal/neosyncdb/util.go | 13 +- internal/otel/otel.go | 54 +- internal/pgx-slog/adapter.go | 7 +- internal/runconfigs/builder.go | 14 +- internal/runconfigs/circular-dependencies.go | 8 +- internal/runconfigs/runconfigs.go | 33 +- internal/schema-manager/mssql/mssql.go | 88 +- internal/schema-manager/mysql/mysql.go | 161 +++- .../not-supported/not-supported.go | 32 +- internal/schema-manager/postgres/postgres.go | 157 +++- internal/schema-manager/schema-manager.go | 34 +- .../schema-manager/shared/foreign-keys.go | 4 +- internal/schema-manager/shared/schema-diff.go | 20 +- internal/schema-manager/shared/shared.go | 5 +- .../postgrestunconnector/connector.go | 8 +- internal/sshtunnel/dialer.go | 30 +- internal/sshtunnel/utils.go | 4 +- .../temporal/clientmanager/client_factory.go | 17 +- .../temporal/clientmanager/config_provider.go | 17 +- internal/temporal/clientmanager/manager.go | 191 +++- .../testcontainers/dynamodb/dynamodb.go | 56 +- .../testcontainers/mongodb/mongodb.go | 21 +- .../testutil/testcontainers/mysql/mysql.go | 43 +- .../testcontainers/postgres/postgres.go | 29 +- .../testcontainers/sqlserver/sqlserver.go | 26 +- worker/internal/cmds/worker/serve/serve.go | 171 ++-- worker/internal/temporal-logger/logger.go | 5 +- worker/pkg/benthos/config.go | 322 +++---- .../benthos/default_transform/processor.go | 24 +- worker/pkg/benthos/dynamodb/input.go | 39 +- worker/pkg/benthos/dynamodb/output.go | 59 +- worker/pkg/benthos/environment/environment.go | 84 +- worker/pkg/benthos/error/output_error.go | 16 +- worker/pkg/benthos/error/processor_error.go | 15 +- worker/pkg/benthos/javascript/processor.go | 16 +- .../benthos/json/processor_neosync_json.go | 10 +- worker/pkg/benthos/metrics/otel_metrics.go | 15 +- worker/pkg/benthos/mongodb/common.go | 25 +- worker/pkg/benthos/mongodb/input.go | 11 +- worker/pkg/benthos/mongodb/output.go | 9 +- .../neosync_connection_data_input.go | 45 +- .../openai_generate/openai_generate.go | 83 +- worker/pkg/benthos/redis/output_hash.go | 13 +- worker/pkg/benthos/sql/input_sql_raw.go | 37 +- worker/pkg/benthos/sql/output_sql_insert.go | 41 +- worker/pkg/benthos/sql/output_sql_update.go | 21 +- .../benthos/sql/processor_neosync_mssql.go | 15 +- .../benthos/sql/processor_neosync_mysql.go | 28 +- .../pkg/benthos/sql/processor_neosync_pgx.go | 27 +- .../benthos/transformer_executor/executor.go | 26 +- .../pkg/benthos/transformers/generate_bool.go | 36 +- .../transformers/generate_business_name.go | 59 +- .../transformers/generate_card_number.go | 52 +- .../transformers/generate_categorical.go | 55 +- .../pkg/benthos/transformers/generate_city.go | 53 +- .../benthos/transformers/generate_country.go | 53 +- .../benthos/transformers/generate_email.go | 137 ++- .../transformers/generate_first_name.go | 59 +- .../benthos/transformers/generate_float.go | 96 +- .../transformers/generate_full_address.go | 71 +- .../transformers/generate_full_name.go | 71 +- .../benthos/transformers/generate_gender.go | 61 +- .../benthos/transformers/generate_int64.go | 89 +- .../generate_int64_phone_number.go | 57 +- .../generate_international_phone_number.go | 74 +- .../transformers/generate_ip_address.go | 61 +- .../transformers/generate_last_name.go | 57 +- .../transformers/generate_random_string.go | 76 +- .../transformers/generate_sha256hash.go | 29 +- .../pkg/benthos/transformers/generate_ssn.go | 34 +- .../benthos/transformers/generate_state.go | 53 +- .../transformers/generate_street_address.go | 53 +- .../generate_string_phone_number.go | 66 +- .../transformers/generate_unix_timestamp.go | 44 +- .../benthos/transformers/generate_username.go | 52 +- .../transformers/generate_utc_timestamp.go | 40 +- .../pkg/benthos/transformers/generate_uuid.go | 27 +- .../benthos/transformers/generate_zipcode.go | 44 +- .../benthos/transformers/generator_utils.go | 12 +- .../transform_character_scramble.go | 73 +- .../transform_e164_phone_number.go | 96 +- .../benthos/transformers/transform_email.go | 162 ++-- .../transformers/transform_first_name.go | 96 +- .../benthos/transformers/transform_float.go | 110 ++- .../transformers/transform_full_name.go | 106 ++- .../transform_identity_scramble.go | 60 +- .../benthos/transformers/transform_int64.go | 82 +- .../transform_int64_phone_number.go | 84 +- .../transformers/transform_lastname.go | 96 +- .../benthos/transformers/transform_string.go | 99 ++- .../transform_string_phone_number.go | 90 +- .../benthos/transformers/transform_uuid.go | 66 +- .../benthos/transformers/utils/float_utils.go | 5 +- .../transformers/utils/integer_utils.go | 25 +- .../benthos/transformers/utils/slice_utils.go | 4 +- .../transformers/utils/string_utils.go | 27 +- .../pkg/query-builder/insert-query-builder.go | 66 +- .../pkg/select-query-builder/querybuilder.go | 66 +- .../tsql/query-qualifier.go | 14 +- .../activities/account-status/activity.go | 17 +- .../gen-benthos-configs/activity.go | 10 +- .../gen-benthos-configs/benthos-builder.go | 49 +- .../activities/jobhooks-by-timing/activity.go | 31 +- .../activities/post-table-sync/activity.go | 42 +- .../datasync/activities/shared/shared.go | 13 +- .../activities/sync-activity-opts/activity.go | 16 +- .../datasync/workflow/register/register.go | 7 +- .../workflows/datasync/workflow/workflow.go | 150 +++- .../activities/execute/activity.go | 34 +- .../activities/hooks-by-event/activity.go | 11 +- .../ee/account_hooks/workflow/workflow.go | 5 +- .../workflows/job/activities/activities.go | 138 ++- .../ee/piidetect/workflows/job/workflow.go | 44 +- .../piidetect/workflows/register/register.go | 14 +- .../workflows/table/activities/activities.go | 69 +- .../ee/piidetect/workflows/table/workflow.go | 5 +- .../activities/init-schema/activity.go | 4 +- .../activities/init-schema/init-schema.go | 29 +- .../activities/reconcile-schema/activity.go | 4 +- .../reconcile-schema/reconcile-schema.go | 46 +- .../schemainit/workflow/register/register.go | 7 +- .../workflows/schemainit/workflow/workflow.go | 5 +- worker/pkg/workflows/shared/util.go | 37 +- .../tablesync/activities/sync/activity.go | 81 +- .../tablesync/shared/identity-allocator.go | 47 +- .../workflows/tablesync/workflow/workflow.go | 19 +- 273 files changed, 11294 insertions(+), 4151 deletions(-) diff --git a/backend/internal/auth/apikey/client.go b/backend/internal/auth/apikey/client.go index 30db9aca80..9820c25333 100644 --- a/backend/internal/auth/apikey/client.go +++ b/backend/internal/auth/apikey/client.go @@ -29,7 +29,11 @@ var ( ) type Queries interface { - GetAccountApiKeyByKeyValue(ctx context.Context, db db_queries.DBTX, apiKey string) (db_queries.NeosyncApiAccountApiKey, error) + GetAccountApiKeyByKeyValue( + ctx context.Context, + db db_queries.DBTX, + apiKey string, + ) (db_queries.NeosyncApiAccountApiKey, error) } type Client struct { @@ -49,10 +53,19 @@ func New( for _, procedure := range allowedWorkerProcedures { allowedWorkerProcedureSet[procedure] = struct{}{} } - return &Client{q: queries, db: db, allowedWorkerApiKeys: allowedWorkerApiKeys, allowedWorkerProcedures: allowedWorkerProcedureSet} + return &Client{ + q: queries, + db: db, + allowedWorkerApiKeys: allowedWorkerApiKeys, + allowedWorkerProcedures: allowedWorkerProcedureSet, + } } -func (c *Client) InjectTokenCtx(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) { +func (c *Client) InjectTokenCtx( + ctx context.Context, + header http.Header, + spec connect.Spec, +) (context.Context, error) { token, err := utils.GetBearerTokenFromHeader(header, "Authorization") if err != nil { return nil, err @@ -93,7 +106,9 @@ func (c *Client) InjectTokenCtx(ctx context.Context, header http.Header, spec co func GetTokenDataFromCtx(ctx context.Context) (*TokenContextData, error) { data, ok := ctx.Value(TokenContextKey{}).(*TokenContextData) if !ok { - return nil, nucleuserrors.NewUnauthenticated("ctx does not contain TokenContextData or unable to cast struct") + return nil, nucleuserrors.NewUnauthenticated( + "ctx does not contain TokenContextData or unable to cast struct", + ) } return data, nil } diff --git a/backend/internal/auth/authmw/auth.go b/backend/internal/auth/authmw/auth.go index 3cd23d27fe..442317f6a0 100644 --- a/backend/internal/auth/authmw/auth.go +++ b/backend/internal/auth/authmw/auth.go @@ -10,7 +10,11 @@ import ( ) type AuthClient interface { - InjectTokenCtx(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) + InjectTokenCtx( + ctx context.Context, + header http.Header, + spec connect.Spec, + ) (context.Context, error) } type AuthMiddleware struct { @@ -25,7 +29,11 @@ func New( return &AuthMiddleware{jwtClient: jwtClient, apiKeyClient: apiKeyClient} } -func (n *AuthMiddleware) InjectTokenCtx(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) { +func (n *AuthMiddleware) InjectTokenCtx( + ctx context.Context, + header http.Header, + spec connect.Spec, +) (context.Context, error) { apiKeyCtx, err := n.apiKeyClient.InjectTokenCtx(ctx, header, spec) if err != nil && !errors.Is(err, auth_apikey.ErrInvalidApiKey) { return nil, err diff --git a/backend/internal/auth/client/client.go b/backend/internal/auth/client/client.go index 0d0fe15357..1c64b12735 100644 --- a/backend/internal/auth/client/client.go +++ b/backend/internal/auth/client/client.go @@ -12,8 +12,17 @@ import ( ) type Interface interface { - GetTokenResponse(ctx context.Context, clientId string, code string, redirecturi string) (*AuthTokenResponse, error) - GetRefreshedAccessToken(ctx context.Context, clientId string, refreshToken string) (*AuthTokenResponse, error) + GetTokenResponse( + ctx context.Context, + clientId string, + code string, + redirecturi string, + ) (*AuthTokenResponse, error) + GetRefreshedAccessToken( + ctx context.Context, + clientId string, + refreshToken string, + ) (*AuthTokenResponse, error) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) GetTokenEndpoint(ctx context.Context) (string, error) GetAuthorizationEndpoint(ctx context.Context) (string, error) @@ -147,7 +156,10 @@ func (c *Client) GetRefreshedAccessToken( clientSecret := c.clientIdSecretMap[clientId] payload := strings.NewReader( fmt.Sprintf( - "grant_type=refresh_token&client_id=%s&client_secret=%s&refresh_token=%s", clientId, clientSecret, refreshToken, + "grant_type=refresh_token&client_id=%s&client_secret=%s&refresh_token=%s", + clientId, + clientSecret, + refreshToken, ), ) tokenurl, err := c.GetTokenEndpoint(ctx) @@ -179,14 +191,20 @@ func (c *Client) GetRefreshedAccessToken( err = json.Unmarshal(body, &tokenResponse) if err != nil { - return nil, fmt.Errorf("unable to unmarshal token response from refresh token request: %w", err) + return nil, fmt.Errorf( + "unable to unmarshal token response from refresh token request: %w", + err, + ) } if tokenResponse.AccessToken == "" { var errorResponse AuthTokenErrorData err = json.Unmarshal(body, &errorResponse) if err != nil { - return nil, fmt.Errorf("unable to unmarshal error response from refresh token request: %w", err) + return nil, fmt.Errorf( + "unable to unmarshal error response from refresh token request: %w", + err, + ) } return &AuthTokenResponse{ Result: nil, @@ -274,7 +292,10 @@ type openIdConfiguration struct { } func (c *Client) getOpenIdConfiguration(ctx context.Context) (*openIdConfiguration, error) { - configUrl := fmt.Sprintf("%s/.well-known/openid-configuration", strings.TrimSuffix(c.authBaseUrl, "/")) + configUrl := fmt.Sprintf( + "%s/.well-known/openid-configuration", + strings.TrimSuffix(c.authBaseUrl, "/"), + ) req, err := http.NewRequestWithContext(ctx, http.MethodGet, configUrl, http.NoBody) if err != nil { diff --git a/backend/internal/auth/clientcred_token_provider/client.go b/backend/internal/auth/clientcred_token_provider/client.go index 0e8b284bda..b442c664aa 100644 --- a/backend/internal/auth/clientcred_token_provider/client.go +++ b/backend/internal/auth/clientcred_token_provider/client.go @@ -25,7 +25,9 @@ type tokenProviderClient struct { clientSecret string } -func (c *tokenProviderClient) GetToken(ctx context.Context) (*auth_client.AuthTokenResponse, error) { +func (c *tokenProviderClient) GetToken( + ctx context.Context, +) (*auth_client.AuthTokenResponse, error) { values := url.Values{ "grant_type": []string{"client_credentials"}, "client_id": []string{c.clientId}, @@ -87,9 +89,17 @@ type ClientCredentialsTokenProvider struct { expiresAt *time.Time } -func New(tokenurl, clientId, clientSecret string, tokenExpirationBuffer time.Duration, logger *slog.Logger) *ClientCredentialsTokenProvider { +func New( + tokenurl, clientId, clientSecret string, + tokenExpirationBuffer time.Duration, + logger *slog.Logger, +) *ClientCredentialsTokenProvider { return &ClientCredentialsTokenProvider{ - tokenprovider: &tokenProviderClient{tokenurl: tokenurl, clientId: clientId, clientSecret: clientSecret}, + tokenprovider: &tokenProviderClient{ + tokenurl: tokenurl, + clientId: clientId, + clientSecret: clientSecret, + }, tokenExpBuffer: tokenExpirationBuffer, logger: logger, } diff --git a/backend/internal/auth/jwt/client.go b/backend/internal/auth/jwt/client.go index 0af668cfe8..d9bc1391ca 100644 --- a/backend/internal/auth/jwt/client.go +++ b/backend/internal/auth/jwt/client.go @@ -73,14 +73,19 @@ func New( } // Validates and returns a parsed access token (if available) -func (j *Client) validateToken(ctx context.Context, accessToken string) (*validator.ValidatedClaims, error) { +func (j *Client) validateToken( + ctx context.Context, + accessToken string, +) (*validator.ValidatedClaims, error) { rawParsedToken, err := j.jwtValidator.ValidateToken(ctx, accessToken) if err != nil { return nil, nucleuserrors.NewUnauthenticated(err.Error()) } validatedClaims, ok := rawParsedToken.(*validator.ValidatedClaims) if !ok { - return nil, nucleuserrors.NewInternalError("unable to convert token claims what was expected") + return nil, nucleuserrors.NewInternalError( + "unable to convert token claims what was expected", + ) } return validatedClaims, nil } @@ -111,7 +116,11 @@ func hasScope(scopes []string, expectedScope string) bool { } // Validates the ctx is authenticated. Stuffs the parsed token onto the context -func (j *Client) InjectTokenCtx(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) { +func (j *Client) InjectTokenCtx( + ctx context.Context, + header http.Header, + spec connect.Spec, +) (context.Context, error) { token, err := utils.GetBearerTokenFromHeader(header, "Authorization") if err != nil { return nil, err @@ -124,7 +133,9 @@ func (j *Client) InjectTokenCtx(ctx context.Context, header http.Header, spec co claims, ok := parsedToken.CustomClaims.(*CustomClaims) if !ok { - return nil, nucleuserrors.NewInternalError("unable to cast custom token claims to CustomClaims struct") + return nil, nucleuserrors.NewInternalError( + "unable to cast custom token claims to CustomClaims struct", + ) } scopes := getCombinedScopesAndPermissions(claims.Scope, claims.Permissions) @@ -161,7 +172,9 @@ func GetTokenDataFromCtx(ctx context.Context) (*TokenContextData, error) { val := ctx.Value(TokenContextKey{}) data, ok := val.(*TokenContextData) if !ok { - return nil, nucleuserrors.NewUnauthenticated(fmt.Sprintf("ctx does not contain TokenContextData or unable to cast struct: %T", val)) + return nil, nucleuserrors.NewUnauthenticated( + fmt.Sprintf("ctx does not contain TokenContextData or unable to cast struct: %T", val), + ) } return data, nil } diff --git a/backend/internal/cmds/mgmt/migrate/down/down.go b/backend/internal/cmds/mgmt/migrate/down/down.go index eb0a7f210e..7cec207389 100644 --- a/backend/internal/cmds/mgmt/migrate/down/down.go +++ b/backend/internal/cmds/mgmt/migrate/down/down.go @@ -47,8 +47,10 @@ func NewCmd() *cobra.Command { ) }, } - cmd.Flags().StringP("database", "d", "", "optionally set the database url, otherwise it will pull from the environment") - cmd.Flags().StringP("source", "s", "", "optionally set the migrations dir, otherwise pull from DB_SCHEMA_DIR env") + cmd.Flags(). + StringP("database", "d", "", "optionally set the database url, otherwise it will pull from the environment") + cmd.Flags(). + StringP("source", "s", "", "optionally set the migrations dir, otherwise pull from DB_SCHEMA_DIR env") return cmd } diff --git a/backend/internal/cmds/mgmt/migrate/up/up.go b/backend/internal/cmds/mgmt/migrate/up/up.go index 7622e24930..ef28f358b7 100644 --- a/backend/internal/cmds/mgmt/migrate/up/up.go +++ b/backend/internal/cmds/mgmt/migrate/up/up.go @@ -47,8 +47,10 @@ func NewCmd() *cobra.Command { ) }, } - cmd.Flags().StringP("database", "d", "", "optionally set the database url, otherwise it will pull from the environment") - cmd.Flags().StringP("source", "s", "", "optionally set the migrations dir, otherwise pull from DB_SCHEMA_DIR env") + cmd.Flags(). + StringP("database", "d", "", "optionally set the database url, otherwise it will pull from the environment") + cmd.Flags(). + StringP("source", "s", "", "optionally set the migrations dir, otherwise pull from DB_SCHEMA_DIR env") return cmd } diff --git a/backend/internal/cmds/mgmt/run/stripe-usage/cmd.go b/backend/internal/cmds/mgmt/run/stripe-usage/cmd.go index 780e2fc5d2..2146733b00 100644 --- a/backend/internal/cmds/mgmt/run/stripe-usage/cmd.go +++ b/backend/internal/cmds/mgmt/run/stripe-usage/cmd.go @@ -56,7 +56,9 @@ func run(ctx context.Context) error { if meterName == "" { return errors.New("must provide valid meter name") } - eventIdSuffix := viper.GetString("EVENT_ID_SUFFIX") // optionally add an event id suffix to allow reprocessing + eventIdSuffix := viper.GetString( + "EVENT_ID_SUFFIX", + ) // optionally add an event id suffix to allow reprocessing ingestDate, err := getIngestDate(ingestDateStr, ingestDateOffset) if err != nil { @@ -77,21 +79,35 @@ func run(ctx context.Context) error { clientInterceptors = append(clientInterceptors, otelinterceptors...) defer func() { if err := otelshutdown(context.Background()); err != nil { - slogger.ErrorContext(ctx, fmt.Errorf("unable to gracefully shutdown otel providers: %w", err).Error()) + slogger.ErrorContext( + ctx, + fmt.Errorf("unable to gracefully shutdown otel providers: %w", err).Error(), + ) } }() } - usersclient := mgmtv1alpha1connect.NewUserAccountServiceClient(httpclient, neosyncurl, connect.WithInterceptors(clientInterceptors...)) - metricsclient := mgmtv1alpha1connect.NewMetricsServiceClient(httpclient, neosyncurl, connect.WithInterceptors(clientInterceptors...)) + usersclient := mgmtv1alpha1connect.NewUserAccountServiceClient( + httpclient, + neosyncurl, + connect.WithInterceptors(clientInterceptors...), + ) + metricsclient := mgmtv1alpha1connect.NewMetricsServiceClient( + httpclient, + neosyncurl, + connect.WithInterceptors(clientInterceptors...), + ) if len(accountIds) > 0 { slogger.DebugContext(ctx, fmt.Sprintf("%d accounts provided as input", len(accountIds))) } - accountsResp, err := usersclient.GetBillingAccounts(ctx, connect.NewRequest(&mgmtv1alpha1.GetBillingAccountsRequest{ - AccountIds: accountIds, - })) + accountsResp, err := usersclient.GetBillingAccounts( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetBillingAccountsRequest{ + AccountIds: accountIds, + }), + ) if err != nil { return err } @@ -108,7 +124,12 @@ func run(ctx context.Context) error { slogger.With("accountId", account.GetId()), ) if err != nil { - slogger.ErrorContext(ctx, fmt.Errorf("unable to process account: %w", err).Error(), "accountId", account.GetId()) + slogger.ErrorContext( + ctx, + fmt.Errorf("unable to process account: %w", err).Error(), + "accountId", + account.GetId(), + ) return fmt.Errorf("unable to process account: %w", err) } } @@ -117,10 +138,17 @@ func run(ctx context.Context) error { return nil } -func getOtelConfig(ctx context.Context, otelconfig neosyncotel.OtelEnvConfig, logger *slog.Logger) (interceptors []connect.Interceptor, shutdown func(context.Context) error, err error) { +func getOtelConfig( + ctx context.Context, + otelconfig neosyncotel.OtelEnvConfig, + logger *slog.Logger, +) (interceptors []connect.Interceptor, shutdown func(context.Context) error, err error) { logger.DebugContext(ctx, "otel is enabled") tmPropagator := neosyncotel.NewDefaultPropagator() - otelconnopts := []otelconnect.Option{otelconnect.WithoutServerPeerAttributes(), otelconnect.WithPropagator(tmPropagator)} + otelconnopts := []otelconnect.Option{ + otelconnect.WithoutServerPeerAttributes(), + otelconnect.WithPropagator(tmPropagator), + } meterProviders := []neosyncotel.MeterProvider{} traceProviders := []neosyncotel.TracerProvider{} @@ -187,14 +215,17 @@ func processAccount( logger *slog.Logger, ) error { logger.DebugContext(ctx, "retrieving daily metric count") - resp, err := metricsclient.GetDailyMetricCount(ctx, connect.NewRequest(&mgmtv1alpha1.GetDailyMetricCountRequest{ - Metric: mgmtv1alpha1.RangedMetricName_RANGED_METRIC_NAME_INPUT_RECEIVED, - Start: ingestdate, - End: ingestdate, - Identifier: &mgmtv1alpha1.GetDailyMetricCountRequest_AccountId{ - AccountId: account.GetId(), - }, - })) + resp, err := metricsclient.GetDailyMetricCount( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetDailyMetricCountRequest{ + Metric: mgmtv1alpha1.RangedMetricName_RANGED_METRIC_NAME_INPUT_RECEIVED, + Start: ingestdate, + End: ingestdate, + Identifier: &mgmtv1alpha1.GetDailyMetricCountRequest_AccountId{ + AccountId: account.GetId(), + }, + }), + ) if err != nil { return fmt.Errorf("unable to get daily metric count: %w", err) } @@ -206,13 +237,19 @@ func processAccount( if recordCount > 0 { ts := getEventTimestamp(ingestdate) logger.DebugContext(ctx, "record count was greater than 0, creating meter event") - _, err := userclient.SetBillingMeterEvent(ctx, connect.NewRequest(&mgmtv1alpha1.SetBillingMeterEventRequest{ - AccountId: account.GetId(), - EventName: meterName, - Value: strconv.FormatUint(recordCount, 10), - EventId: withSuffix(getEventId(account.GetId(), ingestdate), eventIdSuffix), - Timestamp: &ts, - })) + _, err := userclient.SetBillingMeterEvent( + ctx, + connect.NewRequest(&mgmtv1alpha1.SetBillingMeterEventRequest{ + AccountId: account.GetId(), + EventName: meterName, + Value: strconv.FormatUint(recordCount, 10), + EventId: withSuffix( + getEventId(account.GetId(), ingestdate), + eventIdSuffix, + ), + Timestamp: &ts, + }), + ) if err != nil { return fmt.Errorf("unable to set billing meter event: %w", err) } @@ -226,9 +263,19 @@ func processAccount( func getEventTimestamp(date *mgmtv1alpha1.Date) uint64 { now := time.Now().UTC() - inputDate := time.Date(int(date.GetYear()), time.Month(date.GetMonth()), int(date.GetDay()), 12, 0, 0, 0, time.UTC) - - if inputDate.Year() == now.Year() && inputDate.Month() == now.Month() && inputDate.Day() == now.Day() { + inputDate := time.Date( + int(date.GetYear()), + time.Month(date.GetMonth()), + int(date.GetDay()), + 12, + 0, + 0, + 0, + time.UTC, + ) + + if inputDate.Year() == now.Year() && inputDate.Month() == now.Month() && + inputDate.Day() == now.Day() { // If the input date is today, use the current time as Stripe does not allow timestamps more than 5min into the future return uint64(now.Unix()) //nolint:gosec } diff --git a/backend/internal/cmds/mgmt/serve/connect/cmd.go b/backend/internal/cmds/mgmt/serve/connect/cmd.go index 2c8d11ae07..f95a3e7cb1 100644 --- a/backend/internal/cmds/mgmt/serve/connect/cmd.go +++ b/backend/internal/cmds/mgmt/serve/connect/cmd.go @@ -122,7 +122,9 @@ func serve(ctx context.Context) error { slogger = slogger.With("nucleusEnv", neoEnv) } - slog.SetDefault(slogger) // set default logger for methods that can't easily access the configured logger + slog.SetDefault( + slogger, + ) // set default logger for methods that can't easily access the configured logger eelicense, err := license.NewFromEnv() if err != nil { @@ -216,7 +218,11 @@ func serve(ctx context.Context) error { if err != nil { return err } - slogger.Debug("DB_AUTO_MIGRATE is enabled, running migrations...", "migrationDir", schemaDir) + slogger.Debug( + "DB_AUTO_MIGRATE is enabled, running migrations...", + "migrationDir", + schemaDir, + ) if err := neomigrate.Up( ctx, neosyncdb.GetDbUrl(dbMigConfig), @@ -260,7 +266,10 @@ func serve(ctx context.Context) error { if otelconfig.IsEnabled { slogger.Debug("otel is enabled") tmPropagator := neosyncotel.NewDefaultPropagator() - otelconnopts := []otelconnect.Option{otelconnect.WithoutServerPeerAttributes(), otelconnect.WithPropagator(tmPropagator)} + otelconnopts := []otelconnect.Option{ + otelconnect.WithoutServerPeerAttributes(), + otelconnect.WithPropagator(tmPropagator), + } traceProviders := []neosyncotel.TracerProvider{} meterProviders := []neosyncotel.MeterProvider{} @@ -283,14 +292,19 @@ func serve(ctx context.Context) error { otelconnopts = append(otelconnopts, otelconnect.WithoutMetrics()) } - anonymizeMeterProvider, err := neosyncotel.NewMeterProvider(ctx, &neosyncotel.MeterProviderConfig{ - Exporter: otelconfig.MeterExporter, - AppVersion: otelconfig.ServiceVersion, - Opts: neosyncotel.MeterExporterOpts{ - Otlp: []otlpmetricgrpc.Option{neosyncotel.WithDefaultDeltaTemporalitySelector()}, - Console: []stdoutmetric.Option{stdoutmetric.WithPrettyPrint()}, + anonymizeMeterProvider, err := neosyncotel.NewMeterProvider( + ctx, + &neosyncotel.MeterProviderConfig{ + Exporter: otelconfig.MeterExporter, + AppVersion: otelconfig.ServiceVersion, + Opts: neosyncotel.MeterExporterOpts{ + Otlp: []otlpmetricgrpc.Option{ + neosyncotel.WithDefaultDeltaTemporalitySelector(), + }, + Console: []stdoutmetric.Option{stdoutmetric.WithPrettyPrint()}, + }, }, - }) + ) if err != nil { return err } @@ -332,7 +346,9 @@ func serve(ctx context.Context) error { }) defer func() { if err := otelshutdown(context.Background()); err != nil { - slogger.Error(fmt.Errorf("unable to gracefully shutdown otel providers: %w", err).Error()) + slogger.Error( + fmt.Errorf("unable to gracefully shutdown otel providers: %w", err).Error(), + ) } }() } @@ -377,24 +393,29 @@ func serve(ctx context.Context) error { if err != nil { return err } - apikeyClient := auth_apikey.New(db.Q, db.Db, getAllowedWorkerApiKeys(ncloudlicense.IsValid()), []string{ - mgmtv1alpha1connect.JobServiceGetJobProcedure, - mgmtv1alpha1connect.JobServiceGetRunContextProcedure, - mgmtv1alpha1connect.JobServiceSetRunContextProcedure, - mgmtv1alpha1connect.JobServiceSetRunContextsProcedure, - mgmtv1alpha1connect.ConnectionServiceGetConnectionProcedure, - mgmtv1alpha1connect.TransformersServiceGetUserDefinedTransformerByIdProcedure, - mgmtv1alpha1connect.ConnectionDataServiceGetConnectionInitStatementsProcedure, - mgmtv1alpha1connect.UserAccountServiceIsAccountStatusValidProcedure, - mgmtv1alpha1connect.UserAccountServiceGetBillingAccountsProcedure, - mgmtv1alpha1connect.UserAccountServiceSetBillingMeterEventProcedure, - mgmtv1alpha1connect.MetricsServiceGetDailyMetricCountProcedure, - mgmtv1alpha1connect.AnonymizationServiceAnonymizeManyProcedure, - mgmtv1alpha1connect.JobServiceGetActiveJobHooksByTimingProcedure, - mgmtv1alpha1connect.AccountHookServiceGetActiveAccountHooksByEventProcedure, - mgmtv1alpha1connect.AccountHookServiceGetAccountHookProcedure, - mgmtv1alpha1connect.AccountHookServiceSendSlackMessageProcedure, - }) + apikeyClient := auth_apikey.New( + db.Q, + db.Db, + getAllowedWorkerApiKeys(ncloudlicense.IsValid()), + []string{ + mgmtv1alpha1connect.JobServiceGetJobProcedure, + mgmtv1alpha1connect.JobServiceGetRunContextProcedure, + mgmtv1alpha1connect.JobServiceSetRunContextProcedure, + mgmtv1alpha1connect.JobServiceSetRunContextsProcedure, + mgmtv1alpha1connect.ConnectionServiceGetConnectionProcedure, + mgmtv1alpha1connect.TransformersServiceGetUserDefinedTransformerByIdProcedure, + mgmtv1alpha1connect.ConnectionDataServiceGetConnectionInitStatementsProcedure, + mgmtv1alpha1connect.UserAccountServiceIsAccountStatusValidProcedure, + mgmtv1alpha1connect.UserAccountServiceGetBillingAccountsProcedure, + mgmtv1alpha1connect.UserAccountServiceSetBillingMeterEventProcedure, + mgmtv1alpha1connect.MetricsServiceGetDailyMetricCountProcedure, + mgmtv1alpha1connect.AnonymizationServiceAnonymizeManyProcedure, + mgmtv1alpha1connect.JobServiceGetActiveJobHooksByTimingProcedure, + mgmtv1alpha1connect.AccountHookServiceGetActiveAccountHooksByEventProcedure, + mgmtv1alpha1connect.AccountHookServiceGetAccountHookProcedure, + mgmtv1alpha1connect.AccountHookServiceSendSlackMessageProcedure, + }, + ) stdAuthInterceptors = append( stdAuthInterceptors, auth_interceptor.NewInterceptor( @@ -527,20 +548,30 @@ func serve(ctx context.Context) error { accountHookOptions := []accounthooks.Option{accounthooks.WithAppBaseUrl(getAppBaseUrl())} var slackClient ee_slack.Interface if viper.GetBool("SLACK_ACCOUNT_HOOKS_ENABLED") { - encryptor, err := sym_encrypt.NewEncryptor(viper.GetString("NEOSYNC_SYM_ENCRYPTION_PASSWORD")) + encryptor, err := sym_encrypt.NewEncryptor( + viper.GetString("NEOSYNC_SYM_ENCRYPTION_PASSWORD"), + ) if err != nil { return err } slackClient = ee_slack.NewClient( encryptor, - ee_slack.WithAuthClientCreds(viper.GetString("SLACK_AUTH_CLIENT_ID"), viper.GetString("SLACK_AUTH_CLIENT_SECRET")), + ee_slack.WithAuthClientCreds( + viper.GetString("SLACK_AUTH_CLIENT_ID"), + viper.GetString("SLACK_AUTH_CLIENT_SECRET"), + ), ee_slack.WithScope(viper.GetString("SLACK_SCOPE")), ee_slack.WithRedirectUrl(viper.GetString("SLACK_REDIRECT_URL")), ) - accountHookOptions = append(accountHookOptions, accounthooks.WithSlackClient(slackClient)) + accountHookOptions = append( + accountHookOptions, + accounthooks.WithSlackClient(slackClient), + ) } - accountHookService := v1alpha1_accounthookservice.New(accounthooks.New(db, userdataclient, accountHookOptions...)) + accountHookService := v1alpha1_accounthookservice.New( + accounthooks.New(db, userdataclient, accountHookOptions...), + ) api.Handle( mgmtv1alpha1connect.NewAccountHookServiceHandler( @@ -779,7 +810,11 @@ func getPromClientFromEnvironment() (promapi.Client, error) { roundTripper := promapi.DefaultRoundTripper promApiKey := getPromApiKey() if promApiKey != nil { - roundTripper = promconfig.NewAuthorizationCredentialsRoundTripper("Bearer", promconfig.NewInlineSecret(*promApiKey), promapi.DefaultRoundTripper) + roundTripper = promconfig.NewAuthorizationCredentialsRoundTripper( + "Bearer", + promconfig.NewInlineSecret(*promApiKey), + promapi.DefaultRoundTripper, + ) } return promapi.NewClient(promapi.Config{ Address: getPromApiUrl(), @@ -1044,7 +1079,11 @@ func getAllowedWorkerApiKeys(isNeosyncCloud bool) []string { return []string{} } -func getAuthAdminClient(ctx context.Context, authclient auth_client.Interface, logger *slog.Logger) (authmgmt.Interface, error) { +func getAuthAdminClient( + ctx context.Context, + authclient auth_client.Interface, + logger *slog.Logger, +) (authmgmt.Interface, error) { authApiBaseUrl := getAuthApiBaseUrl() authApiClientId := getAuthApiClientId() authApiClientSecret := getAuthApiClientSecret() @@ -1057,10 +1096,21 @@ func getAuthAdminClient(ctx context.Context, authclient auth_client.Interface, l if err != nil { return nil, err } - tokenProvider := clientcredtokenprovider.New(tokenurl, authApiClientId, authApiClientSecret, keycloak.DefaultTokenExpirationBuffer, logger) + tokenProvider := clientcredtokenprovider.New( + tokenurl, + authApiClientId, + authApiClientSecret, + keycloak.DefaultTokenExpirationBuffer, + logger, + ) return keycloak.New(authApiBaseUrl, tokenProvider, logger) } - logger.Warn(fmt.Sprintf("unable to initialize auth admin client due to unsupported provider: %q", provider)) + logger.Warn( + fmt.Sprintf( + "unable to initialize auth admin client due to unsupported provider: %q", + provider, + ), + ) return &authmgmt.UnimplementedClient{}, nil } @@ -1123,7 +1173,9 @@ func getRunLogConfig() (*v1alpha1_jobservice.RunLogConfig, error) { case v1alpha1_jobservice.LokiRunLogType: lokibaseurl := viper.GetString("RUN_LOGS_LOKICONFIG_BASEURL") if lokibaseurl == "" { - return nil, errors.New("must provide loki baseurl when loki run log type has been configured") + return nil, errors.New( + "must provide loki baseurl when loki run log type has been configured", + ) } labelsQuery := viper.GetString("RUN_LOGS_LOKICONFIG_LABELSQUERY") if labelsQuery == "" { @@ -1140,7 +1192,9 @@ func getRunLogConfig() (*v1alpha1_jobservice.RunLogConfig, error) { }, }, nil default: - return nil, errors.New("unsupported or no run log type configured, but run logs are enabled") + return nil, errors.New( + "unsupported or no run log type configured, but run logs are enabled", + ) } } @@ -1209,7 +1263,11 @@ func getStripePriceLookupMap() (billing.PriceQuantity, error) { } quantity, err := strconv.Atoi(v) if err != nil { - return nil, fmt.Errorf("unable to parse value as int for billing quantity %q: %w", v, err) + return nil, fmt.Errorf( + "unable to parse value as int for billing quantity %q: %w", + v, + err, + ) } output[k] = quantity } @@ -1239,7 +1297,10 @@ func getPresidioAnonymizeClient() (*presidioapi.ClientWithResponses, bool, error func getPresidioClient(endpoint string) (*presidioapi.ClientWithResponses, bool, error) { httpclient := http_client.WithHeaders(&http.Client{}, getPresidioHttpHeaders()) - client, err := presidioapi.NewClientWithResponses(endpoint, presidioapi.WithHTTPClient(httpclient)) + client, err := presidioapi.NewClientWithResponses( + endpoint, + presidioapi.WithHTTPClient(httpclient), + ) if err != nil { return nil, false, err } diff --git a/backend/internal/connect/interceptors/accountid/interceptor.go b/backend/internal/connect/interceptors/accountid/interceptor.go index a241f02204..c1a02ab5b4 100644 --- a/backend/internal/connect/interceptors/accountid/interceptor.go +++ b/backend/internal/connect/interceptors/accountid/interceptor.go @@ -34,13 +34,17 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { } } -func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { +func (i *Interceptor) WrapStreamingClient( + next connect.StreamingClientFunc, +) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { return next(ctx, spec) } } -func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (i *Interceptor) WrapStreamingHandler( + next connect.StreamingHandlerFunc, +) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { return next(ctx, conn) } diff --git a/backend/internal/connect/interceptors/auth/interceptor.go b/backend/internal/connect/interceptors/auth/interceptor.go index b5956e6f5b..b855b8dfdd 100644 --- a/backend/internal/connect/interceptors/auth/interceptor.go +++ b/backend/internal/connect/interceptors/auth/interceptor.go @@ -40,13 +40,17 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { } } -func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { +func (i *Interceptor) WrapStreamingClient( + next connect.StreamingClientFunc, +) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { return next(ctx, spec) } } -func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (i *Interceptor) WrapStreamingHandler( + next connect.StreamingHandlerFunc, +) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { if _, ok := i.excludedProcedures[conn.Spec().Procedure]; ok { return next(ctx, conn) diff --git a/backend/internal/connect/interceptors/auth_logging/interceptor.go b/backend/internal/connect/interceptors/auth_logging/interceptor.go index 64017bec2d..972d6913a6 100644 --- a/backend/internal/connect/interceptors/auth_logging/interceptor.go +++ b/backend/internal/connect/interceptors/auth_logging/interceptor.go @@ -23,13 +23,17 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { } } -func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { +func (i *Interceptor) WrapStreamingClient( + next connect.StreamingClientFunc, +) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { return next(ctx, spec) } } -func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (i *Interceptor) WrapStreamingHandler( + next connect.StreamingHandlerFunc, +) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { return next(setAuthValues(ctx, i.db), conn) } diff --git a/backend/internal/connect/interceptors/bookend/interceptor.go b/backend/internal/connect/interceptors/bookend/interceptor.go index c6617e7cec..ed602b2f58 100644 --- a/backend/internal/connect/interceptors/bookend/interceptor.go +++ b/backend/internal/connect/interceptors/bookend/interceptor.go @@ -130,13 +130,17 @@ func getCliAttr(header http.Header) *slog.Attr { return &cliGroup } -func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { +func (i *Interceptor) WrapStreamingClient( + next connect.StreamingClientFunc, +) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { return next(ctx, spec) } } -func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (i *Interceptor) WrapStreamingHandler( + next connect.StreamingHandlerFunc, +) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("procedure", conn.Spec().Procedure) diff --git a/backend/internal/connect/interceptors/logger/interceptor.go b/backend/internal/connect/interceptors/logger/interceptor.go index 76ca1ef246..04c7d790d0 100644 --- a/backend/internal/connect/interceptors/logger/interceptor.go +++ b/backend/internal/connect/interceptors/logger/interceptor.go @@ -24,13 +24,17 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { } } -func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { +func (i *Interceptor) WrapStreamingClient( + next connect.StreamingClientFunc, +) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { return next(ctx, spec) } } -func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (i *Interceptor) WrapStreamingHandler( + next connect.StreamingHandlerFunc, +) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { newCtx := SetLoggerContext(ctx, clonelogger(i.logger)) return next(newCtx, conn) diff --git a/backend/internal/dtomaps/job-runs.go b/backend/internal/dtomaps/job-runs.go index cb8e39e088..bea7b36d95 100644 --- a/backend/internal/dtomaps/job-runs.go +++ b/backend/internal/dtomaps/job-runs.go @@ -26,7 +26,10 @@ func ToJobRunDto( } // returns a job run without any pending activities -func ToJobRunDtoFromWorkflowExecutionInfo(workflow *workflowpb.WorkflowExecutionInfo, logger *slog.Logger) *mgmtv1alpha1.JobRun { +func ToJobRunDtoFromWorkflowExecutionInfo( + workflow *workflowpb.WorkflowExecutionInfo, + logger *slog.Logger, +) *mgmtv1alpha1.JobRun { var completedTime *timestamppb.Timestamp if workflow.GetCloseTime() != nil { completedTime = workflow.GetCloseTime() @@ -53,7 +56,10 @@ func GetJobIdFromWorkflow(logger *slog.Logger, searchAttributes *commonpb.Search return scheduledByID } -func ToJobRunEventTaskDto(event *history.HistoryEvent, taskError *mgmtv1alpha1.JobRunEventTaskError) *mgmtv1alpha1.JobRunEventTask { +func ToJobRunEventTaskDto( + event *history.HistoryEvent, + taskError *mgmtv1alpha1.JobRunEventTaskError, +) *mgmtv1alpha1.JobRunEventTask { return &mgmtv1alpha1.JobRunEventTask{ Id: event.GetEventId(), Type: event.GetEventType().String(), @@ -62,7 +68,10 @@ func ToJobRunEventTaskDto(event *history.HistoryEvent, taskError *mgmtv1alpha1.J } } -func ToJobRunEventTaskErrorDto(failure *temporalfailure.Failure, retryState enums.RetryState) *mgmtv1alpha1.JobRunEventTaskError { +func ToJobRunEventTaskErrorDto( + failure *temporalfailure.Failure, + retryState enums.RetryState, +) *mgmtv1alpha1.JobRunEventTaskError { msg := failure.Message if failure.GetCause() != nil { msg = fmt.Sprintf("%s: %s", failure.GetMessage(), failure.GetCause().GetMessage()) @@ -73,7 +82,9 @@ func ToJobRunEventTaskErrorDto(failure *temporalfailure.Failure, retryState enum } } -func toPendingActivitiesDto(activities []*workflowpb.PendingActivityInfo) []*mgmtv1alpha1.PendingActivity { +func toPendingActivitiesDto( + activities []*workflowpb.PendingActivityInfo, +) []*mgmtv1alpha1.PendingActivity { dtos := []*mgmtv1alpha1.PendingActivity{} for _, activity := range activities { var lastFailure *mgmtv1alpha1.ActivityFailure diff --git a/backend/internal/dtomaps/jobs.go b/backend/internal/dtomaps/jobs.go index 788dca2e14..92a4cb4446 100644 --- a/backend/internal/dtomaps/jobs.go +++ b/backend/internal/dtomaps/jobs.go @@ -42,7 +42,8 @@ func ToJobDto( } jobTypeConfig := &mgmtv1alpha1.JobTypeConfig{} - if inputJob.JobtypeConfig != nil && string(inputJob.JobtypeConfig) != "{}" && string(inputJob.JobtypeConfig) != "null" { + if inputJob.JobtypeConfig != nil && string(inputJob.JobtypeConfig) != "{}" && + string(inputJob.JobtypeConfig) != "null" { err := json.Unmarshal(inputJob.JobtypeConfig, jobTypeConfig) if err != nil { return nil, fmt.Errorf("unable to unmarshal job type config: %w", err) @@ -70,7 +71,9 @@ func ToJobDto( }, nil } -func toDestinationDto(input *db_queries.NeosyncApiJobDestinationConnectionAssociation) *mgmtv1alpha1.JobDestination { +func toDestinationDto( + input *db_queries.NeosyncApiJobDestinationConnectionAssociation, +) *mgmtv1alpha1.JobDestination { return &mgmtv1alpha1.JobDestination{ ConnectionId: neosyncdb.UUIDString(input.ConnectionID), Options: input.Options.ToDto(), @@ -85,7 +88,9 @@ func ToJobStatus(inputSchedule *temporalclient.ScheduleDescription) mgmtv1alpha1 return mgmtv1alpha1.JobStatus_JOB_STATUS_ENABLED } -func ToJobRecentRunsDto(inputSchedule *temporalclient.ScheduleDescription) []*mgmtv1alpha1.JobRecentRun { +func ToJobRecentRunsDto( + inputSchedule *temporalclient.ScheduleDescription, +) []*mgmtv1alpha1.JobRecentRun { recentRuns := []*mgmtv1alpha1.JobRecentRun{} if inputSchedule == nil { return nil diff --git a/backend/internal/dtomaps/transformers.go b/backend/internal/dtomaps/transformers.go index d84858e506..94ec2a4a8d 100644 --- a/backend/internal/dtomaps/transformers.go +++ b/backend/internal/dtomaps/transformers.go @@ -19,7 +19,10 @@ func ToUserDefinedTransformerDto( source := mgmtv1alpha1.TransformerSource(input.Source) transformer, ok := systemTransformers[source] if !ok { - return nil, fmt.Errorf("source %d is valid, but was not found in system transformers map", input.Source) + return nil, fmt.Errorf( + "source %d is valid, but was not found in system transformers map", + input.Source, + ) } return &mgmtv1alpha1.UserDefinedTransformer{ Id: neosyncdb.UUIDString(input.ID), diff --git a/backend/internal/ee/hooks/accounts/service.go b/backend/internal/ee/hooks/accounts/service.go index 667ec3ca97..8de39dc5ae 100644 --- a/backend/internal/ee/hooks/accounts/service.go +++ b/backend/internal/ee/hooks/accounts/service.go @@ -30,18 +30,54 @@ type Service struct { var _ Interface = (*Service)(nil) type Interface interface { - GetAccountHooks(ctx context.Context, req *mgmtv1alpha1.GetAccountHooksRequest) (*mgmtv1alpha1.GetAccountHooksResponse, error) - GetAccountHook(ctx context.Context, req *mgmtv1alpha1.GetAccountHookRequest) (*mgmtv1alpha1.GetAccountHookResponse, error) - CreateAccountHook(ctx context.Context, req *mgmtv1alpha1.CreateAccountHookRequest) (*mgmtv1alpha1.CreateAccountHookResponse, error) - UpdateAccountHook(ctx context.Context, req *mgmtv1alpha1.UpdateAccountHookRequest) (*mgmtv1alpha1.UpdateAccountHookResponse, error) - DeleteAccountHook(ctx context.Context, req *mgmtv1alpha1.DeleteAccountHookRequest) (*mgmtv1alpha1.DeleteAccountHookResponse, error) - IsAccountHookNameAvailable(ctx context.Context, req *mgmtv1alpha1.IsAccountHookNameAvailableRequest) (*mgmtv1alpha1.IsAccountHookNameAvailableResponse, error) - SetAccountHookEnabled(ctx context.Context, req *mgmtv1alpha1.SetAccountHookEnabledRequest) (*mgmtv1alpha1.SetAccountHookEnabledResponse, error) - GetActiveAccountHooksByEvent(ctx context.Context, req *mgmtv1alpha1.GetActiveAccountHooksByEventRequest) (*mgmtv1alpha1.GetActiveAccountHooksByEventResponse, error) - GetSlackConnectionUrl(ctx context.Context, req *mgmtv1alpha1.GetSlackConnectionUrlRequest) (*mgmtv1alpha1.GetSlackConnectionUrlResponse, error) - HandleSlackOAuthCallback(ctx context.Context, req *mgmtv1alpha1.HandleSlackOAuthCallbackRequest) (*mgmtv1alpha1.HandleSlackOAuthCallbackResponse, error) - TestSlackConnection(ctx context.Context, req *mgmtv1alpha1.TestSlackConnectionRequest) (*mgmtv1alpha1.TestSlackConnectionResponse, error) - SendSlackMessage(ctx context.Context, req *mgmtv1alpha1.SendSlackMessageRequest) (*mgmtv1alpha1.SendSlackMessageResponse, error) + GetAccountHooks( + ctx context.Context, + req *mgmtv1alpha1.GetAccountHooksRequest, + ) (*mgmtv1alpha1.GetAccountHooksResponse, error) + GetAccountHook( + ctx context.Context, + req *mgmtv1alpha1.GetAccountHookRequest, + ) (*mgmtv1alpha1.GetAccountHookResponse, error) + CreateAccountHook( + ctx context.Context, + req *mgmtv1alpha1.CreateAccountHookRequest, + ) (*mgmtv1alpha1.CreateAccountHookResponse, error) + UpdateAccountHook( + ctx context.Context, + req *mgmtv1alpha1.UpdateAccountHookRequest, + ) (*mgmtv1alpha1.UpdateAccountHookResponse, error) + DeleteAccountHook( + ctx context.Context, + req *mgmtv1alpha1.DeleteAccountHookRequest, + ) (*mgmtv1alpha1.DeleteAccountHookResponse, error) + IsAccountHookNameAvailable( + ctx context.Context, + req *mgmtv1alpha1.IsAccountHookNameAvailableRequest, + ) (*mgmtv1alpha1.IsAccountHookNameAvailableResponse, error) + SetAccountHookEnabled( + ctx context.Context, + req *mgmtv1alpha1.SetAccountHookEnabledRequest, + ) (*mgmtv1alpha1.SetAccountHookEnabledResponse, error) + GetActiveAccountHooksByEvent( + ctx context.Context, + req *mgmtv1alpha1.GetActiveAccountHooksByEventRequest, + ) (*mgmtv1alpha1.GetActiveAccountHooksByEventResponse, error) + GetSlackConnectionUrl( + ctx context.Context, + req *mgmtv1alpha1.GetSlackConnectionUrlRequest, + ) (*mgmtv1alpha1.GetSlackConnectionUrlResponse, error) + HandleSlackOAuthCallback( + ctx context.Context, + req *mgmtv1alpha1.HandleSlackOAuthCallbackRequest, + ) (*mgmtv1alpha1.HandleSlackOAuthCallbackResponse, error) + TestSlackConnection( + ctx context.Context, + req *mgmtv1alpha1.TestSlackConnectionRequest, + ) (*mgmtv1alpha1.TestSlackConnectionResponse, error) + SendSlackMessage( + ctx context.Context, + req *mgmtv1alpha1.SendSlackMessageRequest, + ) (*mgmtv1alpha1.SendSlackMessageResponse, error) } type config struct { @@ -78,7 +114,10 @@ func New( return &Service{cfg: cfg, db: db, userdataclient: userdataclient} } -func (s *Service) GetAccountHooks(ctx context.Context, req *mgmtv1alpha1.GetAccountHooksRequest) (*mgmtv1alpha1.GetAccountHooksResponse, error) { +func (s *Service) GetAccountHooks( + ctx context.Context, + req *mgmtv1alpha1.GetAccountHooksRequest, +) (*mgmtv1alpha1.GetAccountHooksResponse, error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("accountId", req.GetAccountId()) @@ -115,7 +154,10 @@ func (s *Service) GetAccountHooks(ctx context.Context, req *mgmtv1alpha1.GetAcco }, nil } -func (s *Service) GetAccountHook(ctx context.Context, req *mgmtv1alpha1.GetAccountHookRequest) (*mgmtv1alpha1.GetAccountHookResponse, error) { +func (s *Service) GetAccountHook( + ctx context.Context, + req *mgmtv1alpha1.GetAccountHookRequest, +) (*mgmtv1alpha1.GetAccountHookResponse, error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("hookId", req.GetId()) @@ -151,7 +193,10 @@ func (s *Service) GetAccountHook(ctx context.Context, req *mgmtv1alpha1.GetAccou }, nil } -func (s *Service) DeleteAccountHook(ctx context.Context, req *mgmtv1alpha1.DeleteAccountHookRequest) (*mgmtv1alpha1.DeleteAccountHookResponse, error) { +func (s *Service) DeleteAccountHook( + ctx context.Context, + req *mgmtv1alpha1.DeleteAccountHookRequest, +) (*mgmtv1alpha1.DeleteAccountHookResponse, error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("hookId", req.GetId()) @@ -184,7 +229,10 @@ func (s *Service) DeleteAccountHook(ctx context.Context, req *mgmtv1alpha1.Delet return &mgmtv1alpha1.DeleteAccountHookResponse{}, nil } -func (s *Service) IsAccountHookNameAvailable(ctx context.Context, req *mgmtv1alpha1.IsAccountHookNameAvailableRequest) (*mgmtv1alpha1.IsAccountHookNameAvailableResponse, error) { +func (s *Service) IsAccountHookNameAvailable( + ctx context.Context, + req *mgmtv1alpha1.IsAccountHookNameAvailableRequest, +) (*mgmtv1alpha1.IsAccountHookNameAvailableResponse, error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("hookName", req.GetName(), "accountId", req.GetAccountId()) @@ -202,10 +250,14 @@ func (s *Service) IsAccountHookNameAvailable(ctx context.Context, req *mgmtv1alp } logger.Debug("checking if hook name is available") - ok, err := s.db.Q.IsAccountHookNameAvailable(ctx, s.db.Db, db_queries.IsAccountHookNameAvailableParams{ - AccountID: accountId, - Name: req.GetName(), - }) + ok, err := s.db.Q.IsAccountHookNameAvailable( + ctx, + s.db.Db, + db_queries.IsAccountHookNameAvailableParams{ + AccountID: accountId, + Name: req.GetName(), + }, + ) if err != nil { return nil, err } @@ -214,7 +266,10 @@ func (s *Service) IsAccountHookNameAvailable(ctx context.Context, req *mgmtv1alp }, nil } -func (s *Service) SetAccountHookEnabled(ctx context.Context, req *mgmtv1alpha1.SetAccountHookEnabledRequest) (*mgmtv1alpha1.SetAccountHookEnabledResponse, error) { +func (s *Service) SetAccountHookEnabled( + ctx context.Context, + req *mgmtv1alpha1.SetAccountHookEnabledRequest, +) (*mgmtv1alpha1.SetAccountHookEnabledResponse, error) { getResp, err := s.GetAccountHook(ctx, &mgmtv1alpha1.GetAccountHookRequest{Id: req.GetId()}) if err != nil { return nil, err @@ -243,12 +298,22 @@ func (s *Service) SetAccountHookEnabled(ctx context.Context, req *mgmtv1alpha1.S return nil, err } - logger.Debug(fmt.Sprintf("attempting to update account hook enabled status from %v to %v", getResp.GetHook().GetEnabled(), req.GetEnabled())) - updatedHook, err := s.db.Q.SetAccountHookEnabled(ctx, s.db.Db, db_queries.SetAccountHookEnabledParams{ - ID: hookuuid, - Enabled: req.GetEnabled(), - UpdatedByUserID: user.PgId(), - }) + logger.Debug( + fmt.Sprintf( + "attempting to update account hook enabled status from %v to %v", + getResp.GetHook().GetEnabled(), + req.GetEnabled(), + ), + ) + updatedHook, err := s.db.Q.SetAccountHookEnabled( + ctx, + s.db.Db, + db_queries.SetAccountHookEnabledParams{ + ID: hookuuid, + Enabled: req.GetEnabled(), + UpdatedByUserID: user.PgId(), + }, + ) if err != nil { return nil, err } @@ -263,7 +328,10 @@ func (s *Service) SetAccountHookEnabled(ctx context.Context, req *mgmtv1alpha1.S }, nil } -func (s *Service) GetActiveAccountHooksByEvent(ctx context.Context, req *mgmtv1alpha1.GetActiveAccountHooksByEventRequest) (*mgmtv1alpha1.GetActiveAccountHooksByEventResponse, error) { +func (s *Service) GetActiveAccountHooksByEvent( + ctx context.Context, + req *mgmtv1alpha1.GetActiveAccountHooksByEventRequest, +) (*mgmtv1alpha1.GetActiveAccountHooksByEventResponse, error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("event", req.GetEvent()) @@ -292,10 +360,14 @@ func (s *Service) GetActiveAccountHooksByEvent(ctx context.Context, req *mgmtv1a } logger.Debug(fmt.Sprintf("searching for active account hooks by events %v", eventStrings)) - hooks, err := s.db.Q.GetActiveAccountHooksByEvent(ctx, s.db.Db, db_queries.GetActiveAccountHooksByEventParams{ - AccountID: accountId, - Events: validEvents, - }) + hooks, err := s.db.Q.GetActiveAccountHooksByEvent( + ctx, + s.db.Db, + db_queries.GetActiveAccountHooksByEventParams{ + AccountID: accountId, + Events: validEvents, + }, + ) if err != nil { return nil, err } @@ -310,7 +382,10 @@ func (s *Service) GetActiveAccountHooksByEvent(ctx context.Context, req *mgmtv1a }, nil } -func (s *Service) CreateAccountHook(ctx context.Context, req *mgmtv1alpha1.CreateAccountHookRequest) (*mgmtv1alpha1.CreateAccountHookResponse, error) { +func (s *Service) CreateAccountHook( + ctx context.Context, + req *mgmtv1alpha1.CreateAccountHookRequest, +) (*mgmtv1alpha1.CreateAccountHookResponse, error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("accountId", req.GetAccountId()) @@ -370,7 +445,11 @@ func (s *Service) CreateAccountHook(ctx context.Context, req *mgmtv1alpha1.Creat }, nil } -func (s *Service) joinSlackChannel(ctx context.Context, hook *mgmtv1alpha1.AccountHook, logger *slog.Logger) { +func (s *Service) joinSlackChannel( + ctx context.Context, + hook *mgmtv1alpha1.AccountHook, + logger *slog.Logger, +) { defer func() { if r := recover(); r != nil { logger.Error("panic when attempting to join slack channel", "error", r) @@ -384,12 +463,20 @@ func (s *Service) joinSlackChannel(ctx context.Context, hook *mgmtv1alpha1.Accou channelId := slackConfig.GetChannelId() accountId, err := neosyncdb.ToUuid(hook.GetAccountId()) if err != nil { - logger.Error("unable to parse account id when attempting to join slack channel", "error", err) + logger.Error( + "unable to parse account id when attempting to join slack channel", + "error", + err, + ) return } accessToken, err := s.db.Q.GetSlackAccessToken(ctx, s.db.Db, accountId) if err != nil { - logger.Error("unable to get slack access token when attempting to join slack channel", "error", err) + logger.Error( + "unable to get slack access token when attempting to join slack channel", + "error", + err, + ) return } @@ -401,7 +488,10 @@ func (s *Service) joinSlackChannel(ctx context.Context, hook *mgmtv1alpha1.Accou logger.Debug("joined slack channel") } -func (s *Service) UpdateAccountHook(ctx context.Context, req *mgmtv1alpha1.UpdateAccountHookRequest) (*mgmtv1alpha1.UpdateAccountHookResponse, error) { +func (s *Service) UpdateAccountHook( + ctx context.Context, + req *mgmtv1alpha1.UpdateAccountHookRequest, +) (*mgmtv1alpha1.UpdateAccountHookResponse, error) { getResp, err := s.GetAccountHook(ctx, &mgmtv1alpha1.GetAccountHookRequest{Id: req.GetId()}) if err != nil { return nil, err @@ -481,7 +571,9 @@ func (s *Service) GetSlackConnectionUrl( req *mgmtv1alpha1.GetSlackConnectionUrlRequest, ) (*mgmtv1alpha1.GetSlackConnectionUrlResponse, error) { if !s.cfg.isSlackEnabled { - return nil, nucleuserrors.NewNotImplementedProcedure(mgmtv1alpha1connect.AccountHookServiceGetSlackConnectionUrlProcedure) + return nil, nucleuserrors.NewNotImplementedProcedure( + mgmtv1alpha1connect.AccountHookServiceGetSlackConnectionUrlProcedure, + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) @@ -511,7 +603,9 @@ func (s *Service) HandleSlackOAuthCallback( req *mgmtv1alpha1.HandleSlackOAuthCallbackRequest, ) (*mgmtv1alpha1.HandleSlackOAuthCallbackResponse, error) { if !s.cfg.isSlackEnabled { - return nil, nucleuserrors.NewNotImplementedProcedure(mgmtv1alpha1connect.AccountHookServiceHandleSlackOAuthCallbackProcedure) + return nil, nucleuserrors.NewNotImplementedProcedure( + mgmtv1alpha1connect.AccountHookServiceHandleSlackOAuthCallbackProcedure, + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) @@ -521,20 +615,25 @@ func (s *Service) HandleSlackOAuthCallback( return nil, fmt.Errorf("unable to get user: %w", err) } - oauthState, err := s.cfg.slackClient.ValidateState(ctx, req.GetState(), user.Id(), func(ctx context.Context, userId, accountId string) (bool, error) { - parsedAccountUuid, err := neosyncdb.ToUuid(accountId) - if err != nil { - return false, err - } - ok, err := s.db.Q.IsUserInAccount(ctx, s.db.Db, db_queries.IsUserInAccountParams{ - AccountId: parsedAccountUuid, - UserId: user.PgId(), - }) - if err != nil { - return false, fmt.Errorf("unable to check if user is in account: %w", err) - } - return ok != 0, nil - }) + oauthState, err := s.cfg.slackClient.ValidateState( + ctx, + req.GetState(), + user.Id(), + func(ctx context.Context, userId, accountId string) (bool, error) { + parsedAccountUuid, err := neosyncdb.ToUuid(accountId) + if err != nil { + return false, err + } + ok, err := s.db.Q.IsUserInAccount(ctx, s.db.Db, db_queries.IsUserInAccountParams{ + AccountId: parsedAccountUuid, + UserId: user.PgId(), + }) + if err != nil { + return false, fmt.Errorf("unable to check if user is in account: %w", err) + } + return ok != 0, nil + }, + ) if err != nil { return nil, fmt.Errorf("unable to validate slack oauth state: %w", err) } @@ -560,12 +659,16 @@ func (s *Service) HandleSlackOAuthCallback( return nil, fmt.Errorf("unable to convert account id to uuid: %w", err) } - _, err = s.db.Q.CreateSlackOAuthConnection(ctx, s.db.Db, db_queries.CreateSlackOAuthConnectionParams{ - AccountID: accountUuid, - OauthV2Response: oauthRespBytes, - CreatedByUserID: user.PgId(), - UpdatedByUserID: user.PgId(), - }) + _, err = s.db.Q.CreateSlackOAuthConnection( + ctx, + s.db.Db, + db_queries.CreateSlackOAuthConnectionParams{ + AccountID: accountUuid, + OauthV2Response: oauthRespBytes, + CreatedByUserID: user.PgId(), + UpdatedByUserID: user.PgId(), + }, + ) if err != nil { return nil, fmt.Errorf("unable to store slack access token: %w", err) } @@ -578,7 +681,9 @@ func (s *Service) TestSlackConnection( req *mgmtv1alpha1.TestSlackConnectionRequest, ) (*mgmtv1alpha1.TestSlackConnectionResponse, error) { if !s.cfg.isSlackEnabled { - return nil, nucleuserrors.NewNotImplementedProcedure(mgmtv1alpha1connect.AccountHookServiceTestSlackConnectionProcedure) + return nil, nucleuserrors.NewNotImplementedProcedure( + mgmtv1alpha1connect.AccountHookServiceTestSlackConnectionProcedure, + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) @@ -643,7 +748,9 @@ func (s *Service) SendSlackMessage( req *mgmtv1alpha1.SendSlackMessageRequest, ) (*mgmtv1alpha1.SendSlackMessageResponse, error) { if !s.cfg.isSlackEnabled { - return nil, nucleuserrors.NewNotImplementedProcedure(mgmtv1alpha1connect.AccountHookServiceSendSlackMessageProcedure) + return nil, nucleuserrors.NewNotImplementedProcedure( + mgmtv1alpha1connect.AccountHookServiceSendSlackMessageProcedure, + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) @@ -705,7 +812,12 @@ func (s *Service) SendSlackMessage( } logger.Debug("sending slack message") - err = s.cfg.slackClient.SendMessage(ctx, accessToken, slackChannelId, slack.MsgOptionBlocks(blocks...)) + err = s.cfg.slackClient.SendMessage( + ctx, + accessToken, + slackChannelId, + slack.MsgOptionBlocks(blocks...), + ) if err != nil { return nil, fmt.Errorf("unable to send slack message: %w", err) } @@ -714,18 +826,32 @@ func (s *Service) SendSlackMessage( } func buildJobIdUrlForSlack(appBaseUrl, accountName, jobId string) string { - return fmt.Sprintf("<%s/jobs/%s|%s>", buildAccountBaseUrl(appBaseUrl, accountName), jobId, jobId) + return fmt.Sprintf( + "<%s/jobs/%s|%s>", + buildAccountBaseUrl(appBaseUrl, accountName), + jobId, + jobId, + ) } func buildJobRunUrlForSlack(appBaseUrl, accountName, jobRunId string) string { - return fmt.Sprintf("<%s/runs/%s|%s>", buildAccountBaseUrl(appBaseUrl, accountName), jobRunId, jobRunId) + return fmt.Sprintf( + "<%s/runs/%s|%s>", + buildAccountBaseUrl(appBaseUrl, accountName), + jobRunId, + jobRunId, + ) } func buildAccountBaseUrl(appBaseUrl, accountName string) string { return fmt.Sprintf("%s/%s", appBaseUrl, accountName) } -func getSlackBlocksByEvent(event *accounthook_events.Event, appBaseUrl, accountName string, logger *slog.Logger) []slack.Block { +func getSlackBlocksByEvent( + event *accounthook_events.Event, + appBaseUrl, accountName string, + logger *slog.Logger, +) []slack.Block { switch event.Name { case mgmtv1alpha1.AccountHookEvent_ACCOUNT_HOOK_EVENT_JOB_RUN_CREATED: if event.JobRunCreated == nil { @@ -733,13 +859,43 @@ func getSlackBlocksByEvent(event *accounthook_events.Event, appBaseUrl, accountN return nil } - headerText := slack.NewTextBlockObject(slack.PlainTextType, "🚀 New Job Run Started", false, false) + headerText := slack.NewTextBlockObject( + slack.PlainTextType, + "🚀 New Job Run Started", + false, + false, + ) headerSection := slack.NewHeaderBlock(headerText) jobFields := []*slack.TextBlockObject{ - slack.NewTextBlockObject(slack.MarkdownType, fmt.Sprintf("*Job ID:*\n%s", buildJobIdUrlForSlack(appBaseUrl, accountName, event.JobRunCreated.JobId)), false, false), - slack.NewTextBlockObject(slack.MarkdownType, fmt.Sprintf("*Job Run ID:*\n%s", buildJobRunUrlForSlack(appBaseUrl, accountName, event.JobRunCreated.JobRunId)), false, false), - slack.NewTextBlockObject(slack.MarkdownType, fmt.Sprintf("*Started At:*\n", event.Timestamp.Unix(), event.Timestamp.Format(time.RFC3339)), false, false), + slack.NewTextBlockObject( + slack.MarkdownType, + fmt.Sprintf( + "*Job ID:*\n%s", + buildJobIdUrlForSlack(appBaseUrl, accountName, event.JobRunCreated.JobId), + ), + false, + false, + ), + slack.NewTextBlockObject( + slack.MarkdownType, + fmt.Sprintf( + "*Job Run ID:*\n%s", + buildJobRunUrlForSlack(appBaseUrl, accountName, event.JobRunCreated.JobRunId), + ), + false, + false, + ), + slack.NewTextBlockObject( + slack.MarkdownType, + fmt.Sprintf( + "*Started At:*\n", + event.Timestamp.Unix(), + event.Timestamp.Format(time.RFC3339), + ), + false, + false, + ), } fieldsSection := slack.NewSectionBlock(nil, jobFields, nil) @@ -759,13 +915,43 @@ func getSlackBlocksByEvent(event *accounthook_events.Event, appBaseUrl, accountN return nil } - headerText := slack.NewTextBlockObject(slack.PlainTextType, "🔴 Job Run Failed", false, false) + headerText := slack.NewTextBlockObject( + slack.PlainTextType, + "🔴 Job Run Failed", + false, + false, + ) headerSection := slack.NewHeaderBlock(headerText) jobFields := []*slack.TextBlockObject{ - slack.NewTextBlockObject(slack.MarkdownType, fmt.Sprintf("*Job ID:*\n%s", buildJobIdUrlForSlack(appBaseUrl, accountName, event.JobRunFailed.JobId)), false, false), - slack.NewTextBlockObject(slack.MarkdownType, fmt.Sprintf("*Job Run ID:*\n%s", buildJobRunUrlForSlack(appBaseUrl, accountName, event.JobRunFailed.JobRunId)), false, false), - slack.NewTextBlockObject(slack.MarkdownType, fmt.Sprintf("*Failed At:*\n", event.Timestamp.Unix(), event.Timestamp.Format(time.RFC3339)), false, false), + slack.NewTextBlockObject( + slack.MarkdownType, + fmt.Sprintf( + "*Job ID:*\n%s", + buildJobIdUrlForSlack(appBaseUrl, accountName, event.JobRunFailed.JobId), + ), + false, + false, + ), + slack.NewTextBlockObject( + slack.MarkdownType, + fmt.Sprintf( + "*Job Run ID:*\n%s", + buildJobRunUrlForSlack(appBaseUrl, accountName, event.JobRunFailed.JobRunId), + ), + false, + false, + ), + slack.NewTextBlockObject( + slack.MarkdownType, + fmt.Sprintf( + "*Failed At:*\n", + event.Timestamp.Unix(), + event.Timestamp.Format(time.RFC3339), + ), + false, + false, + ), } fieldsSection := slack.NewSectionBlock(nil, jobFields, nil) @@ -785,13 +971,43 @@ func getSlackBlocksByEvent(event *accounthook_events.Event, appBaseUrl, accountN return nil } - headerText := slack.NewTextBlockObject(slack.PlainTextType, "✅ Job Run Succeeded", false, false) + headerText := slack.NewTextBlockObject( + slack.PlainTextType, + "✅ Job Run Succeeded", + false, + false, + ) headerSection := slack.NewHeaderBlock(headerText) jobFields := []*slack.TextBlockObject{ - slack.NewTextBlockObject(slack.MarkdownType, fmt.Sprintf("*Job ID:*\n%s", buildJobIdUrlForSlack(appBaseUrl, accountName, event.JobRunSucceeded.JobId)), false, false), - slack.NewTextBlockObject(slack.MarkdownType, fmt.Sprintf("*Job Run ID:*\n%s", buildJobRunUrlForSlack(appBaseUrl, accountName, event.JobRunSucceeded.JobRunId)), false, false), - slack.NewTextBlockObject(slack.MarkdownType, fmt.Sprintf("*Succeeded At:*\n", event.Timestamp.Unix(), event.Timestamp.Format(time.RFC3339)), false, false), + slack.NewTextBlockObject( + slack.MarkdownType, + fmt.Sprintf( + "*Job ID:*\n%s", + buildJobIdUrlForSlack(appBaseUrl, accountName, event.JobRunSucceeded.JobId), + ), + false, + false, + ), + slack.NewTextBlockObject( + slack.MarkdownType, + fmt.Sprintf( + "*Job Run ID:*\n%s", + buildJobRunUrlForSlack(appBaseUrl, accountName, event.JobRunSucceeded.JobRunId), + ), + false, + false, + ), + slack.NewTextBlockObject( + slack.MarkdownType, + fmt.Sprintf( + "*Succeeded At:*\n", + event.Timestamp.Unix(), + event.Timestamp.Format(time.RFC3339), + ), + false, + false, + ), } fieldsSection := slack.NewSectionBlock(nil, jobFields, nil) diff --git a/backend/internal/ee/hooks/jobs/service.go b/backend/internal/ee/hooks/jobs/service.go index 287b866907..9252af073d 100644 --- a/backend/internal/ee/hooks/jobs/service.go +++ b/backend/internal/ee/hooks/jobs/service.go @@ -27,14 +27,38 @@ type Service struct { var _ Interface = (*Service)(nil) type Interface interface { - GetJobHooks(ctx context.Context, req *mgmtv1alpha1.GetJobHooksRequest) (*mgmtv1alpha1.GetJobHooksResponse, error) - GetJobHook(ctx context.Context, req *mgmtv1alpha1.GetJobHookRequest) (*mgmtv1alpha1.GetJobHookResponse, error) - CreateJobHook(ctx context.Context, req *mgmtv1alpha1.CreateJobHookRequest) (*mgmtv1alpha1.CreateJobHookResponse, error) - DeleteJobHook(ctx context.Context, req *mgmtv1alpha1.DeleteJobHookRequest) (*mgmtv1alpha1.DeleteJobHookResponse, error) - IsJobHookNameAvailable(ctx context.Context, req *mgmtv1alpha1.IsJobHookNameAvailableRequest) (*mgmtv1alpha1.IsJobHookNameAvailableResponse, error) - UpdateJobHook(ctx context.Context, req *mgmtv1alpha1.UpdateJobHookRequest) (*mgmtv1alpha1.UpdateJobHookResponse, error) - SetJobHookEnabled(ctx context.Context, req *mgmtv1alpha1.SetJobHookEnabledRequest) (*mgmtv1alpha1.SetJobHookEnabledResponse, error) - GetActiveJobHooksByTiming(ctx context.Context, req *mgmtv1alpha1.GetActiveJobHooksByTimingRequest) (*mgmtv1alpha1.GetActiveJobHooksByTimingResponse, error) + GetJobHooks( + ctx context.Context, + req *mgmtv1alpha1.GetJobHooksRequest, + ) (*mgmtv1alpha1.GetJobHooksResponse, error) + GetJobHook( + ctx context.Context, + req *mgmtv1alpha1.GetJobHookRequest, + ) (*mgmtv1alpha1.GetJobHookResponse, error) + CreateJobHook( + ctx context.Context, + req *mgmtv1alpha1.CreateJobHookRequest, + ) (*mgmtv1alpha1.CreateJobHookResponse, error) + DeleteJobHook( + ctx context.Context, + req *mgmtv1alpha1.DeleteJobHookRequest, + ) (*mgmtv1alpha1.DeleteJobHookResponse, error) + IsJobHookNameAvailable( + ctx context.Context, + req *mgmtv1alpha1.IsJobHookNameAvailableRequest, + ) (*mgmtv1alpha1.IsJobHookNameAvailableResponse, error) + UpdateJobHook( + ctx context.Context, + req *mgmtv1alpha1.UpdateJobHookRequest, + ) (*mgmtv1alpha1.UpdateJobHookResponse, error) + SetJobHookEnabled( + ctx context.Context, + req *mgmtv1alpha1.SetJobHookEnabledRequest, + ) (*mgmtv1alpha1.SetJobHookEnabledResponse, error) + GetActiveJobHooksByTiming( + ctx context.Context, + req *mgmtv1alpha1.GetActiveJobHooksByTimingRequest, + ) (*mgmtv1alpha1.GetActiveJobHooksByTimingResponse, error) } type config struct { @@ -67,7 +91,9 @@ func (s *Service) GetJobHooks( req *mgmtv1alpha1.GetJobHooksRequest, ) (*mgmtv1alpha1.GetJobHooksResponse, error) { if !s.cfg.isEnabled { - return nil, nucleuserrors.NewNotImplementedProcedure(mgmtv1alpha1connect.JobServiceGetJobHooksProcedure) + return nil, nucleuserrors.NewNotImplementedProcedure( + mgmtv1alpha1connect.JobServiceGetJobHooksProcedure, + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) @@ -98,7 +124,9 @@ func (s *Service) GetJobHook( req *mgmtv1alpha1.GetJobHookRequest, ) (*mgmtv1alpha1.GetJobHookResponse, error) { if !s.cfg.isEnabled { - return nil, nucleuserrors.NewNotImplementedProcedure(mgmtv1alpha1connect.JobServiceGetJobHookProcedure) + return nil, nucleuserrors.NewNotImplementedProcedure( + mgmtv1alpha1connect.JobServiceGetJobHookProcedure, + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) @@ -116,7 +144,11 @@ func (s *Service) GetJobHook( return nil, nucleuserrors.NewNotFound("unable to find job hook by id") } - verifyResp, err := s.verifyUserHasJob(ctx, neosyncdb.UUIDString(hook.JobID), rbac.JobAction_View) + verifyResp, err := s.verifyUserHasJob( + ctx, + neosyncdb.UUIDString(hook.JobID), + rbac.JobAction_View, + ) if err != nil { return nil, err } @@ -139,7 +171,9 @@ func (s *Service) DeleteJobHook( req *mgmtv1alpha1.DeleteJobHookRequest, ) (*mgmtv1alpha1.DeleteJobHookResponse, error) { if !s.cfg.isEnabled { - return nil, nucleuserrors.NewNotImplementedProcedure(mgmtv1alpha1connect.JobServiceGetJobHooksProcedure) + return nil, nucleuserrors.NewNotImplementedProcedure( + mgmtv1alpha1connect.JobServiceGetJobHooksProcedure, + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) @@ -158,7 +192,11 @@ func (s *Service) DeleteJobHook( return &mgmtv1alpha1.DeleteJobHookResponse{}, nil } - verifyResp, err := s.verifyUserHasJob(ctx, neosyncdb.UUIDString(hook.JobID), rbac.JobAction_Delete) + verifyResp, err := s.verifyUserHasJob( + ctx, + neosyncdb.UUIDString(hook.JobID), + rbac.JobAction_Delete, + ) if err != nil { return nil, err } @@ -179,7 +217,9 @@ func (s *Service) IsJobHookNameAvailable( req *mgmtv1alpha1.IsJobHookNameAvailableRequest, ) (*mgmtv1alpha1.IsJobHookNameAvailableResponse, error) { if !s.cfg.isEnabled { - return nil, nucleuserrors.NewNotImplementedProcedure(mgmtv1alpha1connect.JobServiceIsJobHookNameAvailableProcedure) + return nil, nucleuserrors.NewNotImplementedProcedure( + mgmtv1alpha1connect.JobServiceIsJobHookNameAvailableProcedure, + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) @@ -213,7 +253,9 @@ func (s *Service) CreateJobHook( req *mgmtv1alpha1.CreateJobHookRequest, ) (*mgmtv1alpha1.CreateJobHookResponse, error) { if !s.cfg.isEnabled { - return nil, nucleuserrors.NewNotImplementedProcedure(mgmtv1alpha1connect.JobServiceCreateJobHookProcedure) + return nil, nucleuserrors.NewNotImplementedProcedure( + mgmtv1alpha1connect.JobServiceCreateJobHookProcedure, + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) @@ -245,7 +287,9 @@ func (s *Service) CreateJobHook( } if !isValid { logger.Debug("job hook creation did not pass connection id verification") - return nil, nucleuserrors.NewBadRequest("connection id specified in hook is not a part of job") + return nil, nucleuserrors.NewBadRequest( + "connection id specified in hook is not a part of job", + ) } config, err := hookReq.GetConfig().MarshalJSON() @@ -307,7 +351,9 @@ func (s *Service) UpdateJobHook( } if !isValid { logger.Debug("job hook creation did not pass connection id verification") - return nil, nucleuserrors.NewBadRequest("connection id specified in hook is not a part of job") + return nil, nucleuserrors.NewBadRequest( + "connection id specified in hook is not a part of job", + ) } config, err := req.GetConfig().MarshalJSON() @@ -382,7 +428,13 @@ func (s *Service) SetJobHookEnabled( return nil, err } - logger.Debug(fmt.Sprintf("attempting to update job hook enabled status from %v to %v", getResp.GetHook().GetEnabled(), req.GetEnabled())) + logger.Debug( + fmt.Sprintf( + "attempting to update job hook enabled status from %v to %v", + getResp.GetHook().GetEnabled(), + req.GetEnabled(), + ), + ) updatedHook, err := s.db.Q.SetJobHookEnabled(ctx, s.db.Db, db_queries.SetJobHookEnabledParams{ Enabled: req.GetEnabled(), UpdatedByUserID: verifyResp.UserUuid, @@ -405,7 +457,9 @@ func (s *Service) GetActiveJobHooksByTiming( req *mgmtv1alpha1.GetActiveJobHooksByTimingRequest, ) (*mgmtv1alpha1.GetActiveJobHooksByTimingResponse, error) { if !s.cfg.isEnabled { - return nil, nucleuserrors.NewNotImplementedProcedure(mgmtv1alpha1connect.JobServiceGetActiveJobHooksByTimingProcedure) + return nil, nucleuserrors.NewNotImplementedProcedure( + mgmtv1alpha1connect.JobServiceGetActiveJobHooksByTimingProcedure, + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) @@ -443,7 +497,9 @@ func (s *Service) GetActiveJobHooksByTiming( return nil, err } default: - return nil, nucleuserrors.NewBadRequest(fmt.Sprintf("invalid hook timing: %T", req.GetTiming())) + return nil, nucleuserrors.NewBadRequest( + fmt.Sprintf("invalid hook timing: %T", req.GetTiming()), + ) } logger.Debug(fmt.Sprintf("found %d hooks", len(hooks))) @@ -462,7 +518,11 @@ type verifyUserJobResponse struct { UserUuid pgtype.UUID } -func (s *Service) verifyUserHasJob(ctx context.Context, jobId string, permission rbac.JobAction) (*verifyUserJobResponse, error) { +func (s *Service) verifyUserHasJob( + ctx context.Context, + jobId string, + permission rbac.JobAction, +) (*verifyUserJobResponse, error) { jobuuid, err := neosyncdb.ToUuid(jobId) if err != nil { return nil, err diff --git a/backend/internal/loki/loki.go b/backend/internal/loki/loki.go index fbc658093b..1bcb2172bd 100644 --- a/backend/internal/loki/loki.go +++ b/backend/internal/loki/loki.go @@ -84,8 +84,18 @@ func (c *LokiClient) QueryRange( } if res.StatusCode > 399 { - logger.Error(fmt.Sprintf("received non 200 status code: %d when querying loki for logs", res.StatusCode), "body", string(body)) - return nil, fmt.Errorf("received non 200 status code for loki query_range: %d", res.StatusCode) + logger.Error( + fmt.Sprintf( + "received non 200 status code: %d when querying loki for logs", + res.StatusCode, + ), + "body", + string(body), + ) + return nil, fmt.Errorf( + "received non 200 status code for loki query_range: %d", + res.StatusCode, + ) } var typedResp QueryResponse @@ -114,13 +124,27 @@ func GetEntriesFromStreams(streams Streams) []*LabeledEntry { entries := []*LabeledEntry{} for _, stream := range streams { for _, entry := range stream.Entries { - entries = append(entries, &LabeledEntry{Entry: entry, Labels: getFilteredLabels(stream.Labels, allowedLabels)}) + entries = append( + entries, + &LabeledEntry{ + Entry: entry, + Labels: getFilteredLabels(stream.Labels, allowedLabels), + }, + ) } } return entries } -var allowedLabels = []string{"ActivityType", "Name", "Schema", "Table", "Attempt", "metadata_Schema", "metadata_Table"} +var allowedLabels = []string{ + "ActivityType", + "Name", + "Schema", + "Table", + "Attempt", + "metadata_Schema", + "metadata_Table", +} func getFilteredLabels(labels LabelSet, keepLabels []string) LabelSet { filteredLabels := LabelSet{} diff --git a/backend/internal/userdata/client.go b/backend/internal/userdata/client.go index 392794cb84..d2aaf21895 100644 --- a/backend/internal/userdata/client.go +++ b/backend/internal/userdata/client.go @@ -13,8 +13,14 @@ import ( ) type UserServiceClient interface { - GetUser(ctx context.Context, req *connect.Request[mgmtv1alpha1.GetUserRequest]) (*connect.Response[mgmtv1alpha1.GetUserResponse], error) - IsUserInAccount(ctx context.Context, req *connect.Request[mgmtv1alpha1.IsUserInAccountRequest]) (*connect.Response[mgmtv1alpha1.IsUserInAccountResponse], error) + GetUser( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.GetUserRequest], + ) (*connect.Response[mgmtv1alpha1.GetUserResponse], error) + IsUserInAccount( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.IsUserInAccountRequest], + ) (*connect.Response[mgmtv1alpha1.IsUserInAccountResponse], error) } type Client struct { @@ -44,7 +50,10 @@ func NewClient( } func (c *Client) GetUser(ctx context.Context) (*User, error) { - resp, err := c.userServiceClient.GetUser(ctx, connect.NewRequest(&mgmtv1alpha1.GetUserRequest{})) + resp, err := c.userServiceClient.GetUser( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetUserRequest{}), + ) if err != nil { return nil, fmt.Errorf("unable to get user: %w", err) } diff --git a/backend/internal/userdata/entity_enforcer.go b/backend/internal/userdata/entity_enforcer.go index 3aed31ecd6..0ac7306953 100644 --- a/backend/internal/userdata/entity_enforcer.go +++ b/backend/internal/userdata/entity_enforcer.go @@ -22,53 +22,105 @@ var _ EntityEnforcer = (*UserEntityEnforcer)(nil) type EntityEnforcer interface { EnforceJob(ctx context.Context, job DomainEntity, action rbac.JobAction) error Job(ctx context.Context, job DomainEntity, action rbac.JobAction) (bool, error) - EnforceConnection(ctx context.Context, connection DomainEntity, action rbac.ConnectionAction) error - Connection(ctx context.Context, connection DomainEntity, action rbac.ConnectionAction) (bool, error) + EnforceConnection( + ctx context.Context, + connection DomainEntity, + action rbac.ConnectionAction, + ) error + Connection( + ctx context.Context, + connection DomainEntity, + action rbac.ConnectionAction, + ) (bool, error) EnforceAccount(ctx context.Context, account Identifier, action rbac.AccountAction) error Account(ctx context.Context, account Identifier, action rbac.AccountAction) (bool, error) } -func (u *UserEntityEnforcer) EnforceJob(ctx context.Context, job DomainEntity, action rbac.JobAction) error { +func (u *UserEntityEnforcer) EnforceJob( + ctx context.Context, + job DomainEntity, + action rbac.JobAction, +) error { if err := u.enforceAccountAccess(ctx, job.GetAccountId()); err != nil { return err } if u.isApiKey { return nil } - return u.enforcer.EnforceJob(ctx, u.user, rbac.NewAccountIdEntity(job.GetAccountId()), rbac.NewJobIdEntity(job.GetId()), action) + return u.enforcer.EnforceJob( + ctx, + u.user, + rbac.NewAccountIdEntity(job.GetAccountId()), + rbac.NewJobIdEntity(job.GetId()), + action, + ) } -func (u *UserEntityEnforcer) Job(ctx context.Context, job DomainEntity, action rbac.JobAction) (bool, error) { +func (u *UserEntityEnforcer) Job( + ctx context.Context, + job DomainEntity, + action rbac.JobAction, +) (bool, error) { if err := u.enforceAccountAccess(ctx, job.GetAccountId()); err != nil { return false, err } if u.isApiKey { return true, nil } - return u.enforcer.Job(ctx, u.user, rbac.NewAccountIdEntity(job.GetAccountId()), rbac.NewJobIdEntity(job.GetId()), action) + return u.enforcer.Job( + ctx, + u.user, + rbac.NewAccountIdEntity(job.GetAccountId()), + rbac.NewJobIdEntity(job.GetId()), + action, + ) } -func (u *UserEntityEnforcer) EnforceConnection(ctx context.Context, connection DomainEntity, action rbac.ConnectionAction) error { +func (u *UserEntityEnforcer) EnforceConnection( + ctx context.Context, + connection DomainEntity, + action rbac.ConnectionAction, +) error { if err := u.enforceAccountAccess(ctx, connection.GetAccountId()); err != nil { return err } if u.isApiKey { return nil } - return u.enforcer.EnforceConnection(ctx, u.user, rbac.NewAccountIdEntity(connection.GetAccountId()), rbac.NewConnectionIdEntity(connection.GetId()), action) + return u.enforcer.EnforceConnection( + ctx, + u.user, + rbac.NewAccountIdEntity(connection.GetAccountId()), + rbac.NewConnectionIdEntity(connection.GetId()), + action, + ) } -func (u *UserEntityEnforcer) Connection(ctx context.Context, connection DomainEntity, action rbac.ConnectionAction) (bool, error) { +func (u *UserEntityEnforcer) Connection( + ctx context.Context, + connection DomainEntity, + action rbac.ConnectionAction, +) (bool, error) { if err := u.enforceAccountAccess(ctx, connection.GetAccountId()); err != nil { return false, err } if u.isApiKey { return true, nil } - return u.enforcer.Connection(ctx, u.user, rbac.NewAccountIdEntity(connection.GetAccountId()), rbac.NewConnectionIdEntity(connection.GetId()), action) + return u.enforcer.Connection( + ctx, + u.user, + rbac.NewAccountIdEntity(connection.GetAccountId()), + rbac.NewConnectionIdEntity(connection.GetId()), + action, + ) } -func (u *UserEntityEnforcer) EnforceAccount(ctx context.Context, account Identifier, action rbac.AccountAction) error { +func (u *UserEntityEnforcer) EnforceAccount( + ctx context.Context, + account Identifier, + action rbac.AccountAction, +) error { if err := u.enforceAccountAccess(ctx, account.GetId()); err != nil { return err } @@ -78,7 +130,11 @@ func (u *UserEntityEnforcer) EnforceAccount(ctx context.Context, account Identif return u.enforcer.EnforceAccount(ctx, u.user, rbac.NewAccountIdEntity(account.GetId()), action) } -func (u *UserEntityEnforcer) Account(ctx context.Context, account Identifier, action rbac.AccountAction) (bool, error) { +func (u *UserEntityEnforcer) Account( + ctx context.Context, + account Identifier, + action rbac.AccountAction, +) (bool, error) { if err := u.enforceAccountAccess(ctx, account.GetId()); err != nil { return false, err } diff --git a/backend/internal/userdata/user.go b/backend/internal/userdata/user.go index 9948f97648..6e554c3110 100644 --- a/backend/internal/userdata/user.go +++ b/backend/internal/userdata/user.go @@ -15,7 +15,10 @@ import ( ) type UserAccountServiceClient interface { - IsUserInAccount(ctx context.Context, req *connect.Request[mgmtv1alpha1.IsUserInAccountRequest]) (*connect.Response[mgmtv1alpha1.IsUserInAccountResponse], error) + IsUserInAccount( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.IsUserInAccountRequest], + ) (*connect.Response[mgmtv1alpha1.IsUserInAccountResponse], error) } type User struct { @@ -84,13 +87,17 @@ func enforceAccountAccess(ctx context.Context, user *User, accountId string) err } // We first want to check to make sure the api key is valid and that it says it's in the account // However, we still want to make a DB request to ensure the DB still says it's in the account - if user.apiKeyData.ApiKey == nil || neosyncdb.UUIDString(user.apiKeyData.ApiKey.AccountID) != accountId { + if user.apiKeyData.ApiKey == nil || + neosyncdb.UUIDString(user.apiKeyData.ApiKey.AccountID) != accountId { return nucleuserrors.NewUnauthorized("api key is not valid for account") } } // Note: because we are calling to the user account service here, the ctx must still contain the user data - inAccountResp, err := user.userAccountServiceClient.IsUserInAccount(ctx, connect.NewRequest(&mgmtv1alpha1.IsUserInAccountRequest{AccountId: accountId})) + inAccountResp, err := user.userAccountServiceClient.IsUserInAccount( + ctx, + connect.NewRequest(&mgmtv1alpha1.IsUserInAccountRequest{AccountId: accountId}), + ) if err != nil { return fmt.Errorf("unable to check if user is in account: %w", err) } diff --git a/backend/internal/version/version.go b/backend/internal/version/version.go index 1b642043ae..1d40e46a85 100644 --- a/backend/internal/version/version.go +++ b/backend/internal/version/version.go @@ -7,11 +7,11 @@ import ( type VersionInfo struct { GitVersion string `json:"gitVersion" yaml:"gitVersion"` - GitCommit string `json:"gitCommit" yaml:"gitCommit"` - BuildDate string `json:"buildDate" yaml:"buildDate"` - GoVersion string `json:"goVersion" yaml:"goVersion"` - Compiler string `json:"compiler" yaml:"compiler"` - Platform string `json:"platform" yaml:"platform"` + GitCommit string `json:"gitCommit" yaml:"gitCommit"` + BuildDate string `json:"buildDate" yaml:"buildDate"` + GoVersion string `json:"goVersion" yaml:"goVersion"` + Compiler string `json:"compiler" yaml:"compiler"` + Platform string `json:"platform" yaml:"platform"` } func (info *VersionInfo) String() string { diff --git a/backend/pkg/clienttls/clienttls.go b/backend/pkg/clienttls/clienttls.go index de39ce9e39..d5e7206269 100644 --- a/backend/pkg/clienttls/clienttls.go +++ b/backend/pkg/clienttls/clienttls.go @@ -22,7 +22,9 @@ type ClientTlsFileConfig struct { type ClientTlsFileHandler func(config *mgmtv1alpha1.ClientTlsConfig) (*ClientTlsFileConfig, error) // Joins the client cert and key into a single file -func UpsertClientTlsFileSingleClient(config *mgmtv1alpha1.ClientTlsConfig) (*ClientTlsFileConfig, error) { +func UpsertClientTlsFileSingleClient( + config *mgmtv1alpha1.ClientTlsConfig, +) (*ClientTlsFileConfig, error) { if config == nil { return nil, errors.New("config was nil") } diff --git a/backend/pkg/integration-test/clients.go b/backend/pkg/integration-test/clients.go index 99ce4a3e91..e0d353ba79 100644 --- a/backend/pkg/integration-test/clients.go +++ b/backend/pkg/integration-test/clients.go @@ -29,17 +29,23 @@ func WithUserId(userId string) ClientConfigOption { } } -func (s *NeosyncClients) Users(opts ...ClientConfigOption) mgmtv1alpha1connect.UserAccountServiceClient { +func (s *NeosyncClients) Users( + opts ...ClientConfigOption, +) mgmtv1alpha1connect.UserAccountServiceClient { config := getHydratedClientConfig(opts...) return mgmtv1alpha1connect.NewUserAccountServiceClient(getHttpClient(config), s.httpUrl) } -func (s *NeosyncClients) Connections(opts ...ClientConfigOption) mgmtv1alpha1connect.ConnectionServiceClient { +func (s *NeosyncClients) Connections( + opts ...ClientConfigOption, +) mgmtv1alpha1connect.ConnectionServiceClient { config := getHydratedClientConfig(opts...) return mgmtv1alpha1connect.NewConnectionServiceClient(getHttpClient(config), s.httpUrl) } -func (s *NeosyncClients) Anonymize(opts ...ClientConfigOption) mgmtv1alpha1connect.AnonymizationServiceClient { +func (s *NeosyncClients) Anonymize( + opts ...ClientConfigOption, +) mgmtv1alpha1connect.AnonymizationServiceClient { config := getHydratedClientConfig(opts...) return mgmtv1alpha1connect.NewAnonymizationServiceClient(getHttpClient(config), s.httpUrl) } @@ -49,17 +55,23 @@ func (s *NeosyncClients) Jobs(opts ...ClientConfigOption) mgmtv1alpha1connect.Jo return mgmtv1alpha1connect.NewJobServiceClient(getHttpClient(config), s.httpUrl) } -func (s *NeosyncClients) Transformers(opts ...ClientConfigOption) mgmtv1alpha1connect.TransformersServiceClient { +func (s *NeosyncClients) Transformers( + opts ...ClientConfigOption, +) mgmtv1alpha1connect.TransformersServiceClient { config := getHydratedClientConfig(opts...) return mgmtv1alpha1connect.NewTransformersServiceClient(getHttpClient(config), s.httpUrl) } -func (s *NeosyncClients) ConnectionData(opts ...ClientConfigOption) mgmtv1alpha1connect.ConnectionDataServiceClient { +func (s *NeosyncClients) ConnectionData( + opts ...ClientConfigOption, +) mgmtv1alpha1connect.ConnectionDataServiceClient { config := getHydratedClientConfig(opts...) return mgmtv1alpha1connect.NewConnectionDataServiceClient(getHttpClient(config), s.httpUrl) } -func (s *NeosyncClients) AccountHooks(opts ...ClientConfigOption) mgmtv1alpha1connect.AccountHookServiceClient { +func (s *NeosyncClients) AccountHooks( + opts ...ClientConfigOption, +) mgmtv1alpha1connect.AccountHookServiceClient { config := getHydratedClientConfig(opts...) return mgmtv1alpha1connect.NewAccountHookServiceClient(getHttpClient(config), s.httpUrl) } diff --git a/backend/pkg/integration-test/integration-test-util.go b/backend/pkg/integration-test/integration-test-util.go index bb7899508e..6b8cbb1b96 100644 --- a/backend/pkg/integration-test/integration-test-util.go +++ b/backend/pkg/integration-test/integration-test-util.go @@ -15,7 +15,10 @@ func CreatePersonalAccount( t *testing.T, userclient mgmtv1alpha1connect.UserAccountServiceClient, ) string { - resp, err := userclient.SetPersonalAccount(ctx, connect.NewRequest(&mgmtv1alpha1.SetPersonalAccountRequest{})) + resp, err := userclient.SetPersonalAccount( + ctx, + connect.NewRequest(&mgmtv1alpha1.SetPersonalAccountRequest{}), + ) RequireNoErrResp(t, resp, err) return resp.Msg.AccountId } @@ -186,14 +189,26 @@ func CreateMongodbConnection( return resp.Msg.GetConnection() } -func SetUser(ctx context.Context, t *testing.T, client mgmtv1alpha1connect.UserAccountServiceClient) string { +func SetUser( + ctx context.Context, + t *testing.T, + client mgmtv1alpha1connect.UserAccountServiceClient, +) string { resp, err := client.SetUser(ctx, connect.NewRequest(&mgmtv1alpha1.SetUserRequest{})) RequireNoErrResp(t, resp, err) return resp.Msg.GetUserId() } -func CreateTeamAccount(ctx context.Context, t *testing.T, client mgmtv1alpha1connect.UserAccountServiceClient, name string) string { - resp, err := client.CreateTeamAccount(ctx, connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: name})) +func CreateTeamAccount( + ctx context.Context, + t *testing.T, + client mgmtv1alpha1connect.UserAccountServiceClient, + name string, +) string { + resp, err := client.CreateTeamAccount( + ctx, + connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: name}), + ) RequireNoErrResp(t, resp, err) return resp.Msg.AccountId } diff --git a/backend/pkg/integration-test/integration-test.go b/backend/pkg/integration-test/integration-test.go index 6b0fae786f..d1fcc2c65f 100644 --- a/backend/pkg/integration-test/integration-test.go +++ b/backend/pkg/integration-test/integration-test.go @@ -73,7 +73,11 @@ type NeosyncApiTestClient struct { // Option is a functional option for configuring Neosync Api Test Client type Option func(*NeosyncApiTestClient) -func NewNeosyncApiTestClient(ctx context.Context, t testing.TB, opts ...Option) (*NeosyncApiTestClient, error) { +func NewNeosyncApiTestClient( + ctx context.Context, + t testing.TB, + opts ...Option, +) (*NeosyncApiTestClient, error) { neoApi := &NeosyncApiTestClient{ migrationsDir: "../../../../sql/postgresql/schema", } @@ -133,25 +137,37 @@ func (s *NeosyncApiTestClient) Setup(ctx context.Context, t testing.TB) error { if err != nil { return fmt.Errorf("unable to setup oss unauthenticated licensed mux: %w", err) } - rootmux.Handle(openSourceUnauthenticatedLicensedPostfix+"/", http.StripPrefix(openSourceUnauthenticatedLicensedPostfix, ossUnauthLicensedMux)) + rootmux.Handle( + openSourceUnauthenticatedLicensedPostfix+"/", + http.StripPrefix(openSourceUnauthenticatedLicensedPostfix, ossUnauthLicensedMux), + ) ossAuthLicensedMux, err := s.setupOssLicensedAuthMux(ctx, pgcontainer, logger) if err != nil { return fmt.Errorf("unable to setup oss authenticated licensed mux: %w", err) } - rootmux.Handle(openSourceAuthenticatedLicensedPostfix+"/", http.StripPrefix(openSourceAuthenticatedLicensedPostfix, ossAuthLicensedMux)) + rootmux.Handle( + openSourceAuthenticatedLicensedPostfix+"/", + http.StripPrefix(openSourceAuthenticatedLicensedPostfix, ossAuthLicensedMux), + ) ossUnauthUnlicensedMux, err := s.setupOssUnlicensedMux(pgcontainer, logger) if err != nil { return fmt.Errorf("unable to setup oss unauthenticated unlicensed mux: %w", err) } - rootmux.Handle(openSourceUnauthenticatedUnlicensedPostfix+"/", http.StripPrefix(openSourceUnauthenticatedUnlicensedPostfix, ossUnauthUnlicensedMux)) + rootmux.Handle( + openSourceUnauthenticatedUnlicensedPostfix+"/", + http.StripPrefix(openSourceUnauthenticatedUnlicensedPostfix, ossUnauthUnlicensedMux), + ) neoCloudAuthdMux, err := s.setupNeoCloudMux(ctx, pgcontainer, logger) if err != nil { return fmt.Errorf("unable to setup neo cloud authenticated mux: %w", err) } - rootmux.Handle(neoCloudAuthenticatedLicensedPostfix+"/", http.StripPrefix(neoCloudAuthenticatedLicensedPostfix, neoCloudAuthdMux)) + rootmux.Handle( + neoCloudAuthenticatedLicensedPostfix+"/", + http.StripPrefix(neoCloudAuthenticatedLicensedPostfix, neoCloudAuthdMux), + ) s.httpsrv = startHTTPServer(t, rootmux) rootmux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { @@ -159,10 +175,18 @@ func (s *NeosyncApiTestClient) Setup(ctx context.Context, t testing.TB) error { http.NotFound(w, r) }) - s.OSSUnauthenticatedLicensedClients = newNeosyncClients(s.httpsrv.URL + openSourceUnauthenticatedLicensedPostfix) - s.OSSAuthenticatedLicensedClients = newNeosyncClients(s.httpsrv.URL + openSourceAuthenticatedLicensedPostfix) - s.OSSUnauthenticatedUnlicensedClients = newNeosyncClients(s.httpsrv.URL + openSourceUnauthenticatedUnlicensedPostfix) - s.NeosyncCloudAuthenticatedLicensedClients = newNeosyncClients(s.httpsrv.URL + neoCloudAuthenticatedLicensedPostfix) + s.OSSUnauthenticatedLicensedClients = newNeosyncClients( + s.httpsrv.URL + openSourceUnauthenticatedLicensedPostfix, + ) + s.OSSAuthenticatedLicensedClients = newNeosyncClients( + s.httpsrv.URL + openSourceAuthenticatedLicensedPostfix, + ) + s.OSSUnauthenticatedUnlicensedClients = newNeosyncClients( + s.httpsrv.URL + openSourceUnauthenticatedUnlicensedPostfix, + ) + s.NeosyncCloudAuthenticatedLicensedClients = newNeosyncClients( + s.httpsrv.URL + neoCloudAuthenticatedLicensedPostfix, + ) return nil } @@ -189,8 +213,11 @@ func (s *NeosyncApiTestClient) MockTemporalForCreateJob(returnId string) { } // Used for any API call that uses GetJobRun() as this mocks the response from Temporal for that execution -func (s *NeosyncApiTestClient) MockTemporalForDescribeWorkflowExecution(accountId, jobId, jobRunId, workflowName string) { - s.Mocks.TemporalClientManager.EXPECT().DescribeWorklowExecution(mock.Anything, accountId, jobRunId, mock.Anything). +func (s *NeosyncApiTestClient) MockTemporalForDescribeWorkflowExecution( + accountId, jobId, jobRunId, workflowName string, +) { + s.Mocks.TemporalClientManager.EXPECT(). + DescribeWorklowExecution(mock.Anything, accountId, jobRunId, mock.Anything). Return(&workflowservice.DescribeWorkflowExecutionResponse{ WorkflowExecutionInfo: &workflow.WorkflowExecutionInfo{ Execution: &common.WorkflowExecution{ @@ -205,13 +232,16 @@ func (s *NeosyncApiTestClient) MockTemporalForDescribeWorkflowExecution(accountI SearchAttributes: &common.SearchAttributes{ IndexedFields: map[string]*common.Payload{ "TemporalScheduledById": { - Data: []byte(jobId), - Metadata: map[string][]byte{"jobId": []byte(jobId)}, // this doesnt seem to work as it's not the correct format for what temporal expects + Data: []byte(jobId), + Metadata: map[string][]byte{ + "jobId": []byte(jobId), + }, // this doesnt seem to work as it's not the correct format for what temporal expects }, }, }, }, - }, nil).Once() + }, nil). + Once() } func (s *NeosyncApiTestClient) InitializeTest(ctx context.Context, t testing.TB) error { err := neomigrate.Up(ctx, s.Pgcontainer.URL, s.migrationsDir, testutil.GetTestLogger(t)) diff --git a/backend/pkg/integration-test/mux.go b/backend/pkg/integration-test/mux.go index 870b2f144f..d3e59e1042 100644 --- a/backend/pkg/integration-test/mux.go +++ b/backend/pkg/integration-test/mux.go @@ -46,24 +46,26 @@ import ( var ( validAuthUser = &authmgmt.User{Name: "foo", Email: "bar", Picture: "baz"} - authinterceptor = auth_interceptor.NewInterceptor(func(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) { - // will need to further fill this out as the tests grow - authuserid, err := utils.GetBearerTokenFromHeader(header, "Authorization") - if err != nil { - return nil, err - } - if apikey.IsValidV1WorkerKey(authuserid) { - return auth_apikey.SetTokenData(ctx, &auth_apikey.TokenContextData{ - RawToken: authuserid, - ApiKey: nil, - ApiKeyType: apikey.WorkerApiKey, + authinterceptor = auth_interceptor.NewInterceptor( + func(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) { + // will need to further fill this out as the tests grow + authuserid, err := utils.GetBearerTokenFromHeader(header, "Authorization") + if err != nil { + return nil, err + } + if apikey.IsValidV1WorkerKey(authuserid) { + return auth_apikey.SetTokenData(ctx, &auth_apikey.TokenContextData{ + RawToken: authuserid, + ApiKey: nil, + ApiKeyType: apikey.WorkerApiKey, + }), nil + } + return auth_jwt.SetTokenData(ctx, &auth_jwt.TokenContextData{ + AuthUserId: authuserid, + Claims: &auth_jwt.CustomClaims{Email: &validAuthUser.Email}, }), nil - } - return auth_jwt.SetTokenData(ctx, &auth_jwt.TokenContextData{ - AuthUserId: authuserid, - Claims: &auth_jwt.CustomClaims{Email: &validAuthUser.Email}, - }), nil - }) + }, + ) ) const ( @@ -77,7 +79,11 @@ const ( neoCloudAuthenticatedLicensedPostfix = "/neosynccloud-authenticated" ) -func (s *NeosyncApiTestClient) setupOssUnauthenticatedLicensedMux(ctx context.Context, pgcontainer *tcpostgres.PostgresTestContainer, logger *slog.Logger) (*http.ServeMux, error) { +func (s *NeosyncApiTestClient) setupOssUnauthenticatedLicensedMux( + ctx context.Context, + pgcontainer *tcpostgres.PostgresTestContainer, + logger *slog.Logger, +) (*http.ServeMux, error) { isLicensed := true isAuthEnabled := false isNeosyncCloud := false @@ -85,10 +91,21 @@ func (s *NeosyncApiTestClient) setupOssUnauthenticatedLicensedMux(ctx context.Co if err != nil { return nil, fmt.Errorf("unable to get enforced rbac client: %w", err) } - return s.setupMux(pgcontainer, isAuthEnabled, isLicensed, isNeosyncCloud, enforcedRbacClient, logger) + return s.setupMux( + pgcontainer, + isAuthEnabled, + isLicensed, + isNeosyncCloud, + enforcedRbacClient, + logger, + ) } -func (s *NeosyncApiTestClient) setupOssLicensedAuthMux(ctx context.Context, pgcontainer *tcpostgres.PostgresTestContainer, logger *slog.Logger) (*http.ServeMux, error) { +func (s *NeosyncApiTestClient) setupOssLicensedAuthMux( + ctx context.Context, + pgcontainer *tcpostgres.PostgresTestContainer, + logger *slog.Logger, +) (*http.ServeMux, error) { isLicensed := true isAuthEnabled := true isNeosyncCloud := false @@ -96,18 +113,39 @@ func (s *NeosyncApiTestClient) setupOssLicensedAuthMux(ctx context.Context, pgco if err != nil { return nil, fmt.Errorf("unable to get enforced rbac client: %w", err) } - return s.setupMux(pgcontainer, isAuthEnabled, isLicensed, isNeosyncCloud, enforcedRbacClient, logger) + return s.setupMux( + pgcontainer, + isAuthEnabled, + isLicensed, + isNeosyncCloud, + enforcedRbacClient, + logger, + ) } -func (s *NeosyncApiTestClient) setupOssUnlicensedMux(pgcontainer *tcpostgres.PostgresTestContainer, logger *slog.Logger) (*http.ServeMux, error) { +func (s *NeosyncApiTestClient) setupOssUnlicensedMux( + pgcontainer *tcpostgres.PostgresTestContainer, + logger *slog.Logger, +) (*http.ServeMux, error) { isLicensed := false isAuthEnabled := false isNeosyncCloud := false permissiveRbacClient := rbac.NewAllowAllClient() - return s.setupMux(pgcontainer, isAuthEnabled, isLicensed, isNeosyncCloud, permissiveRbacClient, logger) + return s.setupMux( + pgcontainer, + isAuthEnabled, + isLicensed, + isNeosyncCloud, + permissiveRbacClient, + logger, + ) } -func (s *NeosyncApiTestClient) setupNeoCloudMux(ctx context.Context, pgcontainer *tcpostgres.PostgresTestContainer, logger *slog.Logger) (*http.ServeMux, error) { +func (s *NeosyncApiTestClient) setupNeoCloudMux( + ctx context.Context, + pgcontainer *tcpostgres.PostgresTestContainer, + logger *slog.Logger, +) (*http.ServeMux, error) { isLicensed := true isAuthEnabled := true isNeosyncCloud := true @@ -115,7 +153,14 @@ func (s *NeosyncApiTestClient) setupNeoCloudMux(ctx context.Context, pgcontainer if err != nil { return nil, fmt.Errorf("unable to get enforced rbac client: %w", err) } - return s.setupMux(pgcontainer, isAuthEnabled, isLicensed, isNeosyncCloud, enforcedRbacClient, logger) + return s.setupMux( + pgcontainer, + isAuthEnabled, + isLicensed, + isNeosyncCloud, + enforcedRbacClient, + logger, + ) } func (s *NeosyncApiTestClient) setupMux( @@ -146,7 +191,11 @@ func (s *NeosyncApiTestClient) setupMux( } userService := v1alpha1_useraccountservice.New( - &v1alpha1_useraccountservice.Config{IsAuthEnabled: isAuthEnabled, IsNeosyncCloud: isNeosyncCloud, DefaultMaxAllowedRecords: &maxAllowed}, + &v1alpha1_useraccountservice.Config{ + IsAuthEnabled: isAuthEnabled, + IsNeosyncCloud: isNeosyncCloud, + DefaultMaxAllowedRecords: &maxAllowed, + }, neosyncdb.New(pgcontainer.DB, db_queries.New()), s.Mocks.TemporalConfigProvider, s.Mocks.Authclient, @@ -228,12 +277,17 @@ func (s *NeosyncApiTestClient) setupMux( var presAnonClient presidioapi.AnonymizeInterface anonymizationService := v1alpha_anonymizationservice.New( - &v1alpha_anonymizationservice.Config{IsPresidioEnabled: isPresidioEnabled, IsAuthEnabled: isAuthEnabled, IsNeosyncCloud: isNeosyncCloud}, + &v1alpha_anonymizationservice.Config{ + IsPresidioEnabled: isPresidioEnabled, + IsAuthEnabled: isAuthEnabled, + IsNeosyncCloud: isNeosyncCloud, + }, nil, // meter userclient, userService, transformerService, - presAnalyzeClient, presAnonClient, + presAnalyzeClient, + presAnonClient, neosyncDb, ) @@ -299,8 +353,15 @@ func (s *NeosyncApiTestClient) setupMux( return mux, nil } -func (s *NeosyncApiTestClient) getEnforcedRbacClient(ctx context.Context, pgcontainer *tcpostgres.PostgresTestContainer) (rbac.Interface, error) { - rbacenforcer, err := enforcer.NewActiveEnforcer(ctx, stdlib.OpenDBFromPool(pgcontainer.DB), "neosync_api.casbin_rule") +func (s *NeosyncApiTestClient) getEnforcedRbacClient( + ctx context.Context, + pgcontainer *tcpostgres.PostgresTestContainer, +) (rbac.Interface, error) { + rbacenforcer, err := enforcer.NewActiveEnforcer( + ctx, + stdlib.OpenDBFromPool(pgcontainer.DB), + "neosync_api.casbin_rule", + ) if err != nil { return nil, fmt.Errorf("unable to create rbac enforcer: %w", err) } diff --git a/backend/pkg/metrics/usage.go b/backend/pkg/metrics/usage.go index f07fc7e467..2011dfd4a7 100644 --- a/backend/pkg/metrics/usage.go +++ b/backend/pkg/metrics/usage.go @@ -29,7 +29,10 @@ func GetDailyUsageFromProm( vector, ok := result.(model.Vector) if !ok { - return nil, -1, fmt.Errorf("error casting prometheus query result to model.Vector. Got %T", result) + return nil, -1, fmt.Errorf( + "error casting prometheus query result to model.Vector. Got %T", + result, + ) } dailyTotals := map[string]float64{} @@ -50,7 +53,12 @@ func GetDailyUsageFromProm( for _, day := range dates { date, err := time.Parse(NeosyncDateFormat, day) if err != nil { - return nil, -1, fmt.Errorf("unable to convert day back to usage date (%q) format (%q): %w", date, NeosyncDateFormat, err) + return nil, -1, fmt.Errorf( + "unable to convert day back to usage date (%q) format (%q): %w", + date, + NeosyncDateFormat, + err, + ) } mgmtDate := timeToDate(date) dailyResults = append(dailyResults, &mgmtv1alpha1.DayResult{ @@ -99,7 +107,13 @@ func sortUsageDates(a, b string) int { return 0 } -func GetTotalUsageFromProm(ctx context.Context, api promv1.API, query string, dayEnd time.Time, logger *slog.Logger) (float64, error) { +func GetTotalUsageFromProm( + ctx context.Context, + api promv1.API, + query string, + dayEnd time.Time, + logger *slog.Logger, +) (float64, error) { var overallTotal float64 result, warnings, err := api.Query(ctx, query, dayEnd) @@ -133,7 +147,13 @@ func GetPromQueryFromMetric( if err != nil { return "", err } - return fmt.Sprintf("sum(max_over_time(%s{%s}[%s])) by (%s)", metricName, labels.ToPromQueryString(), timeWindow, NeosyncDateLabel), nil + return fmt.Sprintf( + "sum(max_over_time(%s{%s}[%s])) by (%s)", + metricName, + labels.ToPromQueryString(), + timeWindow, + NeosyncDateLabel, + ), nil } const ( diff --git a/backend/pkg/metrics/util.go b/backend/pkg/metrics/util.go index 6b55492ce4..b93bb89553 100644 --- a/backend/pkg/metrics/util.go +++ b/backend/pkg/metrics/util.go @@ -13,8 +13,26 @@ import ( // // []string{"2024-09-.*", "2024-10-.*"} func GenerateMonthRegexRange(startDate, endDate *mgmtv1alpha1.Date) []string { - start := time.Date(int(startDate.Year), time.Month(startDate.Month), int(startDate.Day), 0, 0, 0, 0, time.UTC) - end := time.Date(int(endDate.Year), time.Month(endDate.Month), int(endDate.Day), 0, 0, 0, 0, time.UTC) + start := time.Date( + int(startDate.Year), + time.Month(startDate.Month), + int(startDate.Day), + 0, + 0, + 0, + 0, + time.UTC, + ) + end := time.Date( + int(endDate.Year), + time.Month(endDate.Month), + int(endDate.Day), + 0, + 0, + 0, + 0, + time.UTC, + ) patterns := []string{} has := map[string]any{} @@ -39,8 +57,26 @@ func GenerateMonthRegexRange(startDate, endDate *mgmtv1alpha1.Date) []string { // // 2024-09-14, 2024-09-15 == 2d func CalculatePromLookbackDuration(startDate, endDate *mgmtv1alpha1.Date) string { - start := time.Date(int(startDate.Year), time.Month(startDate.Month), int(startDate.Day), 0, 0, 0, 0, time.UTC) - end := time.Date(int(endDate.Year), time.Month(endDate.Month), int(endDate.Day), 0, 0, 0, 0, time.UTC) + start := time.Date( + int(startDate.Year), + time.Month(startDate.Month), + int(startDate.Day), + 0, + 0, + 0, + 0, + time.UTC, + ) + end := time.Date( + int(endDate.Year), + time.Month(endDate.Month), + int(endDate.Day), + 0, + 0, + 0, + 0, + time.UTC, + ) days := daysBetween(start, end) diff --git a/backend/pkg/mongoconnect/connector.go b/backend/pkg/mongoconnect/connector.go index e36b6578f4..0af19e1906 100644 --- a/backend/pkg/mongoconnect/connector.go +++ b/backend/pkg/mongoconnect/connector.go @@ -16,7 +16,10 @@ import ( ) type Interface interface { - NewFromConnectionConfig(cc *mgmtv1alpha1.ConnectionConfig, logger *slog.Logger) (DbContainer, error) + NewFromConnectionConfig( + cc *mgmtv1alpha1.ConnectionConfig, + logger *slog.Logger, + ) (DbContainer, error) } type DbContainer interface { @@ -107,7 +110,9 @@ func getConnectionDetails( mongoConfig := cc.GetMongoConfig() if mongoConfig == nil { - return nil, fmt.Errorf("mongo config was nil, expected ConnectionConfig to contain valid MongoConfig") + return nil, fmt.Errorf( + "mongo config was nil, expected ConnectionConfig to contain valid MongoConfig", + ) } if mongoConfig.GetClientTls() != nil { @@ -118,7 +123,10 @@ func getConnectionDetails( } tunnelCfg := mongoConfig.GetTunnel() if tunnelCfg != nil { - return nil, fmt.Errorf("tunneling in mongodb is not currently supported: %w", errors.ErrUnsupported) + return nil, fmt.Errorf( + "tunneling in mongodb is not currently supported: %w", + errors.ErrUnsupported, + ) } connDetails, err := getGeneralDbConnectConfigFromMongo(mongoConfig) @@ -130,7 +138,9 @@ func getConnectionDetails( }, nil } -func getGeneralDbConnectConfigFromMongo(config *mgmtv1alpha1.MongoConnectionConfig) (*connstring.ConnString, error) { +func getGeneralDbConnectConfigFromMongo( + config *mgmtv1alpha1.MongoConnectionConfig, +) (*connstring.ConnString, error) { dburl := config.GetUrl() if dburl == "" { return nil, fmt.Errorf("must provide valid mongoconfig url") diff --git a/backend/pkg/mssql-querier/querier.go b/backend/pkg/mssql-querier/querier.go index bccb99f53a..117a0f0887 100644 --- a/backend/pkg/mssql-querier/querier.go +++ b/backend/pkg/mssql-querier/querier.go @@ -9,16 +9,48 @@ import ( type Querier interface { GetAllSchemas(ctx context.Context, db mysql_queries.DBTX) ([]string, error) GetAllTables(ctx context.Context, db mysql_queries.DBTX) ([]*GetAllTablesRow, error) - GetCustomSequencesBySchemas(ctx context.Context, db mysql_queries.DBTX, schemas []string) ([]*GetCustomSequencesBySchemasRow, error) - GetCustomTriggersBySchemasAndTables(ctx context.Context, db mysql_queries.DBTX, schematables []string) ([]*GetCustomTriggersBySchemasAndTablesRow, error) - GetDataTypesBySchemas(ctx context.Context, db mysql_queries.DBTX, schematables []string) ([]*GetDataTypesBySchemasRow, error) + GetCustomSequencesBySchemas( + ctx context.Context, + db mysql_queries.DBTX, + schemas []string, + ) ([]*GetCustomSequencesBySchemasRow, error) + GetCustomTriggersBySchemasAndTables( + ctx context.Context, + db mysql_queries.DBTX, + schematables []string, + ) ([]*GetCustomTriggersBySchemasAndTablesRow, error) + GetDataTypesBySchemas( + ctx context.Context, + db mysql_queries.DBTX, + schematables []string, + ) ([]*GetDataTypesBySchemasRow, error) GetDatabaseSchema(ctx context.Context, db mysql_queries.DBTX) ([]*GetDatabaseSchemaRow, error) - GetDatabaseTableSchemasBySchemasAndTables(ctx context.Context, db mysql_queries.DBTX, schematables []string) ([]*GetDatabaseTableSchemasBySchemasAndTablesRow, error) - GetIndicesBySchemasAndTables(ctx context.Context, db mysql_queries.DBTX, schematables []string) ([]*GetIndicesBySchemasAndTablesRow, error) + GetDatabaseTableSchemasBySchemasAndTables( + ctx context.Context, + db mysql_queries.DBTX, + schematables []string, + ) ([]*GetDatabaseTableSchemasBySchemasAndTablesRow, error) + GetIndicesBySchemasAndTables( + ctx context.Context, + db mysql_queries.DBTX, + schematables []string, + ) ([]*GetIndicesBySchemasAndTablesRow, error) GetRolePermissions(ctx context.Context, db mysql_queries.DBTX) ([]*GetRolePermissionsRow, error) - GetTableConstraintsBySchemas(ctx context.Context, db mysql_queries.DBTX, schemas []string) ([]*GetTableConstraintsBySchemasRow, error) - GetViewsAndFunctionsBySchemas(ctx context.Context, db mysql_queries.DBTX, schemas []string) ([]*GetViewsAndFunctionsBySchemasRow, error) - GetUniqueIndexesBySchema(ctx context.Context, db mysql_queries.DBTX, schemas []string) ([]*GetUniqueIndexesBySchemaRow, error) + GetTableConstraintsBySchemas( + ctx context.Context, + db mysql_queries.DBTX, + schemas []string, + ) ([]*GetTableConstraintsBySchemasRow, error) + GetViewsAndFunctionsBySchemas( + ctx context.Context, + db mysql_queries.DBTX, + schemas []string, + ) ([]*GetViewsAndFunctionsBySchemasRow, error) + GetUniqueIndexesBySchema( + ctx context.Context, + db mysql_queries.DBTX, + schemas []string, + ) ([]*GetUniqueIndexesBySchemaRow, error) } var _ Querier = (*Queries)(nil) diff --git a/backend/pkg/mssql-querier/system.sql.go b/backend/pkg/mssql-querier/system.sql.go index bb05a14794..6df5300666 100644 --- a/backend/pkg/mssql-querier/system.sql.go +++ b/backend/pkg/mssql-querier/system.sql.go @@ -70,7 +70,10 @@ type GetDatabaseSchemaRow struct { IdentityIncrement sql.NullInt32 } -func (q *Queries) GetDatabaseSchema(ctx context.Context, db mysql_queries.DBTX) ([]*GetDatabaseSchemaRow, error) { +func (q *Queries) GetDatabaseSchema( + ctx context.Context, + db mysql_queries.DBTX, +) ([]*GetDatabaseSchemaRow, error) { rows, err := db.QueryContext(ctx, getDatabaseSchema) if err != nil { return nil, err @@ -163,7 +166,10 @@ type GetAllTablesRow struct { TableName string } -func (q *Queries) GetAllTables(ctx context.Context, db mysql_queries.DBTX) ([]*GetAllTablesRow, error) { +func (q *Queries) GetAllTables( + ctx context.Context, + db mysql_queries.DBTX, +) ([]*GetAllTablesRow, error) { rows, err := db.QueryContext(ctx, getAllTables) if err != nil { return nil, err @@ -291,7 +297,11 @@ type GetDatabaseTableSchemasBySchemasAndTablesRow struct { IdentityIncrement sql.NullInt32 } -func (q *Queries) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Context, db mysql_queries.DBTX, schematables []string) ([]*GetDatabaseTableSchemasBySchemasAndTablesRow, error) { +func (q *Queries) GetDatabaseTableSchemasBySchemasAndTables( + ctx context.Context, + db mysql_queries.DBTX, + schematables []string, +) ([]*GetDatabaseTableSchemasBySchemasAndTablesRow, error) { placeholders, args := createSchemaTableParams(schematables) query := fmt.Sprintf(getDatabaseTableSchemasBySchemasAndTables, placeholders) @@ -411,7 +421,10 @@ type GetRolePermissionsRow struct { PrivilegeType string } -func (q *Queries) GetRolePermissions(ctx context.Context, db mysql_queries.DBTX) ([]*GetRolePermissionsRow, error) { +func (q *Queries) GetRolePermissions( + ctx context.Context, + db mysql_queries.DBTX, +) ([]*GetRolePermissionsRow, error) { rows, err := db.QueryContext(ctx, getRolePermissions) if err != nil { return nil, err @@ -558,7 +571,11 @@ type GetTableConstraintsBySchemasRow struct { CheckClause sql.NullString } -func (q *Queries) GetTableConstraintsBySchemas(ctx context.Context, db mysql_queries.DBTX, schemas []string) ([]*GetTableConstraintsBySchemasRow, error) { +func (q *Queries) GetTableConstraintsBySchemas( + ctx context.Context, + db mysql_queries.DBTX, + schemas []string, +) ([]*GetTableConstraintsBySchemasRow, error) { placeholders, args := createSchemaTableParams(schemas) query := fmt.Sprintf(getTableConstraintsBySchemas, placeholders) @@ -705,7 +722,11 @@ type GetIndicesBySchemasAndTablesRow struct { IndexDefinition string } -func (q *Queries) GetIndicesBySchemasAndTables(ctx context.Context, db mysql_queries.DBTX, schematables []string) ([]*GetIndicesBySchemasAndTablesRow, error) { +func (q *Queries) GetIndicesBySchemasAndTables( + ctx context.Context, + db mysql_queries.DBTX, + schematables []string, +) ([]*GetIndicesBySchemasAndTablesRow, error) { placeholders, args := createSchemaTableParams(schematables) query := fmt.Sprintf(getIndicesBySchemasAndTable, placeholders) @@ -790,7 +811,11 @@ type GetViewsAndFunctionsBySchemasRow struct { Dependencies sql.NullString } -func (q *Queries) GetViewsAndFunctionsBySchemas(ctx context.Context, db mysql_queries.DBTX, schemas []string) ([]*GetViewsAndFunctionsBySchemasRow, error) { +func (q *Queries) GetViewsAndFunctionsBySchemas( + ctx context.Context, + db mysql_queries.DBTX, + schemas []string, +) ([]*GetViewsAndFunctionsBySchemasRow, error) { placeholders, args := createSchemaTableParams(schemas) query := fmt.Sprintf(getViewsAndFunctionsBySchemas, placeholders) rows, err := db.QueryContext(ctx, query, args...) @@ -854,7 +879,11 @@ type GetCustomSequencesBySchemasRow struct { Definition string } -func (q *Queries) GetCustomSequencesBySchemas(ctx context.Context, db mysql_queries.DBTX, schemas []string) ([]*GetCustomSequencesBySchemasRow, error) { +func (q *Queries) GetCustomSequencesBySchemas( + ctx context.Context, + db mysql_queries.DBTX, + schemas []string, +) ([]*GetCustomSequencesBySchemasRow, error) { placeholders, args := createSchemaTableParams(schemas) query := fmt.Sprintf(getCustomSequencesBySchemas, placeholders) rows, err := db.QueryContext(ctx, query, args...) @@ -903,7 +932,11 @@ type GetCustomTriggersBySchemasAndTablesRow struct { Definition sql.NullString } -func (q *Queries) GetCustomTriggersBySchemasAndTables(ctx context.Context, db mysql_queries.DBTX, schematables []string) ([]*GetCustomTriggersBySchemasAndTablesRow, error) { +func (q *Queries) GetCustomTriggersBySchemasAndTables( + ctx context.Context, + db mysql_queries.DBTX, + schematables []string, +) ([]*GetCustomTriggersBySchemasAndTablesRow, error) { placeholders, args := createSchemaTableParams(schematables) query := fmt.Sprintf(getCustomTriggersBySchemasAndTables, placeholders) rows, err := db.QueryContext(ctx, query, args...) @@ -996,9 +1029,16 @@ type GetDataTypesBySchemasRow struct { Definition string } -func (q *Queries) GetDataTypesBySchemas(ctx context.Context, db mysql_queries.DBTX, schemas []string) ([]*GetDataTypesBySchemasRow, error) { +func (q *Queries) GetDataTypesBySchemas( + ctx context.Context, + db mysql_queries.DBTX, + schemas []string, +) ([]*GetDataTypesBySchemasRow, error) { placeholders, args := createSchemaTableParams(schemas) - where := fmt.Sprintf("WHERE tt.is_user_defined = 1 AND SCHEMA_NAME(tt.schema_id) IN (%s);", placeholders) + where := fmt.Sprintf( + "WHERE tt.is_user_defined = 1 AND SCHEMA_NAME(tt.schema_id) IN (%s);", + placeholders, + ) query := getDataTypesBySchemasAndTables + " " + where rows, err := db.QueryContext(ctx, query, args...) if err != nil { @@ -1074,7 +1114,11 @@ type GetUniqueIndexesBySchemaRow struct { IndexColumns string } -func (q *Queries) GetUniqueIndexesBySchema(ctx context.Context, db mysql_queries.DBTX, schemas []string) ([]*GetUniqueIndexesBySchemaRow, error) { +func (q *Queries) GetUniqueIndexesBySchema( + ctx context.Context, + db mysql_queries.DBTX, + schemas []string, +) ([]*GetUniqueIndexesBySchemaRow, error) { placeholders, args := createSchemaTableParams(schemas) query := fmt.Sprintf(getUniqueIndexesBySchema, placeholders) rows, err := db.QueryContext(ctx, query, args...) diff --git a/backend/pkg/sqlconnect/sql-connector.go b/backend/pkg/sqlconnect/sql-connector.go index 3a76e4d1f5..6786990abf 100644 --- a/backend/pkg/sqlconnect/sql-connector.go +++ b/backend/pkg/sqlconnect/sql-connector.go @@ -51,12 +51,20 @@ func WithConnectionTimeout(timeoutSeconds uint32) SqlConnectorOption { } type SqlConnector interface { - NewDbFromConnectionConfig(connectionConfig *mgmtv1alpha1.ConnectionConfig, logger *slog.Logger, opts ...SqlConnectorOption) (SqlDbContainer, error) + NewDbFromConnectionConfig( + connectionConfig *mgmtv1alpha1.ConnectionConfig, + logger *slog.Logger, + opts ...SqlConnectorOption, + ) (SqlDbContainer, error) } type SqlOpenConnector struct{} -func (rc *SqlOpenConnector) NewDbFromConnectionConfig(cc *mgmtv1alpha1.ConnectionConfig, logger *slog.Logger, opts ...SqlConnectorOption) (SqlDbContainer, error) { +func (rc *SqlOpenConnector) NewDbFromConnectionConfig( + cc *mgmtv1alpha1.ConnectionConfig, + logger *slog.Logger, + opts ...SqlConnectorOption, +) (SqlDbContainer, error) { if cc == nil { return nil, errors.New("connectionConfig was nil, expected *mgmtv1alpha1.ConnectionConfig") } @@ -115,7 +123,11 @@ func (rc *SqlOpenConnector) NewDbFromConnectionConfig(cc *mgmtv1alpha1.Connectio } } -func getPgConnectorFn(dsn string, config *mgmtv1alpha1.PostgresConnectionConfig, logger *slog.Logger) stdlibConnectorGetter { +func getPgConnectorFn( + dsn string, + config *mgmtv1alpha1.PostgresConnectionConfig, + logger *slog.Logger, +) stdlibConnectorGetter { return func() (driver.Connector, func(), error) { connectorOpts := []postgrestunconnector.Option{ postgrestunconnector.WithLogger(logger), @@ -125,7 +137,10 @@ func getPgConnectorFn(dsn string, config *mgmtv1alpha1.PostgresConnectionConfig, if config.GetClientTls() != nil { tlsConfig, err := getTLSConfig(config.GetClientTls(), logger) if err != nil { - return nil, nil, fmt.Errorf("unable to construct postgres client tls config: %w", err) + return nil, nil, fmt.Errorf( + "unable to construct postgres client tls config: %w", + err, + ) } logger.Debug("constructed postgres client tls config") connectorOpts = append(connectorOpts, postgrestunconnector.WithTLSConfig(tlsConfig)) @@ -133,10 +148,18 @@ func getPgConnectorFn(dsn string, config *mgmtv1alpha1.PostgresConnectionConfig, if config.GetTunnel() != nil { cfg, err := tun.GetTunnelConfigFromSSHDto(config.GetTunnel()) if err != nil { - return nil, nil, fmt.Errorf("unable to construct postgres client tunnel config: %w", err) + return nil, nil, fmt.Errorf( + "unable to construct postgres client tunnel config: %w", + err, + ) } logger.Debug("constructed postgres tunnel config") - dialer := tun.NewLazySSHDialer(cfg.Addr, cfg.ClientConfig, tun.DefaultSSHDialerConfig(), logger) + dialer := tun.NewLazySSHDialer( + cfg.Addr, + cfg.ClientConfig, + tun.DefaultSSHDialerConfig(), + logger, + ) connectorOpts = append(connectorOpts, postgrestunconnector.WithDialer(dialer)) closers = append(closers, func() { logger.Debug("closing postgres ssh dialer") @@ -162,7 +185,11 @@ func getPgConnectorFn(dsn string, config *mgmtv1alpha1.PostgresConnectionConfig, } } -func getMysqlConnectorFn(dsn string, config *mgmtv1alpha1.MysqlConnectionConfig, logger *slog.Logger) stdlibConnectorGetter { +func getMysqlConnectorFn( + dsn string, + config *mgmtv1alpha1.MysqlConnectionConfig, + logger *slog.Logger, +) stdlibConnectorGetter { return func() (driver.Connector, func(), error) { connectorOpts := []mysqltunconnector.Option{} closers := []func(){} @@ -178,10 +205,18 @@ func getMysqlConnectorFn(dsn string, config *mgmtv1alpha1.MysqlConnectionConfig, if config.GetTunnel() != nil { cfg, err := tun.GetTunnelConfigFromSSHDto(config.GetTunnel()) if err != nil { - return nil, nil, fmt.Errorf("unable to construct mysql client tunnel config: %w", err) + return nil, nil, fmt.Errorf( + "unable to construct mysql client tunnel config: %w", + err, + ) } logger.Debug("constructed mysql tunnel config") - dialer := tun.NewLazySSHDialer(cfg.Addr, cfg.ClientConfig, tun.DefaultSSHDialerConfig(), logger) + dialer := tun.NewLazySSHDialer( + cfg.Addr, + cfg.ClientConfig, + tun.DefaultSSHDialerConfig(), + logger, + ) connectorOpts = append(connectorOpts, mysqltunconnector.WithDialer(dialer)) closers = append(closers, func() { logger.Debug("closing mysql ssh dialer") @@ -207,7 +242,11 @@ func getMysqlConnectorFn(dsn string, config *mgmtv1alpha1.MysqlConnectionConfig, } } -func getMssqlConnectorFn(dsn string, config *mgmtv1alpha1.MssqlConnectionConfig, logger *slog.Logger) stdlibConnectorGetter { +func getMssqlConnectorFn( + dsn string, + config *mgmtv1alpha1.MssqlConnectionConfig, + logger *slog.Logger, +) stdlibConnectorGetter { return func() (driver.Connector, func(), error) { connectorOpts := []mssqltunconnector.Option{} closers := []func(){} @@ -226,7 +265,12 @@ func getMssqlConnectorFn(dsn string, config *mgmtv1alpha1.MssqlConnectionConfig, return nil, nil, fmt.Errorf("unable to construct mssql tunnel config: %w", err) } logger.Debug("constructed mssql tunnel config") - dialer := tun.NewLazySSHDialer(cfg.Addr, cfg.ClientConfig, tun.DefaultSSHDialerConfig(), logger) + dialer := tun.NewLazySSHDialer( + cfg.Addr, + cfg.ClientConfig, + tun.DefaultSSHDialerConfig(), + logger, + ) connectorOpts = append(connectorOpts, mssqltunconnector.WithDialer(dialer)) closers = append(closers, func() { logger.Debug("closing mssql ssh dialer") @@ -252,7 +296,9 @@ func getMssqlConnectorFn(dsn string, config *mgmtv1alpha1.MssqlConnectionConfig, } } -func getConnectionOptsFromConnectionConfig(cc *mgmtv1alpha1.ConnectionConfig) (*DbConnectionOptions, error) { +func getConnectionOptsFromConnectionConfig( + cc *mgmtv1alpha1.ConnectionConfig, +) (*DbConnectionOptions, error) { switch config := cc.GetConfig().(type) { case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: return sqlConnOptsToDbConnOpts(config.MysqlConfig.GetConnectionOptions()) @@ -412,7 +458,9 @@ func getTLSConfig(cfg *mgmtv1alpha1.ClientTlsConfig, logger *slog.Logger) (*tls. clientCert := cfg.GetClientCert() clientKey := cfg.GetClientKey() if clientCert != "" && clientKey != "" { - logger.Debug("client cert and key provided, adding to certificates for client tls connection") + logger.Debug( + "client cert and key provided, adding to certificates for client tls connection", + ) cert, err := tls.X509KeyPair([]byte(cfg.GetClientCert()), []byte(cfg.GetClientKey())) if err != nil { return nil, fmt.Errorf("failed to load client certificate and key: %w", err) diff --git a/backend/pkg/sqlmanager/mssql/mssql-manager.go b/backend/pkg/sqlmanager/mssql/mssql-manager.go index 84d95a4e71..7d556de393 100644 --- a/backend/pkg/sqlmanager/mssql/mssql-manager.go +++ b/backend/pkg/sqlmanager/mssql/mssql-manager.go @@ -27,13 +27,26 @@ type Manager struct { ee_sqlmanager_mssql.Manager } -func NewManager(querier mssql_queries.Querier, db mysql_queries.DBTX, closer func(), logger *slog.Logger) *Manager { - return &Manager{querier: querier, db: db, close: closer, logger: logger, Manager: *ee_sqlmanager_mssql.NewManager(querier, db, closer, logger)} +func NewManager( + querier mssql_queries.Querier, + db mysql_queries.DBTX, + closer func(), + logger *slog.Logger, +) *Manager { + return &Manager{ + querier: querier, + db: db, + close: closer, + logger: logger, + Manager: *ee_sqlmanager_mssql.NewManager(querier, db, closer, logger), + } } const defaultIdentity string = "IDENTITY(1,1)" -func (m *Manager) GetDatabaseSchema(ctx context.Context) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { +func (m *Manager) GetDatabaseSchema( + ctx context.Context, +) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { dbSchemas, err := m.querier.GetDatabaseSchema(ctx, m.db) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err @@ -107,7 +120,10 @@ func isColumnUpdateAllowed(isIdentity, isComputed bool) bool { return true } -func (m *Manager) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { +func (m *Manager) GetDatabaseTableSchemasBySchemasAndTables( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { if len(tables) == 0 { return []*sqlmanager_shared.DatabaseSchemaRow{}, nil } @@ -187,19 +203,32 @@ func (m *Manager) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Context, return output, nil } -func (m *Manager) GetColumnsByTables(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableColumn, error) { + +func (m *Manager) GetColumnsByTables( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.TableColumn, error) { return nil, errors.ErrUnsupported } -func (m *Manager) GetTableConstraintsByTables(ctx context.Context, schema string, tables []string) (map[string]*sqlmanager_shared.AllTableConstraints, error) { +func (m *Manager) GetTableConstraintsByTables( + ctx context.Context, + schema string, + tables []string, +) (map[string]*sqlmanager_shared.AllTableConstraints, error) { return nil, errors.ErrUnsupported } -func (m *Manager) GetFunctionsBySchemas(ctx context.Context, schemas []string) ([]*sqlmanager_shared.DataType, error) { +func (m *Manager) GetFunctionsBySchemas( + ctx context.Context, + schemas []string, +) ([]*sqlmanager_shared.DataType, error) { return nil, errors.ErrUnsupported } -func (m *Manager) GetAllSchemas(ctx context.Context) ([]*sqlmanager_shared.DatabaseSchemaNameRow, error) { +func (m *Manager) GetAllSchemas( + ctx context.Context, +) ([]*sqlmanager_shared.DatabaseSchemaNameRow, error) { rows, err := m.querier.GetAllSchemas(ctx, m.db) if err != nil { return nil, err @@ -228,7 +257,9 @@ func (m *Manager) GetAllTables(ctx context.Context) ([]*sqlmanager_shared.Databa return result, nil } -func (m *Manager) GetSchemaColumnMap(ctx context.Context) (map[string]map[string]*sqlmanager_shared.DatabaseSchemaRow, error) { +func (m *Manager) GetSchemaColumnMap( + ctx context.Context, +) (map[string]map[string]*sqlmanager_shared.DatabaseSchemaRow, error) { dbSchemas, err := m.GetDatabaseSchema(ctx) if err != nil { return nil, err @@ -237,7 +268,10 @@ func (m *Manager) GetSchemaColumnMap(ctx context.Context) (map[string]map[string return result, nil } -func (m *Manager) GetTableConstraintsBySchema(ctx context.Context, schemas []string) (*sqlmanager_shared.TableConstraints, error) { +func (m *Manager) GetTableConstraintsBySchema( + ctx context.Context, + schemas []string, +) (*sqlmanager_shared.TableConstraints, error) { if len(schemas) == 0 { return &sqlmanager_shared.TableConstraints{}, nil } @@ -289,31 +323,47 @@ func (m *Manager) GetTableConstraintsBySchema(ctx context.Context, schemas []str notNullable = append(notNullable, nullability == "NOT NULL") } if len(constraintCols) != len(fkCols) { - return nil, fmt.Errorf("length of columns was not equal to length of foreign key cols: %d %d", len(constraintCols), len(fkCols)) + return nil, fmt.Errorf( + "length of columns was not equal to length of foreign key cols: %d %d", + len(constraintCols), + len(fkCols), + ) } if len(constraintCols) != len(notNullable) { - return nil, fmt.Errorf("length of columns was not equal to length of not nullable cols: %d %d", len(constraintCols), len(notNullable)) + return nil, fmt.Errorf( + "length of columns was not equal to length of not nullable cols: %d %d", + len(constraintCols), + len(notNullable), + ) } if isInvalidCircularSelfReferencingFk(row, constraintCols, fkCols) { continue } - foreignKeyMap[tableName] = append(foreignKeyMap[tableName], &sqlmanager_shared.ForeignConstraint{ - Columns: constraintCols, - NotNullable: notNullable, - ForeignKey: &sqlmanager_shared.ForeignKey{ - Table: sqlmanager_shared.BuildTable(row.ReferencedSchema.String, row.ReferencedTable.String), - Columns: fkCols, + foreignKeyMap[tableName] = append( + foreignKeyMap[tableName], + &sqlmanager_shared.ForeignConstraint{ + Columns: constraintCols, + NotNullable: notNullable, + ForeignKey: &sqlmanager_shared.ForeignKey{ + Table: sqlmanager_shared.BuildTable( + row.ReferencedSchema.String, + row.ReferencedTable.String, + ), + Columns: fkCols, + }, }, - }) + ) } case "PRIMARY KEY": if _, exists := primaryKeyMap[tableName]; !exists { primaryKeyMap[tableName] = []string{} } - primaryKeyMap[tableName] = append(primaryKeyMap[tableName], sqlmanager_shared.DedupeSlice(constraintCols)...) + primaryKeyMap[tableName] = append( + primaryKeyMap[tableName], + sqlmanager_shared.DedupeSlice(constraintCols)...) case "UNIQUE": columns := sqlmanager_shared.DedupeSlice(constraintCols) uniqueConstraintsMap[tableName] = append(uniqueConstraintsMap[tableName], columns) @@ -322,7 +372,10 @@ func (m *Manager) GetTableConstraintsBySchema(ctx context.Context, schemas []str for _, row := range uniqueIndexes { tableName := sqlmanager_shared.BuildTable(row.TableSchema, row.TableName) - uniqueIndexesMap[tableName] = append(uniqueIndexesMap[tableName], splitAndStrip(row.IndexColumns, ", ")) + uniqueIndexesMap[tableName] = append( + uniqueIndexesMap[tableName], + splitAndStrip(row.IndexColumns, ", "), + ) } return &sqlmanager_shared.TableConstraints{ @@ -336,7 +389,10 @@ func (m *Manager) GetTableConstraintsBySchema(ctx context.Context, schemas []str // Checks if a foreign key constraint is self-referencing (points to the same table) // and all constraint columns match their referenced columns, indicating a circular reference. // example public.users.id has a foreign key to public.users.id -func isInvalidCircularSelfReferencingFk(row *mssql_queries.GetTableConstraintsBySchemasRow, constraintColumns, referencedColumns []string) bool { +func isInvalidCircularSelfReferencingFk( + row *mssql_queries.GetTableConstraintsBySchemasRow, + constraintColumns, referencedColumns []string, +) bool { // Check if the foreign key references the same table isSameTable := row.SchemaName == row.ReferencedSchema.String && row.TableName == row.ReferencedTable.String @@ -382,7 +438,12 @@ func splitAndStrip(input, delim string) []string { //nolint:unparam return output } -func (m *Manager) BatchExec(ctx context.Context, batchSize int, statements []string, opts *sqlmanager_shared.BatchExecOpts) error { +func (m *Manager) BatchExec( + ctx context.Context, + batchSize int, + statements []string, + opts *sqlmanager_shared.BatchExecOpts, +) error { // mssql does not support batching statements total := len(statements) for idx, stmt := range statements { @@ -429,7 +490,9 @@ func (m *Manager) Close() { } } -func GetMssqlColumnOverrideAndResetProperties(columnInfo *sqlmanager_shared.DatabaseSchemaRow) (needsOverride, needsReset bool) { +func GetMssqlColumnOverrideAndResetProperties( + columnInfo *sqlmanager_shared.DatabaseSchemaRow, +) (needsOverride, needsReset bool) { needsOverride = false needsReset = false @@ -441,7 +504,8 @@ func GetMssqlColumnOverrideAndResetProperties(columnInfo *sqlmanager_shared.Data } // check if column default is sequence - if columnInfo.ColumnDefault != "" && gotypeutil.CaseInsensitiveContains(columnInfo.ColumnDefault, "NEXT VALUE") { + if columnInfo.ColumnDefault != "" && + gotypeutil.CaseInsensitiveContains(columnInfo.ColumnDefault, "NEXT VALUE") { needsReset = true return } diff --git a/backend/pkg/sqlmanager/mysql/mysql-manager.go b/backend/pkg/sqlmanager/mysql/mysql-manager.go index fe1a5c4d78..78613cf95f 100644 --- a/backend/pkg/sqlmanager/mysql/mysql-manager.go +++ b/backend/pkg/sqlmanager/mysql/mysql-manager.go @@ -29,11 +29,17 @@ type MysqlManager struct { close func() } -func NewManager(querier mysql_queries.Querier, pool mysql_queries.DBTX, closer func()) *MysqlManager { +func NewManager( + querier mysql_queries.Querier, + pool mysql_queries.DBTX, + closer func(), +) *MysqlManager { return &MysqlManager{querier: querier, pool: pool, close: closer} } -func (m *MysqlManager) GetDatabaseSchema(ctx context.Context) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { +func (m *MysqlManager) GetDatabaseSchema( + ctx context.Context, +) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { dbSchemas, err := m.querier.GetDatabaseSchema(ctx, m.pool) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err @@ -43,7 +49,8 @@ func (m *MysqlManager) GetDatabaseSchema(ctx context.Context) ([]*sqlmanager_sha result := []*sqlmanager_shared.DatabaseSchemaRow{} for _, row := range dbSchemas { var generatedType *string - if row.Extra.Valid && strings.Contains(row.Extra.String, "GENERATED") && !strings.Contains(row.Extra.String, "DEFAULT_GENERATED") { + if row.Extra.Valid && strings.Contains(row.Extra.String, "GENERATED") && + !strings.Contains(row.Extra.String, "DEFAULT_GENERATED") { generatedTypeCopy := row.Extra.String generatedType = &generatedTypeCopy } @@ -105,13 +112,17 @@ func (m *MysqlManager) GetDatabaseSchema(ctx context.Context) ([]*sqlmanager_sha func isColumnUpdateAllowed(generatedType sql.NullString) bool { // generated always stored columns cannot be updated - if generatedType.Valid && (strings.EqualFold(generatedType.String, "STORED GENERATED") || strings.EqualFold(generatedType.String, "VIRTUAL GENERATED")) { + if generatedType.Valid && + (strings.EqualFold(generatedType.String, "STORED GENERATED") || strings.EqualFold(generatedType.String, "VIRTUAL GENERATED")) { return false } return true } -func (m *MysqlManager) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { +func (m *MysqlManager) GetDatabaseTableSchemasBySchemasAndTables( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { if len(tables) == 0 { return []*sqlmanager_shared.DatabaseSchemaRow{}, nil } @@ -128,17 +139,24 @@ func (m *MysqlManager) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Con var colDefMapMu sync.Mutex for schema, tables := range schemaset { errgrp.Go(func() error { - columnDefs, err := m.querier.GetDatabaseTableSchemasBySchemasAndTables(errctx, m.pool, &mysql_queries.GetDatabaseTableSchemasBySchemasAndTablesParams{ - Schema: schema, - Tables: tables, - }) + columnDefs, err := m.querier.GetDatabaseTableSchemasBySchemasAndTables( + errctx, + m.pool, + &mysql_queries.GetDatabaseTableSchemasBySchemasAndTablesParams{ + Schema: schema, + Tables: tables, + }, + ) if err != nil { return err } colDefMapMu.Lock() defer colDefMapMu.Unlock() for _, columnDefinition := range columnDefs { - key := sqlmanager_shared.SchemaTable{Schema: columnDefinition.SchemaName, Table: columnDefinition.TableName} + key := sqlmanager_shared.SchemaTable{ + Schema: columnDefinition.SchemaName, + Table: columnDefinition.TableName, + } dbSchemas[key.String()] = append(dbSchemas[key.String()], columnDefinition) } return nil @@ -152,7 +170,9 @@ func (m *MysqlManager) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Con for _, rows := range dbSchemas { for _, row := range rows { var generatedType *string - if row.IdentityGeneration.Valid && strings.Contains(row.IdentityGeneration.String, "GENERATED") && !strings.Contains(row.IdentityGeneration.String, "DEFAULT_GENERATED") { + if row.IdentityGeneration.Valid && + strings.Contains(row.IdentityGeneration.String, "GENERATED") && + !strings.Contains(row.IdentityGeneration.String, "DEFAULT_GENERATED") { generatedTypeCopy := row.IdentityGeneration.String generatedType = &generatedTypeCopy } @@ -168,7 +188,8 @@ func (m *MysqlManager) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Con } var columnDefaultType *string - if row.IdentityGeneration.Valid && columnDefaultStr != "" && row.IdentityGeneration.String == "" { + if row.IdentityGeneration.Valid && columnDefaultStr != "" && + row.IdentityGeneration.String == "" { val := columnDefaultString // With this type columnDefaultStr will be surrounded by quotes when translated to SQL columnDefaultType = &val } else if row.IdentityGeneration.Valid && columnDefaultStr != "" && row.IdentityGeneration.String != "" { @@ -207,7 +228,10 @@ func (m *MysqlManager) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Con return result, nil } -func (m *MysqlManager) GetColumnsByTables(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableColumn, error) { +func (m *MysqlManager) GetColumnsByTables( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.TableColumn, error) { rows, err := m.GetDatabaseTableSchemasBySchemasAndTables(ctx, tables) if err != nil { return nil, err @@ -233,7 +257,9 @@ func (m *MysqlManager) GetColumnsByTables(ctx context.Context, tables []*sqlmana return columns, nil } -func (m *MysqlManager) GetAllSchemas(ctx context.Context) ([]*sqlmanager_shared.DatabaseSchemaNameRow, error) { +func (m *MysqlManager) GetAllSchemas( + ctx context.Context, +) ([]*sqlmanager_shared.DatabaseSchemaNameRow, error) { rows, err := m.querier.GetAllSchemas(ctx, m.pool) if err != nil { return nil, err @@ -247,7 +273,9 @@ func (m *MysqlManager) GetAllSchemas(ctx context.Context) ([]*sqlmanager_shared. return result, nil } -func (m *MysqlManager) GetAllTables(ctx context.Context) ([]*sqlmanager_shared.DatabaseTableRow, error) { +func (m *MysqlManager) GetAllTables( + ctx context.Context, +) ([]*sqlmanager_shared.DatabaseTableRow, error) { rows, err := m.querier.GetAllTables(ctx, m.pool) if err != nil { return nil, err @@ -263,7 +291,9 @@ func (m *MysqlManager) GetAllTables(ctx context.Context) ([]*sqlmanager_shared.D } // returns: {public.users: { id: struct{}{}, created_at: struct{}{}}} -func (m *MysqlManager) GetSchemaColumnMap(ctx context.Context) (map[string]map[string]*sqlmanager_shared.DatabaseSchemaRow, error) { +func (m *MysqlManager) GetSchemaColumnMap( + ctx context.Context, +) (map[string]map[string]*sqlmanager_shared.DatabaseSchemaRow, error) { dbSchemas, err := m.GetDatabaseSchema(ctx) if err != nil { return nil, err @@ -272,7 +302,10 @@ func (m *MysqlManager) GetSchemaColumnMap(ctx context.Context) (map[string]map[s return result, nil } -func (m *MysqlManager) GetTableConstraintsBySchema(ctx context.Context, schemas []string) (*sqlmanager_shared.TableConstraints, error) { +func (m *MysqlManager) GetTableConstraintsBySchema( + ctx context.Context, + schemas []string, +) (*sqlmanager_shared.TableConstraints, error) { if len(schemas) == 0 { return &sqlmanager_shared.TableConstraints{}, nil } @@ -309,25 +342,41 @@ func (m *MysqlManager) GetTableConstraintsBySchema(ctx context.Context, schemas notNullable = append(notNullable, notNullableInt == 1) } if len(constraintCols) != len(fkCols) { - return nil, fmt.Errorf("length of columns was not equal to length of foreign key cols: %d %d", len(constraintCols), len(fkCols)) + return nil, fmt.Errorf( + "length of columns was not equal to length of foreign key cols: %d %d", + len(constraintCols), + len(fkCols), + ) } if len(constraintCols) != len(notNullable) { - return nil, fmt.Errorf("length of columns was not equal to length of not nullable cols: %d %d", len(constraintCols), len(notNullable)) + return nil, fmt.Errorf( + "length of columns was not equal to length of not nullable cols: %d %d", + len(constraintCols), + len(notNullable), + ) } - foreignKeyMap[tableName] = append(foreignKeyMap[tableName], &sqlmanager_shared.ForeignConstraint{ - Columns: constraintCols, - NotNullable: notNullable, - ForeignKey: &sqlmanager_shared.ForeignKey{ - Table: sqlmanager_shared.BuildTable(row.ReferencedSchemaName, row.ReferencedTableName), - Columns: fkCols, + foreignKeyMap[tableName] = append( + foreignKeyMap[tableName], + &sqlmanager_shared.ForeignConstraint{ + Columns: constraintCols, + NotNullable: notNullable, + ForeignKey: &sqlmanager_shared.ForeignKey{ + Table: sqlmanager_shared.BuildTable( + row.ReferencedSchemaName, + row.ReferencedTableName, + ), + Columns: fkCols, + }, }, - }) + ) case "PRIMARY KEY": if _, exists := primaryKeyMap[tableName]; !exists { primaryKeyMap[tableName] = []string{} } - primaryKeyMap[tableName] = append(primaryKeyMap[tableName], sqlmanager_shared.DedupeSlice(constraintCols)...) + primaryKeyMap[tableName] = append( + primaryKeyMap[tableName], + sqlmanager_shared.DedupeSlice(constraintCols)...) case "UNIQUE": columns := sqlmanager_shared.DedupeSlice(constraintCols) uniqueConstraintsMap[tableName] = append(uniqueConstraintsMap[tableName], columns) @@ -376,7 +425,10 @@ type indexInfo struct { columns []string } -func (m *MysqlManager) GetTableInitStatements(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableInitStatement, error) { +func (m *MysqlManager) GetTableInitStatements( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.TableInitStatement, error) { if len(tables) == 0 { return []*sqlmanager_shared.TableInitStatement{}, nil } @@ -393,17 +445,27 @@ func (m *MysqlManager) GetTableInitStatements(ctx context.Context, tables []*sql var colDefMapMu sync.Mutex for schema, tables := range schemaset { errgrp.Go(func() error { - columnDefs, err := m.querier.GetDatabaseTableSchemasBySchemasAndTables(errctx, m.pool, &mysql_queries.GetDatabaseTableSchemasBySchemasAndTablesParams{ - Schema: schema, - Tables: tables, - }) + columnDefs, err := m.querier.GetDatabaseTableSchemasBySchemasAndTables( + errctx, + m.pool, + &mysql_queries.GetDatabaseTableSchemasBySchemasAndTablesParams{ + Schema: schema, + Tables: tables, + }, + ) if err != nil { - return fmt.Errorf("failed to build mysql database table schemas by schemas and tables: %w", err) + return fmt.Errorf( + "failed to build mysql database table schemas by schemas and tables: %w", + err, + ) } colDefMapMu.Lock() defer colDefMapMu.Unlock() for _, columnDefinition := range columnDefs { - key := sqlmanager_shared.SchemaTable{Schema: columnDefinition.SchemaName, Table: columnDefinition.TableName} + key := sqlmanager_shared.SchemaTable{ + Schema: columnDefinition.SchemaName, + Table: columnDefinition.TableName, + } colDefMap[key.String()] = append(colDefMap[key.String()], columnDefinition) } return nil @@ -414,17 +476,24 @@ func (m *MysqlManager) GetTableInitStatements(ctx context.Context, tables []*sql var constraintMapMu sync.Mutex for schema, tables := range schemaset { errgrp.Go(func() error { - constraints, err := m.querier.GetTableConstraints(errctx, m.pool, &mysql_queries.GetTableConstraintsParams{ - Schema: schema, - Tables: tables, - }) + constraints, err := m.querier.GetTableConstraints( + errctx, + m.pool, + &mysql_queries.GetTableConstraintsParams{ + Schema: schema, + Tables: tables, + }, + ) if err != nil { return fmt.Errorf("failed to build mysql table constraints: %w", err) } constraintMapMu.Lock() defer constraintMapMu.Unlock() for _, constraint := range constraints { - key := sqlmanager_shared.SchemaTable{Schema: constraint.SchemaName, Table: constraint.TableName} + key := sqlmanager_shared.SchemaTable{ + Schema: constraint.SchemaName, + Table: constraint.TableName, + } constraintmap[key.String()] = append(constraintmap[key.String()], constraint) } return nil @@ -435,10 +504,14 @@ func (m *MysqlManager) GetTableInitStatements(ctx context.Context, tables []*sql var indexMapMu sync.Mutex for schema, tables := range schemaset { errgrp.Go(func() error { - idxrecords, err := m.querier.GetIndicesBySchemasAndTables(errctx, m.pool, &mysql_queries.GetIndicesBySchemasAndTablesParams{ - Schema: schema, - Tables: tables, - }) + idxrecords, err := m.querier.GetIndicesBySchemasAndTables( + errctx, + m.pool, + &mysql_queries.GetIndicesBySchemasAndTablesParams{ + Schema: schema, + Tables: tables, + }, + ) if err != nil { return fmt.Errorf("failed to build mysql indices by schemas and tables: %w", err) } @@ -446,7 +519,10 @@ func (m *MysqlManager) GetTableInitStatements(ctx context.Context, tables []*sql indexMapMu.Lock() defer indexMapMu.Unlock() for _, record := range idxrecords { - key := sqlmanager_shared.SchemaTable{Schema: record.SchemaName, Table: record.TableName} + key := sqlmanager_shared.SchemaTable{ + Schema: record.SchemaName, + Table: record.TableName, + } if _, exists := indexmap[key.String()]; !exists { indexmap[key.String()] = make(map[string]*indexInfo) } @@ -533,14 +609,22 @@ func (m *MysqlManager) GetTableInitStatements(ctx context.Context, tables []*sql } info := &sqlmanager_shared.TableInitStatement{ - CreateTableStatement: fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s`.`%s` (%s);", tableData[0].SchemaName, tableData[0].TableName, strings.Join(columns, ", ")), + CreateTableStatement: fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS `%s`.`%s` (%s);", + tableData[0].SchemaName, + tableData[0].TableName, + strings.Join(columns, ", "), + ), AlterTableStatements: []*sqlmanager_shared.AlterTableStatement{}, IndexStatements: []string{}, } for _, constraint := range constraintmap[key] { stmt, err := buildAlterStatementByConstraint(constraint) if err != nil { - return nil, fmt.Errorf("failed to build alter table statement by constraint: %w", err) + return nil, fmt.Errorf( + "failed to build alter table statement by constraint: %w", + err, + ) } info.AlterTableStatements = append(info.AlterTableStatements, stmt) } @@ -557,7 +641,11 @@ func (m *MysqlManager) GetTableInitStatements(ctx context.Context, tables []*sql return output, nil } -func (m *MysqlManager) GetSequencesByTables(ctx context.Context, schema string, tables []string) ([]*sqlmanager_shared.DataType, error) { +func (m *MysqlManager) GetSequencesByTables( + ctx context.Context, + schema string, + tables []string, +) ([]*sqlmanager_shared.DataType, error) { return nil, errors.ErrUnsupported } @@ -569,11 +657,19 @@ func convertUInt8ToString(value any) (string, error) { return string(convertedType), nil } -func (m *MysqlManager) GetTableConstraintsByTables(ctx context.Context, schema string, tables []string) (map[string]*sqlmanager_shared.AllTableConstraints, error) { - constraints, err := m.querier.GetTableConstraints(ctx, m.pool, &mysql_queries.GetTableConstraintsParams{ - Schema: schema, - Tables: tables, - }) +func (m *MysqlManager) GetTableConstraintsByTables( + ctx context.Context, + schema string, + tables []string, +) (map[string]*sqlmanager_shared.AllTableConstraints, error) { + constraints, err := m.querier.GetTableConstraints( + ctx, + m.pool, + &mysql_queries.GetTableConstraintsParams{ + Schema: schema, + Tables: tables, + }, + ) if err != nil { return nil, fmt.Errorf("failed to get table constraints by schemas: %w", err) } @@ -596,7 +692,10 @@ func (m *MysqlManager) GetTableConstraintsByTables(ctx context.Context, schema s for _, notNullableInt := range notNullableInts { notNullable = append(notNullable, notNullableInt == 1) } - key := sqlmanager_shared.SchemaTable{Schema: constraint.SchemaName, Table: constraint.TableName}.String() + key := sqlmanager_shared.SchemaTable{ + Schema: constraint.SchemaName, + Table: constraint.TableName, + }.String() if allConstraints[key] == nil { allConstraints[key] = &sqlmanager_shared.AllTableConstraints{ ForeignKeyConstraints: []*sqlmanager_shared.ForeignKeyConstraint{}, @@ -618,7 +717,10 @@ func (m *MysqlManager) GetTableConstraintsByTables(ctx context.Context, schema s DeleteRule: nullStringToPtr(constraint.DeleteRule), } fk.Fingerprint = sqlmanager_shared.BuildForeignKeyConstraintFingerprint(fk) - allConstraints[key].ForeignKeyConstraints = append(allConstraints[key].ForeignKeyConstraints, fk) + allConstraints[key].ForeignKeyConstraints = append( + allConstraints[key].ForeignKeyConstraints, + fk, + ) } else { checkStr, err := convertUInt8ToString(constraint.CheckClause) if err != nil { @@ -655,7 +757,10 @@ func BuildUpdateColumnStatement(column *sqlmanager_shared.TableColumn) (string, } func buildColumnStatement(keyword string, column *sqlmanager_shared.TableColumn) (string, error) { - columnDefaultStr, err := EscapeMysqlDefaultColumn(column.ColumnDefault, column.ColumnDefaultType) + columnDefaultStr, err := EscapeMysqlDefaultColumn( + column.ColumnDefault, + column.ColumnDefaultType, + ) if err != nil { return "", fmt.Errorf("failed to escape column default: %w", err) } @@ -682,32 +787,61 @@ func buildColumnStatement(keyword string, column *sqlmanager_shared.TableColumn) col = buildTableColForCreate(colReq) } - return fmt.Sprintf("ALTER TABLE %s.%s %s COLUMN %s;", EscapeMysqlColumn(column.Schema), EscapeMysqlColumn(column.Table), keyword, col), nil + return fmt.Sprintf( + "ALTER TABLE %s.%s %s COLUMN %s;", + EscapeMysqlColumn(column.Schema), + EscapeMysqlColumn(column.Table), + keyword, + col, + ), nil } func BuildDropColumnStatement(column *sqlmanager_shared.TableColumn) string { - return fmt.Sprintf("ALTER TABLE %s.%s DROP COLUMN %s;", EscapeMysqlColumn(column.Schema), EscapeMysqlColumn(column.Table), EscapeMysqlColumn(column.Name)) + return fmt.Sprintf( + "ALTER TABLE %s.%s DROP COLUMN %s;", + EscapeMysqlColumn(column.Schema), + EscapeMysqlColumn(column.Table), + EscapeMysqlColumn(column.Name), + ) } func BuildDropConstraintStatement(schema, table, constraintType, constraintName string) string { if strings.EqualFold(constraintType, "PRIMARY KEY") { - return fmt.Sprintf("ALTER TABLE %s.%s DROP PRIMARY KEY;", EscapeMysqlColumn(schema), EscapeMysqlColumn(table)) + return fmt.Sprintf( + "ALTER TABLE %s.%s DROP PRIMARY KEY;", + EscapeMysqlColumn(schema), + EscapeMysqlColumn(table), + ) } if strings.EqualFold(constraintType, "UNIQUE") { constraintType = "INDEX" } - return fmt.Sprintf("ALTER TABLE %s.%s DROP %s %s;", EscapeMysqlColumn(schema), EscapeMysqlColumn(table), constraintType, EscapeMysqlColumn(constraintName)) + return fmt.Sprintf( + "ALTER TABLE %s.%s DROP %s %s;", + EscapeMysqlColumn(schema), + EscapeMysqlColumn(table), + constraintType, + EscapeMysqlColumn(constraintName), + ) } func BuildDropTriggerStatement(schema *string, triggerName string) string { if schema == nil { return fmt.Sprintf("DROP TRIGGER IF EXISTS %s;", EscapeMysqlColumn(triggerName)) } - return fmt.Sprintf("DROP TRIGGER IF EXISTS %s.%s;", EscapeMysqlColumn(*schema), EscapeMysqlColumn(triggerName)) + return fmt.Sprintf( + "DROP TRIGGER IF EXISTS %s.%s;", + EscapeMysqlColumn(*schema), + EscapeMysqlColumn(triggerName), + ) } func BuildDropFunctionStatement(schema, functionName string) string { - return fmt.Sprintf("DROP FUNCTION IF EXISTS %s.%s;", EscapeMysqlColumn(schema), EscapeMysqlColumn(functionName)) + return fmt.Sprintf( + "DROP FUNCTION IF EXISTS %s.%s;", + EscapeMysqlColumn(schema), + EscapeMysqlColumn(functionName), + ) } func buildTableColForModifyColumn(record *buildTableColRequest) string { @@ -739,7 +873,10 @@ func buildTableCol(record *buildTableColRequest, isModifyColumn bool) string { } else if record.IdentityType != nil && *record.IdentityType == "VIRTUAL GENERATED" { genType = "VIRTUAL" } - pieces = append(pieces, fmt.Sprintf("GENERATED ALWAYS AS (%s) %s", record.GeneratedExpression, genType)) + pieces = append( + pieces, + fmt.Sprintf("GENERATED ALWAYS AS (%s) %s", record.GeneratedExpression, genType), + ) } else { pieces = append(pieces, buildNullableText(record.IsNullable)) } @@ -758,7 +895,10 @@ func buildTableCol(record *buildTableColRequest, isModifyColumn bool) string { } if record.Comment != nil && *record.Comment != "" { - pieces = append(pieces, fmt.Sprintf("COMMENT '%s'", strings.ReplaceAll(*record.Comment, "'", `\'`))) + pieces = append( + pieces, + fmt.Sprintf("COMMENT '%s'", strings.ReplaceAll(*record.Comment, "'", `\'`)), + ) } return strings.Join(pieces, " ") @@ -771,7 +911,9 @@ func buildNullableText(isNullable bool) string { return "NOT NULL" } -func buildAlterStatementByConstraint(c *mysql_queries.GetTableConstraintsRow) (*sqlmanager_shared.AlterTableStatement, error) { +func buildAlterStatementByConstraint( + c *mysql_queries.GetTableConstraintsRow, +) (*sqlmanager_shared.AlterTableStatement, error) { constraintCols, err := jsonRawToSlice[string](c.ConstraintColumns) if err != nil { return nil, err @@ -782,19 +924,41 @@ func buildAlterStatementByConstraint(c *mysql_queries.GetTableConstraintsRow) (* } switch c.ConstraintType { case "PRIMARY KEY": - stmt := fmt.Sprintf("ALTER TABLE `%s`.`%s` ADD PRIMARY KEY (%s);", c.SchemaName, c.TableName, strings.Join(EscapeMysqlColumns(constraintCols), ",")) + stmt := fmt.Sprintf( + "ALTER TABLE `%s`.`%s` ADD PRIMARY KEY (%s);", + c.SchemaName, + c.TableName, + strings.Join(EscapeMysqlColumns(constraintCols), ","), + ) return &sqlmanager_shared.AlterTableStatement{ - Statement: wrapIdempotentConstraint(c.SchemaName, c.TableName, c.ConstraintName, stmt), + Statement: wrapIdempotentConstraint( + c.SchemaName, + c.TableName, + c.ConstraintName, + stmt, + ), ConstraintType: sqlmanager_shared.PrimaryConstraintType, }, nil case "UNIQUE": - stmt := fmt.Sprintf("ALTER TABLE `%s`.`%s` ADD CONSTRAINT `%s` UNIQUE (%s);", c.SchemaName, c.TableName, c.ConstraintName, strings.Join(EscapeMysqlColumns(constraintCols), ",")) + stmt := fmt.Sprintf( + "ALTER TABLE `%s`.`%s` ADD CONSTRAINT `%s` UNIQUE (%s);", + c.SchemaName, + c.TableName, + c.ConstraintName, + strings.Join(EscapeMysqlColumns(constraintCols), ","), + ) return &sqlmanager_shared.AlterTableStatement{ - Statement: wrapIdempotentConstraint(c.SchemaName, c.TableName, c.ConstraintName, stmt), + Statement: wrapIdempotentConstraint( + c.SchemaName, + c.TableName, + c.ConstraintName, + stmt, + ), ConstraintType: sqlmanager_shared.UniqueConstraintType, }, nil case "FOREIGN KEY": - stmt := fmt.Sprintf("ALTER TABLE `%s`.`%s` ADD CONSTRAINT `%s` FOREIGN KEY (%s) REFERENCES `%s`.`%s`(%s) ON DELETE %s ON UPDATE %s;", + stmt := fmt.Sprintf( + "ALTER TABLE `%s`.`%s` ADD CONSTRAINT `%s` FOREIGN KEY (%s) REFERENCES `%s`.`%s`(%s) ON DELETE %s ON UPDATE %s;", c.SchemaName, c.TableName, c.ConstraintName, @@ -806,7 +970,12 @@ func buildAlterStatementByConstraint(c *mysql_queries.GetTableConstraintsRow) (* c.UpdateRule.String, ) return &sqlmanager_shared.AlterTableStatement{ - Statement: wrapIdempotentConstraint(c.SchemaName, c.TableName, c.ConstraintName, stmt), + Statement: wrapIdempotentConstraint( + c.SchemaName, + c.TableName, + c.ConstraintName, + stmt, + ), ConstraintType: sqlmanager_shared.ForeignConstraintType, }, nil case "CHECK": @@ -814,16 +983,30 @@ func buildAlterStatementByConstraint(c *mysql_queries.GetTableConstraintsRow) (* if err != nil { return nil, err } - stmt := fmt.Sprintf("ALTER TABLE `%s`.`%s` ADD CONSTRAINT %s CHECK (%s);", c.SchemaName, c.TableName, c.ConstraintName, checkStr) + stmt := fmt.Sprintf( + "ALTER TABLE `%s`.`%s` ADD CONSTRAINT %s CHECK (%s);", + c.SchemaName, + c.TableName, + c.ConstraintName, + checkStr, + ) return &sqlmanager_shared.AlterTableStatement{ - Statement: wrapIdempotentConstraint(c.SchemaName, c.TableName, c.ConstraintName, stmt), + Statement: wrapIdempotentConstraint( + c.SchemaName, + c.TableName, + c.ConstraintName, + stmt, + ), ConstraintType: sqlmanager_shared.CheckConstraintType, }, nil } return nil, errors.ErrUnsupported } -func (m *MysqlManager) GetSchemaTableDataTypes(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) (*sqlmanager_shared.SchemaTableDataTypeResponse, error) { +func (m *MysqlManager) GetSchemaTableDataTypes( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) (*sqlmanager_shared.SchemaTableDataTypeResponse, error) { if len(tables) == 0 { return &sqlmanager_shared.SchemaTableDataTypeResponse{}, nil } @@ -847,7 +1030,10 @@ func (m *MysqlManager) GetSchemaTableDataTypes(ctx context.Context, tables []*sq return output, nil } -func (m *MysqlManager) GetSchemaTableTriggers(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableTrigger, error) { +func (m *MysqlManager) GetSchemaTableTriggers( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.TableTrigger, error) { if len(tables) == 0 { return []*sqlmanager_shared.TableTrigger{}, nil } @@ -868,10 +1054,14 @@ func (m *MysqlManager) GetSchemaTableTriggers(ctx context.Context, tables []*sql schema := schema tables := tables errgrp.Go(func() error { - rows, err := m.querier.GetCustomTriggersBySchemaAndTables(errctx, m.pool, &mysql_queries.GetCustomTriggersBySchemaAndTablesParams{ - Schema: schema, - Tables: tables, - }) + rows, err := m.querier.GetCustomTriggersBySchemaAndTables( + errctx, + m.pool, + &mysql_queries.GetCustomTriggersBySchemaAndTablesParams{ + Schema: schema, + Tables: tables, + }, + ) if err != nil && !neosyncdb.IsNoRows(err) { return err } else if err != nil && neosyncdb.IsNoRows(err) { @@ -899,7 +1089,16 @@ func (m *MysqlManager) GetSchemaTableTriggers(ctx context.Context, tables []*sql Table: row.TableName, TriggerSchema: &row.TriggerSchema, TriggerName: row.TriggerName, - Definition: wrapIdempotentTrigger(row.SchemaName, row.TableName, row.TriggerName, row.TriggerSchema, row.Timing, row.EventType, row.Orientation, row.Statement), + Definition: wrapIdempotentTrigger( + row.SchemaName, + row.TableName, + row.TriggerName, + row.TriggerSchema, + row.Timing, + row.EventType, + row.Orientation, + row.Statement, + ), } trigger.Fingerprint = sqlmanager_shared.BuildTriggerFingerprint(trigger) output = append(output, trigger) @@ -922,7 +1121,10 @@ func (m *MysqlManager) GetSchemaInitStatements( uniqueSchemas[table.Schema] = struct{}{} } for schema := range uniqueSchemas { - schemaStmts = append(schemaStmts, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS `%s`;", schema)) + schemaStmts = append( + schemaStmts, + fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS `%s`;", schema), + ) } return nil }) @@ -987,7 +1189,10 @@ func (m *MysqlManager) GetSchemaInitStatements( }, nil } -func (m *MysqlManager) getFunctionsBySchemas(ctx context.Context, schemas []string) ([]*sqlmanager_shared.DataType, error) { +func (m *MysqlManager) getFunctionsBySchemas( + ctx context.Context, + schemas []string, +) ([]*sqlmanager_shared.DataType, error) { rows, err := m.querier.GetCustomFunctionsBySchemas(ctx, m.pool, schemas) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err @@ -1002,17 +1207,31 @@ func (m *MysqlManager) getFunctionsBySchemas(ctx context.Context, schemas []stri return nil, err } function := &sqlmanager_shared.DataType{ - Schema: row.SchemaName, - Name: row.FunctionName, - Definition: wrapIdempotentFunction(row.SchemaName, row.FunctionName, functionSignatureStr, row.ReturnDataType, row.Definition, row.IsDeterministic == 1), + Schema: row.SchemaName, + Name: row.FunctionName, + Definition: wrapIdempotentFunction( + row.SchemaName, + row.FunctionName, + functionSignatureStr, + row.ReturnDataType, + row.Definition, + row.IsDeterministic == 1, + ), } - function.Fingerprint = sqlmanager_shared.BuildFingerprint(function.Schema, function.Name, function.Definition) + function.Fingerprint = sqlmanager_shared.BuildFingerprint( + function.Schema, + function.Name, + function.Definition, + ) output = append(output, function) } return output, nil } -func (m *MysqlManager) GetFunctionsByTables(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.DataType, error) { +func (m *MysqlManager) GetFunctionsByTables( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.DataType, error) { schemaMap := map[string]struct{}{} for _, t := range tables { schemaMap[t.Schema] = struct{}{} @@ -1062,10 +1281,25 @@ func hashInput(input ...string) string { } func createIndexStmt(schema, table string, idxInfo *indexInfo, columnInput []string) string { - if strings.EqualFold(idxInfo.indexType, "spatial") || strings.EqualFold(idxInfo.indexType, "fulltext") { - return fmt.Sprintf("ALTER TABLE %s.%s ADD %s INDEX %s (%s);", EscapeMysqlColumn(schema), EscapeMysqlColumn(table), idxInfo.indexType, EscapeMysqlColumn(idxInfo.indexName), strings.Join(columnInput, ", ")) + if strings.EqualFold(idxInfo.indexType, "spatial") || + strings.EqualFold(idxInfo.indexType, "fulltext") { + return fmt.Sprintf( + "ALTER TABLE %s.%s ADD %s INDEX %s (%s);", + EscapeMysqlColumn(schema), + EscapeMysqlColumn(table), + idxInfo.indexType, + EscapeMysqlColumn(idxInfo.indexName), + strings.Join(columnInput, ", "), + ) } - return fmt.Sprintf("ALTER TABLE %s.%s ADD INDEX %s (%s) USING %s;", EscapeMysqlColumn(schema), EscapeMysqlColumn(table), EscapeMysqlColumn(idxInfo.indexName), strings.Join(columnInput, ", "), idxInfo.indexType) + return fmt.Sprintf( + "ALTER TABLE %s.%s ADD INDEX %s (%s) USING %s;", + EscapeMysqlColumn(schema), + EscapeMysqlColumn(table), + EscapeMysqlColumn(idxInfo.indexName), + strings.Join(columnInput, ", "), + idxInfo.indexType, + ) } func wrapIdempotentIndex( @@ -1148,7 +1382,12 @@ FOR EACH %s return strings.TrimSpace(stmt) } -func (m *MysqlManager) BatchExec(ctx context.Context, batchSize int, statements []string, opts *sqlmanager_shared.BatchExecOpts) error { +func (m *MysqlManager) BatchExec( + ctx context.Context, + batchSize int, + statements []string, + opts *sqlmanager_shared.BatchExecOpts, +) error { for i := 0; i < len(statements); i += batchSize { end := i + batchSize if end > len(statements) { @@ -1232,7 +1471,10 @@ func EscapeMysqlColumn(col string) string { return fmt.Sprintf("`%s`", col) } -func EscapeMysqlDefaultColumn(defaultColumnValue string, defaultColumnType *string) (string, error) { +func EscapeMysqlDefaultColumn( + defaultColumnValue string, + defaultColumnType *string, +) (string, error) { defaultColumnTypes := []string{columnDefaultString, columnDefaultDefault} if defaultColumnType == nil { return defaultColumnValue, nil @@ -1243,10 +1485,19 @@ func EscapeMysqlDefaultColumn(defaultColumnValue string, defaultColumnType *stri if *defaultColumnType == columnDefaultDefault { return fmt.Sprintf("(%s)", defaultColumnValue), nil } - return fmt.Sprintf("(%s)", defaultColumnValue), fmt.Errorf("unsupported default column type: %s, currently supported types are: %v", *defaultColumnType, defaultColumnTypes) + return fmt.Sprintf( + "(%s)", + defaultColumnValue, + ), fmt.Errorf( + "unsupported default column type: %s, currently supported types are: %v", + *defaultColumnType, + defaultColumnTypes, + ) } -func GetMysqlColumnOverrideAndResetProperties(columnInfo *sqlmanager_shared.DatabaseSchemaRow) (needsOverride, needsReset bool) { +func GetMysqlColumnOverrideAndResetProperties( + columnInfo *sqlmanager_shared.DatabaseSchemaRow, +) (needsOverride, needsReset bool) { needsOverride = false needsReset = false return diff --git a/backend/pkg/sqlmanager/postgres/postgres-manager.go b/backend/pkg/sqlmanager/postgres/postgres-manager.go index 971cc9e8e7..2c81b290d3 100644 --- a/backend/pkg/sqlmanager/postgres/postgres-manager.go +++ b/backend/pkg/sqlmanager/postgres/postgres-manager.go @@ -26,7 +26,9 @@ func NewManager(querier pg_queries.Querier, db pg_queries.DBTX, closer func()) * return &PostgresManager{querier: querier, db: db, close: closer} } -func (p *PostgresManager) GetDatabaseSchema(ctx context.Context) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { +func (p *PostgresManager) GetDatabaseSchema( + ctx context.Context, +) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { dbSchemas, err := p.querier.GetDatabaseSchema(ctx, p.db) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err @@ -59,7 +61,10 @@ func (p *PostgresManager) GetDatabaseSchema(ctx context.Context) ([]*sqlmanager_ OrdinalPosition: int(row.OrdinalPosition), GeneratedType: generatedType, IdentityGeneration: identityGeneration, - UpdateAllowed: isColumnUpdateAllowed(row.IdentityGeneration, row.GeneratedType), + UpdateAllowed: isColumnUpdateAllowed( + row.IdentityGeneration, + row.GeneratedType, + ), }) } return result, nil @@ -73,7 +78,10 @@ func isColumnUpdateAllowed(identityGeneration, generatedType string) bool { return true } -func (p *PostgresManager) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { +func (p *PostgresManager) GetDatabaseTableSchemasBySchemasAndTables( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { schemaTables := make([]string, 0, len(tables)) for _, t := range tables { schemaTables = append(schemaTables, t.String()) @@ -97,13 +105,19 @@ func (p *PostgresManager) GetDatabaseTableSchemasBySchemasAndTables(ctx context. OrdinalPosition: int(row.OrdinalPosition), GeneratedType: sqlmanager_shared.Ptr(row.GeneratedType), IdentityGeneration: sqlmanager_shared.Ptr(row.IdentityGeneration), - UpdateAllowed: isColumnUpdateAllowed(row.IdentityGeneration, row.GeneratedType), + UpdateAllowed: isColumnUpdateAllowed( + row.IdentityGeneration, + row.GeneratedType, + ), }) } return result, nil } -func (p *PostgresManager) GetColumnsByTables(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableColumn, error) { +func (p *PostgresManager) GetColumnsByTables( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.TableColumn, error) { schemaTables := make([]string, 0, len(tables)) for _, t := range tables { schemaTables = append(schemaTables, t.String()) @@ -147,7 +161,11 @@ func (p *PostgresManager) GetColumnsByTables(ctx context.Context, tables []*sqlm return result, nil } -func (p *PostgresManager) GetTableConstraintsByTables(ctx context.Context, schema string, tables []string) (map[string]*sqlmanager_shared.AllTableConstraints, error) { +func (p *PostgresManager) GetTableConstraintsByTables( + ctx context.Context, + schema string, + tables []string, +) (map[string]*sqlmanager_shared.AllTableConstraints, error) { if len(tables) == 0 { return map[string]*sqlmanager_shared.AllTableConstraints{}, nil } @@ -156,10 +174,14 @@ func (p *PostgresManager) GetTableConstraintsByTables(ctx context.Context, schem var fkConstraints []*pg_queries.GetForeignKeyConstraintsBySchemasAndTablesRow errgrp.Go(func() error { var err error - constraints, err := p.querier.GetNonForeignKeyTableConstraintsBySchemaAndTables(errctx, p.db, &pg_queries.GetNonForeignKeyTableConstraintsBySchemaAndTablesParams{ - Schema: schema, - Tables: tables, - }) + constraints, err := p.querier.GetNonForeignKeyTableConstraintsBySchemaAndTables( + errctx, + p.db, + &pg_queries.GetNonForeignKeyTableConstraintsBySchemaAndTablesParams{ + Schema: schema, + Tables: tables, + }, + ) if err != nil { return err } @@ -169,10 +191,14 @@ func (p *PostgresManager) GetTableConstraintsByTables(ctx context.Context, schem errgrp.Go(func() error { var err error - constraints, err := p.querier.GetForeignKeyConstraintsBySchemasAndTables(errctx, p.db, &pg_queries.GetForeignKeyConstraintsBySchemasAndTablesParams{ - Schema: schema, - Tables: tables, - }) + constraints, err := p.querier.GetForeignKeyConstraintsBySchemasAndTables( + errctx, + p.db, + &pg_queries.GetForeignKeyConstraintsBySchemasAndTablesParams{ + Schema: schema, + Tables: tables, + }, + ) if err != nil { return err } @@ -205,8 +231,13 @@ func (p *PostgresManager) GetTableConstraintsByTables(ctx context.Context, schem Columns: row.ConstraintColumns, Definition: row.ConstraintDefinition, } - constraint.Fingerprint = sqlmanager_shared.BuildNonForeignKeyConstraintFingerprint(constraint) - result[key].NonForeignKeyConstraints = append(result[key].NonForeignKeyConstraints, constraint) + constraint.Fingerprint = sqlmanager_shared.BuildNonForeignKeyConstraintFingerprint( + constraint, + ) + result[key].NonForeignKeyConstraints = append( + result[key].NonForeignKeyConstraints, + constraint, + ) } for _, row := range fkConstraints { @@ -237,7 +268,9 @@ func (p *PostgresManager) GetTableConstraintsByTables(ctx context.Context, schem return result, nil } -func (p *PostgresManager) GetAllSchemas(ctx context.Context) ([]*sqlmanager_shared.DatabaseSchemaNameRow, error) { +func (p *PostgresManager) GetAllSchemas( + ctx context.Context, +) ([]*sqlmanager_shared.DatabaseSchemaNameRow, error) { rows, err := p.querier.GetAllSchemas(ctx, p.db) if err != nil { return nil, err @@ -251,7 +284,9 @@ func (p *PostgresManager) GetAllSchemas(ctx context.Context) ([]*sqlmanager_shar return result, nil } -func (p *PostgresManager) GetAllTables(ctx context.Context) ([]*sqlmanager_shared.DatabaseTableRow, error) { +func (p *PostgresManager) GetAllTables( + ctx context.Context, +) ([]*sqlmanager_shared.DatabaseTableRow, error) { rows, err := p.querier.GetAllTables(ctx, p.db) if err != nil { return nil, err @@ -267,7 +302,9 @@ func (p *PostgresManager) GetAllTables(ctx context.Context) ([]*sqlmanager_share } // returns: {public.users: { id: struct{}{}, created_at: struct{}{}}} -func (p *PostgresManager) GetSchemaColumnMap(ctx context.Context) (map[string]map[string]*sqlmanager_shared.DatabaseSchemaRow, error) { +func (p *PostgresManager) GetSchemaColumnMap( + ctx context.Context, +) (map[string]map[string]*sqlmanager_shared.DatabaseSchemaRow, error) { dbSchemas, err := p.GetDatabaseSchema(ctx) if err != nil { return nil, err @@ -276,7 +313,10 @@ func (p *PostgresManager) GetSchemaColumnMap(ctx context.Context) (map[string]ma return result, nil } -func (p *PostgresManager) GetTableConstraintsBySchema(ctx context.Context, schemas []string) (*sqlmanager_shared.TableConstraints, error) { +func (p *PostgresManager) GetTableConstraintsBySchema( + ctx context.Context, + schemas []string, +) (*sqlmanager_shared.TableConstraints, error) { if len(schemas) == 0 { return &sqlmanager_shared.TableConstraints{}, nil } @@ -325,7 +365,9 @@ func (p *PostgresManager) GetTableConstraintsBySchema(ctx context.Context, schem if _, exists := primaryKeyMap[tableName]; !exists { primaryKeyMap[tableName] = []string{} } - primaryKeyMap[tableName] = append(primaryKeyMap[tableName], sqlmanager_shared.DedupeSlice(row.ConstraintColumns)...) + primaryKeyMap[tableName] = append( + primaryKeyMap[tableName], + sqlmanager_shared.DedupeSlice(row.ConstraintColumns)...) case "u": columns := sqlmanager_shared.DedupeSlice(row.ConstraintColumns) uniqueConstraintsMap[tableName] = append(uniqueConstraintsMap[tableName], columns) @@ -336,20 +378,34 @@ func (p *PostgresManager) GetTableConstraintsBySchema(ctx context.Context, schem for _, row := range fkConstraints { tableName := sqlmanager_shared.BuildTable(row.ReferencingSchema, row.ReferencingTable) if len(row.ReferencingColumns) != len(row.ReferencedColumns) { - return nil, fmt.Errorf("length of columns was not equal to length of foreign key cols: %d %d", len(row.ReferencingColumns), len(row.ReferencedColumns)) + return nil, fmt.Errorf( + "length of columns was not equal to length of foreign key cols: %d %d", + len(row.ReferencingColumns), + len(row.ReferencedColumns), + ) } if len(row.ReferencingColumns) != len(row.NotNullable) { - return nil, fmt.Errorf("length of columns was not equal to length of not nullable cols: %d %d", len(row.ReferencingColumns), len(row.NotNullable)) + return nil, fmt.Errorf( + "length of columns was not equal to length of not nullable cols: %d %d", + len(row.ReferencingColumns), + len(row.NotNullable), + ) } - foreignKeyMap[tableName] = append(foreignKeyMap[tableName], &sqlmanager_shared.ForeignConstraint{ - Columns: row.ReferencingColumns, - NotNullable: row.NotNullable, - ForeignKey: &sqlmanager_shared.ForeignKey{ - Table: sqlmanager_shared.BuildTable(row.ReferencedSchema, row.ReferencedTable), - Columns: row.ReferencedColumns, + foreignKeyMap[tableName] = append( + foreignKeyMap[tableName], + &sqlmanager_shared.ForeignConstraint{ + Columns: row.ReferencingColumns, + NotNullable: row.NotNullable, + ForeignKey: &sqlmanager_shared.ForeignKey{ + Table: sqlmanager_shared.BuildTable( + row.ReferencedSchema, + row.ReferencedTable, + ), + Columns: row.ReferencedColumns, + }, }, - }) + ) } uniqueIndexesMap := map[string][][]string{} @@ -382,7 +438,10 @@ func (p *PostgresManager) GetRolePermissionsMap(ctx context.Context) (map[string return schemaTablePrivsMap, err } -func (p *PostgresManager) GetSchemaTableTriggers(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableTrigger, error) { +func (p *PostgresManager) GetSchemaTableTriggers( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.TableTrigger, error) { if len(tables) == 0 { return []*sqlmanager_shared.TableTrigger{}, nil } @@ -405,7 +464,12 @@ func (p *PostgresManager) GetSchemaTableTriggers(ctx context.Context, tables []* Schema: row.SchemaName, Table: row.TableName, TriggerName: row.TriggerName, - Definition: wrapPgIdempotentTrigger(row.SchemaName, row.TableName, row.TriggerName, row.Definition), + Definition: wrapPgIdempotentTrigger( + row.SchemaName, + row.TableName, + row.TriggerName, + row.Definition, + ), } trigger.Fingerprint = sqlmanager_shared.BuildTriggerFingerprint(trigger) output = append(output, trigger) @@ -414,7 +478,10 @@ func (p *PostgresManager) GetSchemaTableTriggers(ctx context.Context, tables []* } // Returns ansilary dependencies like sequences, datatypes, functions, etc that are used by tables, but live at the schema level -func (p *PostgresManager) GetSchemaTableDataTypes(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) (*sqlmanager_shared.SchemaTableDataTypeResponse, error) { +func (p *PostgresManager) GetSchemaTableDataTypes( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) (*sqlmanager_shared.SchemaTableDataTypeResponse, error) { if len(tables) == 0 { return &sqlmanager_shared.SchemaTableDataTypeResponse{}, nil } @@ -474,11 +541,19 @@ func (p *PostgresManager) GetSchemaTableDataTypes(ctx context.Context, tables [] return output, nil } -func (p *PostgresManager) GetSequencesByTables(ctx context.Context, schema string, tables []string) ([]*sqlmanager_shared.DataType, error) { - rows, err := p.querier.GetCustomSequencesBySchemaAndTables(ctx, p.db, &pg_queries.GetCustomSequencesBySchemaAndTablesParams{ - Schema: schema, - Tables: tables, - }) +func (p *PostgresManager) GetSequencesByTables( + ctx context.Context, + schema string, + tables []string, +) ([]*sqlmanager_shared.DataType, error) { + rows, err := p.querier.GetCustomSequencesBySchemaAndTables( + ctx, + p.db, + &pg_queries.GetCustomSequencesBySchemaAndTablesParams{ + Schema: schema, + Tables: tables, + }, + ) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err } else if err != nil && neosyncdb.IsNoRows(err) { @@ -498,7 +573,10 @@ func (p *PostgresManager) GetSequencesByTables(ctx context.Context, schema strin return output, nil } -func (p *PostgresManager) getExtensionsBySchemas(ctx context.Context, schemas []string) ([]*sqlmanager_shared.ExtensionDataType, error) { +func (p *PostgresManager) getExtensionsBySchemas( + ctx context.Context, + schemas []string, +) ([]*sqlmanager_shared.ExtensionDataType, error) { rows, err := p.querier.GetExtensionsBySchemas(ctx, p.db, schemas) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err @@ -509,8 +587,12 @@ func (p *PostgresManager) getExtensionsBySchemas(ctx context.Context, schemas [] output := make([]*sqlmanager_shared.ExtensionDataType, 0, len(rows)) for _, row := range rows { output = append(output, &sqlmanager_shared.ExtensionDataType{ - Name: row.ExtensionName, - Definition: wrapPgIdempotentExtension(row.SchemaName, row.ExtensionName, row.InstalledVersion), + Name: row.ExtensionName, + Definition: wrapPgIdempotentExtension( + row.SchemaName, + row.ExtensionName, + row.InstalledVersion, + ), }) } return output, nil @@ -524,14 +606,27 @@ func wrapPgIdempotentExtension( if schema.Valid && strings.EqualFold(schema.String, "public") { return fmt.Sprintf(`CREATE EXTENSION IF NOT EXISTS %q VERSION %q;`, extensionName, version) } - return fmt.Sprintf(`CREATE EXTENSION IF NOT EXISTS %q VERSION %q SCHEMA %q;`, extensionName, version, schema.String) + return fmt.Sprintf( + `CREATE EXTENSION IF NOT EXISTS %q VERSION %q SCHEMA %q;`, + extensionName, + version, + schema.String, + ) } -func (p *PostgresManager) getFunctionsByTables(ctx context.Context, schema string, tables []string) ([]*sqlmanager_shared.DataType, error) { - rows, err := p.querier.GetCustomFunctionsBySchemaAndTables(ctx, p.db, &pg_queries.GetCustomFunctionsBySchemaAndTablesParams{ - Schema: schema, - Tables: tables, - }) +func (p *PostgresManager) getFunctionsByTables( + ctx context.Context, + schema string, + tables []string, +) ([]*sqlmanager_shared.DataType, error) { + rows, err := p.querier.GetCustomFunctionsBySchemaAndTables( + ctx, + p.db, + &pg_queries.GetCustomFunctionsBySchemaAndTablesParams{ + Schema: schema, + Tables: tables, + }, + ) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err } else if err != nil && neosyncdb.IsNoRows(err) { @@ -541,11 +636,20 @@ func (p *PostgresManager) getFunctionsByTables(ctx context.Context, schema strin output := make([]*sqlmanager_shared.DataType, 0, len(rows)) for _, row := range rows { function := &sqlmanager_shared.DataType{ - Schema: row.SchemaName, - Name: row.FunctionName, - Definition: wrapPgIdempotentFunction(row.SchemaName, row.FunctionName, row.FunctionSignature, row.Definition), + Schema: row.SchemaName, + Name: row.FunctionName, + Definition: wrapPgIdempotentFunction( + row.SchemaName, + row.FunctionName, + row.FunctionSignature, + row.Definition, + ), } - function.Fingerprint = sqlmanager_shared.BuildFingerprint(function.Schema, function.Name, function.Definition) + function.Fingerprint = sqlmanager_shared.BuildFingerprint( + function.Schema, + function.Name, + function.Definition, + ) output = append(output, function) } return output, nil @@ -557,11 +661,19 @@ type datatypes struct { Domains []*sqlmanager_shared.DataType } -func (p *PostgresManager) getDataTypesByTables(ctx context.Context, schema string, tables []string) (*datatypes, error) { - rows, err := p.querier.GetDataTypesBySchemaAndTables(ctx, p.db, &pg_queries.GetDataTypesBySchemaAndTablesParams{ - Schema: schema, - Tables: tables, - }) +func (p *PostgresManager) getDataTypesByTables( + ctx context.Context, + schema string, + tables []string, +) (*datatypes, error) { + rows, err := p.querier.GetDataTypesBySchemaAndTables( + ctx, + p.db, + &pg_queries.GetDataTypesBySchemaAndTablesParams{ + Schema: schema, + Tables: tables, + }, + ) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err } else if err != nil && neosyncdb.IsNoRows(err) { @@ -589,7 +701,10 @@ func (p *PostgresManager) getDataTypesByTables(ctx context.Context, schema strin return output, nil } -func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableInitStatement, error) { +func (p *PostgresManager) GetTableInitStatements( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.TableInitStatement, error) { if len(tables) == 0 { return []*sqlmanager_shared.TableInitStatement{}, nil } @@ -609,12 +724,19 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* colDefMap := map[string][]*pg_queries.GetDatabaseTableSchemasBySchemasAndTablesRow{} errgrp.Go(func() error { - columnDefs, err := p.querier.GetDatabaseTableSchemasBySchemasAndTables(errctx, p.db, combined) + columnDefs, err := p.querier.GetDatabaseTableSchemasBySchemasAndTables( + errctx, + p.db, + combined, + ) if err != nil { return err } for _, columnDefinition := range columnDefs { - key := sqlmanager_shared.SchemaTable{Schema: columnDefinition.SchemaName, Table: columnDefinition.TableName} + key := sqlmanager_shared.SchemaTable{ + Schema: columnDefinition.SchemaName, + Table: columnDefinition.TableName, + } colDefMap[key.String()] = append(colDefMap[key.String()], columnDefinition) } return nil @@ -622,12 +744,19 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* constraintmap := map[string][]*pg_queries.GetNonForeignKeyTableConstraintsBySchemaRow{} errgrp.Go(func() error { - constraints, err := p.querier.GetNonForeignKeyTableConstraintsBySchema(errctx, p.db, schemas) // todo: update this to only grab what is necessary instead of entire schema + constraints, err := p.querier.GetNonForeignKeyTableConstraintsBySchema( + errctx, + p.db, + schemas, + ) // todo: update this to only grab what is necessary instead of entire schema if err != nil { return err } for _, constraint := range constraints { - key := sqlmanager_shared.SchemaTable{Schema: constraint.SchemaName, Table: constraint.TableName} + key := sqlmanager_shared.SchemaTable{ + Schema: constraint.SchemaName, + Table: constraint.TableName, + } constraintmap[key.String()] = append(constraintmap[key.String()], constraint) } return nil @@ -640,7 +769,10 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* return err } for _, constraint := range fkConstraints { - key := sqlmanager_shared.SchemaTable{Schema: constraint.ReferencingSchema, Table: constraint.ReferencingTable} + key := sqlmanager_shared.SchemaTable{ + Schema: constraint.ReferencingSchema, + Table: constraint.ReferencingTable, + } fkConstraintMap[key.String()] = append(fkConstraintMap[key.String()], constraint) } return nil @@ -654,7 +786,10 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* } for _, record := range idxrecords { key := sqlmanager_shared.SchemaTable{Schema: record.SchemaName, Table: record.TableName} - indexmap[key.String()] = append(indexmap[key.String()], wrapPgIdempotentIndex(record.SchemaName, record.IndexName, record.IndexDefinition)) + indexmap[key.String()] = append( + indexmap[key.String()], + wrapPgIdempotentIndex(record.SchemaName, record.IndexName, record.IndexDefinition), + ) } return nil }) @@ -675,7 +810,11 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* if !record.IsPartitioned { ks := key.String() errgrp.Go(func() error { - partitionhierarchy, err := p.querier.GetPartitionHierarchyByTable(errctx, p.db, ks) + partitionhierarchy, err := p.querier.GetPartitionHierarchyByTable( + errctx, + p.db, + ks, + ) if err != nil { return err } @@ -706,7 +845,9 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* record := record var seqDefinition *string if record.IdentityGeneration != "" && record.SeqStartValue.Valid && record.SeqMinValue.Valid && - record.SeqMaxValue.Valid && record.SeqIncrementBy.Valid && record.SeqCycleOption.Valid && record.SeqCacheValue.Valid { + record.SeqMaxValue.Valid && record.SeqIncrementBy.Valid && + record.SeqCycleOption.Valid && + record.SeqCacheValue.Valid { seqConfig := &SequenceConfiguration{ StartValue: record.SeqStartValue.Int64, MinValue: record.SeqMinValue.Int64, @@ -735,7 +876,13 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* partitionKey = fmt.Sprintf(" PARTITION BY %s", partition.PartitionKey) } info := &sqlmanager_shared.TableInitStatement{ - CreateTableStatement: fmt.Sprintf("CREATE TABLE IF NOT EXISTS %q.%q (%s)%s;", tableData[0].SchemaName, tableData[0].TableName, strings.Join(columns, ", "), partitionKey), + CreateTableStatement: fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %q.%q (%s)%s;", + tableData[0].SchemaName, + tableData[0].TableName, + strings.Join(columns, ", "), + partitionKey, + ), AlterTableStatements: []*sqlmanager_shared.AlterTableStatement{}, IndexStatements: indexmap[key], PartitionStatements: []string{}, @@ -747,22 +894,42 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* } constraintType, err := sqlmanager_shared.ToConstraintType(constraint.ConstraintType) if err != nil { - return nil, fmt.Errorf("failed to convert constraint type '%s': %w", constraint.ConstraintType, err) + return nil, fmt.Errorf( + "failed to convert constraint type '%s': %w", + constraint.ConstraintType, + err, + ) } - info.AlterTableStatements = append(info.AlterTableStatements, &sqlmanager_shared.AlterTableStatement{ - Statement: wrapPgIdempotentConstraint(constraint.SchemaName, constraint.TableName, constraint.ConstraintName, stmt), - ConstraintType: constraintType, - }) + info.AlterTableStatements = append( + info.AlterTableStatements, + &sqlmanager_shared.AlterTableStatement{ + Statement: wrapPgIdempotentConstraint( + constraint.SchemaName, + constraint.TableName, + constraint.ConstraintName, + stmt, + ), + ConstraintType: constraintType, + }, + ) } for _, constraint := range fkConstraintMap[key] { stmt, err := buildAlterStatementByForeignKeyConstraint(constraint) if err != nil { return nil, err } - info.AlterTableStatements = append(info.AlterTableStatements, &sqlmanager_shared.AlterTableStatement{ - Statement: wrapPgIdempotentConstraint(constraint.ReferencingSchema, constraint.ReferencingTable, constraint.ConstraintName, stmt), - ConstraintType: sqlmanager_shared.ForeignConstraintType, - }) + info.AlterTableStatements = append( + info.AlterTableStatements, + &sqlmanager_shared.AlterTableStatement{ + Statement: wrapPgIdempotentConstraint( + constraint.ReferencingSchema, + constraint.ReferencingTable, + constraint.ConstraintName, + stmt, + ), + ConstraintType: sqlmanager_shared.ForeignConstraintType, + }, + ) } for _, partition := range partitionHierarchy[key] { if !partition.ParentSchemaName.Valid || !partition.ParentTableName.Valid { @@ -774,7 +941,18 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* if ok && p.IsPartitioned && p.PartitionKey != "" { partitionKey = fmt.Sprintf(" PARTITION BY %s", p.PartitionKey) } - info.PartitionStatements = append(info.PartitionStatements, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %q.%q PARTITION OF %q.%q %s %s;", partition.SchemaName, partition.TableName, partition.ParentSchemaName.String, partition.ParentTableName.String, partition.PartitionBound, partitionKey)) + info.PartitionStatements = append( + info.PartitionStatements, + fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %q.%q PARTITION OF %q.%q %s %s;", + partition.SchemaName, + partition.TableName, + partition.ParentSchemaName.String, + partition.ParentTableName.String, + partition.PartitionBound, + partitionKey, + ), + ) } output = append(output, info) } @@ -798,7 +976,10 @@ func (p *PostgresManager) GetSchemaInitStatements( schemaStmts := []string{} errgrp.Go(func() error { for schema := range uniqueSchemas { - schemaStmts = append(schemaStmts, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q;", schema)) + schemaStmts = append( + schemaStmts, + fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q;", schema), + ) } return nil }) @@ -883,7 +1064,10 @@ func (p *PostgresManager) GetSchemaInitStatements( } // Finds any schemas referenced in datatypes that don't exist in tables and returns the statements to create them -func getSchemaCreationStatementsFromDataTypes(tables []*sqlmanager_shared.SchemaTable, datatypes *sqlmanager_shared.SchemaTableDataTypeResponse) []string { +func getSchemaCreationStatementsFromDataTypes( + tables []*sqlmanager_shared.SchemaTable, + datatypes *sqlmanager_shared.SchemaTableDataTypeResponse, +) []string { schemaStmts := []string{} schemaSet := map[string]struct{}{} for _, table := range tables { @@ -893,21 +1077,30 @@ func getSchemaCreationStatementsFromDataTypes(tables []*sqlmanager_shared.Schema // Check each datatype schema against the table schemas for _, composite := range datatypes.Composites { if _, exists := schemaSet[composite.Schema]; !exists { - schemaStmts = append(schemaStmts, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q;", composite.Schema)) + schemaStmts = append( + schemaStmts, + fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q;", composite.Schema), + ) schemaSet[composite.Schema] = struct{}{} } } for _, enum := range datatypes.Enums { if _, exists := schemaSet[enum.Schema]; !exists { - schemaStmts = append(schemaStmts, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q;", enum.Schema)) + schemaStmts = append( + schemaStmts, + fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q;", enum.Schema), + ) schemaSet[enum.Schema] = struct{}{} } } for _, domain := range datatypes.Domains { if _, exists := schemaSet[domain.Schema]; !exists { - schemaStmts = append(schemaStmts, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q;", domain.Schema)) + schemaStmts = append( + schemaStmts, + fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q;", domain.Schema), + ) schemaSet[domain.Schema] = struct{}{} } } @@ -1074,8 +1267,13 @@ func buildAlterStatementByForeignKeyConstraint( } return fmt.Sprintf( "ALTER TABLE %q.%q ADD CONSTRAINT %q FOREIGN KEY (%s) REFERENCES %q.%q (%s);", - constraint.ReferencingSchema, constraint.ReferencingTable, constraint.ConstraintName, strings.Join(EscapePgColumns(constraint.ReferencingColumns), ", "), - constraint.ReferencedSchema, constraint.ReferencedTable, strings.Join(EscapePgColumns(constraint.ReferencedColumns), ", "), + constraint.ReferencingSchema, + constraint.ReferencingTable, + constraint.ConstraintName, + strings.Join(EscapePgColumns(constraint.ReferencingColumns), ", "), + constraint.ReferencedSchema, + constraint.ReferencedTable, + strings.Join(EscapePgColumns(constraint.ReferencedColumns), ", "), ), nil } @@ -1087,7 +1285,10 @@ func buildAlterStatementByConstraint( } return fmt.Sprintf( "ALTER TABLE %q.%q ADD CONSTRAINT %q %s;", - constraint.SchemaName, constraint.TableName, constraint.ConstraintName, constraint.ConstraintDefinition, + constraint.SchemaName, + constraint.TableName, + constraint.ConstraintName, + constraint.ConstraintDefinition, ), nil } @@ -1110,7 +1311,12 @@ func BuildDropColumnStatement(schema, table, column string) string { func BuildDropConstraintStatement(schema, table, constraintName string) string { // cascade is used to drop the constraint and any dependent objects (other constraints, indexes, triggers, etc) - return fmt.Sprintf("ALTER TABLE %q.%q DROP CONSTRAINT IF EXISTS %q CASCADE;", schema, table, constraintName) + return fmt.Sprintf( + "ALTER TABLE %q.%q DROP CONSTRAINT IF EXISTS %q CASCADE;", + schema, + table, + constraintName, + ) } type buildTableColRequest struct { @@ -1154,7 +1360,11 @@ func (s *SequenceConfiguration) toCycelText() string { } func buildTableCol(record *buildTableColRequest) string { - pieces := []string{EscapePgColumn(record.ColumnName), record.DataType, buildNullableText(record.IsNullable)} + pieces := []string{ + EscapePgColumn(record.ColumnName), + record.DataType, + buildNullableText(record.IsNullable), + } if record.IsSerial { switch record.DataType { @@ -1192,7 +1402,13 @@ func BuildUpdateCommentStatement(schema, table, column string, comment *string) if comment == nil || *comment == "" { return fmt.Sprintf("COMMENT ON COLUMN %q.%q.%q IS NULL;", schema, table, column) } - return fmt.Sprintf("COMMENT ON COLUMN %q.%q.%q IS '%s';", schema, table, column, strings.ReplaceAll(*comment, "'", "''")) + return fmt.Sprintf( + "COMMENT ON COLUMN %q.%q.%q IS '%s';", + schema, + table, + column, + strings.ReplaceAll(*comment, "'", "''"), + ) } func buildNullableText(isNullable bool) string { @@ -1202,7 +1418,12 @@ func buildNullableText(isNullable bool) string { return "NOT NULL" } -func (p *PostgresManager) BatchExec(ctx context.Context, batchSize int, statements []string, opts *sqlmanager_shared.BatchExecOpts) error { +func (p *PostgresManager) BatchExec( + ctx context.Context, + batchSize int, + statements []string, + opts *sqlmanager_shared.BatchExecOpts, +) error { for i := 0; i < len(statements); i += batchSize { end := i + batchSize if end > len(statements) { @@ -1308,7 +1529,15 @@ func EscapePgColumn(col string) string { func BuildPgIdentityColumnResetCurrentSql( schema, table, column string, ) string { - return fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%q.%q', '%s'), COALESCE((SELECT MAX(%q) FROM %q.%q), 1));", schema, table, column, column, schema, table) + return fmt.Sprintf( + "SELECT setval(pg_get_serial_sequence('%q.%q', '%s'), COALESCE((SELECT MAX(%q) FROM %q.%q), 1));", + schema, + table, + column, + column, + schema, + table, + ) } func BuildPgInsertIdentityAlwaysSql( @@ -1322,7 +1551,9 @@ func BuildPgResetSequenceSql(schema, sequenceName string) string { return fmt.Sprintf("ALTER SEQUENCE %q.%q RESTART;", schema, sequenceName) } -func GetPostgresColumnOverrideAndResetProperties(columnInfo *sqlmanager_shared.DatabaseSchemaRow) (needsOverride, needsReset bool) { +func GetPostgresColumnOverrideAndResetProperties( + columnInfo *sqlmanager_shared.DatabaseSchemaRow, +) (needsOverride, needsReset bool) { needsOverride = false needsReset = false @@ -1339,7 +1570,8 @@ func GetPostgresColumnOverrideAndResetProperties(columnInfo *sqlmanager_shared.D } // check if column default is sequence - if columnInfo.ColumnDefault != "" && gotypeutil.CaseInsensitiveContains(columnInfo.ColumnDefault, "nextVal") { + if columnInfo.ColumnDefault != "" && + gotypeutil.CaseInsensitiveContains(columnInfo.ColumnDefault, "nextVal") { needsReset = true return } diff --git a/backend/pkg/sqlmanager/sql-manager.go b/backend/pkg/sqlmanager/sql-manager.go index 68c0994323..e15092b963 100644 --- a/backend/pkg/sqlmanager/sql-manager.go +++ b/backend/pkg/sqlmanager/sql-manager.go @@ -22,26 +22,62 @@ import ( type SqlDatabase interface { // Schema level methods for managing and retrieving information about the database schema GetDatabaseSchema(ctx context.Context) ([]*sqlmanager_shared.DatabaseSchemaRow, error) - GetSchemaColumnMap(ctx context.Context) (map[string]map[string]*sqlmanager_shared.DatabaseSchemaRow, error) // ex: {public.users: { id: struct{}{}, created_at: struct{}{}}} - GetTableConstraintsBySchema(ctx context.Context, schemas []string) (*sqlmanager_shared.TableConstraints, error) - GetTableInitStatements(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableInitStatement, error) + GetSchemaColumnMap( + ctx context.Context, + ) (map[string]map[string]*sqlmanager_shared.DatabaseSchemaRow, error) // ex: {public.users: { id: struct{}{}, created_at: struct{}{}}} + GetTableConstraintsBySchema( + ctx context.Context, + schemas []string, + ) (*sqlmanager_shared.TableConstraints, error) + GetTableInitStatements( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, + ) ([]*sqlmanager_shared.TableInitStatement, error) GetRolePermissionsMap(ctx context.Context) (map[string][]string, error) GetAllSchemas(ctx context.Context) ([]*sqlmanager_shared.DatabaseSchemaNameRow, error) GetAllTables(ctx context.Context) ([]*sqlmanager_shared.DatabaseTableRow, error) // Table level methods for managing and retrieving information about the tables within the database - GetDatabaseTableSchemasBySchemasAndTables(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.DatabaseSchemaRow, error) + GetDatabaseTableSchemasBySchemasAndTables( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, + ) ([]*sqlmanager_shared.DatabaseSchemaRow, error) GetTableRowCount(ctx context.Context, schema, table string, whereClause *string) (int64, error) - GetSchemaTableDataTypes(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) (*sqlmanager_shared.SchemaTableDataTypeResponse, error) - GetSchemaTableTriggers(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableTrigger, error) - GetSchemaInitStatements(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.InitSchemaStatements, error) - GetSequencesByTables(ctx context.Context, schema string, tables []string) ([]*sqlmanager_shared.DataType, error) + GetSchemaTableDataTypes( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, + ) (*sqlmanager_shared.SchemaTableDataTypeResponse, error) + GetSchemaTableTriggers( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, + ) ([]*sqlmanager_shared.TableTrigger, error) + GetSchemaInitStatements( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, + ) ([]*sqlmanager_shared.InitSchemaStatements, error) + GetSequencesByTables( + ctx context.Context, + schema string, + tables []string, + ) ([]*sqlmanager_shared.DataType, error) // returns a map of schema.table to all constraints for that table - GetTableConstraintsByTables(ctx context.Context, schema string, tables []string) (map[string]*sqlmanager_shared.AllTableConstraints, error) - GetColumnsByTables(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableColumn, error) + GetTableConstraintsByTables( + ctx context.Context, + schema string, + tables []string, + ) (map[string]*sqlmanager_shared.AllTableConstraints, error) + GetColumnsByTables( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, + ) ([]*sqlmanager_shared.TableColumn, error) // Connection level methods for managing database connections and executing statements - BatchExec(ctx context.Context, batchSize int, statements []string, opts *sqlmanager_shared.BatchExecOpts) error + BatchExec( + ctx context.Context, + batchSize int, + statements []string, + opts *sqlmanager_shared.BatchExecOpts, + ) error Exec(ctx context.Context, statement string) error Close() } @@ -66,7 +102,9 @@ func NewSqlManager( pgQuerier: pg_queries.New(), mysqlQuerier: mysql_queries.New(), mssqlQuerier: mssql_queries.New(), - mgr: connectionmanager.NewConnectionManager(sqlprovider.NewProvider(&sqlconnect.SqlOpenConnector{})), + mgr: connectionmanager.NewConnectionManager( + sqlprovider.NewProvider(&sqlconnect.SqlOpenConnector{}), + ), } for _, opt := range opts { opt(config) @@ -76,7 +114,9 @@ func NewSqlManager( } } -func WithConnectionManager(manager connectionmanager.Interface[neosync_benthos_sql.SqlDbtx]) SqlManagerOption { +func WithConnectionManager( + manager connectionmanager.Interface[neosync_benthos_sql.SqlDbtx], +) SqlManagerOption { return func(smc *sqlManagerConfig) { smc.mgr = manager } @@ -85,7 +125,9 @@ func WithConnectionManager(manager connectionmanager.Interface[neosync_benthos_s // Initializes a default SQL-enabled connection manager, but allows for providing options func WithConnectionManagerOpts(opts ...connectionmanager.ManagerOption) SqlManagerOption { return func(smc *sqlManagerConfig) { - smc.mgr = connectionmanager.NewConnectionManager(sqlprovider.NewProvider(&sqlconnect.SqlOpenConnector{}), opts...) + smc.mgr = connectionmanager.NewConnectionManager( + sqlprovider.NewProvider(&sqlconnect.SqlOpenConnector{}), + opts...) } } @@ -146,16 +188,25 @@ func (s *SqlManager) NewSqlConnection( } } -func GetColumnOverrideAndResetProperties(driver string, cInfo *sqlmanager_shared.DatabaseSchemaRow) (needsOverride, needsReset bool, err error) { +func GetColumnOverrideAndResetProperties( + driver string, + cInfo *sqlmanager_shared.DatabaseSchemaRow, +) (needsOverride, needsReset bool, err error) { switch driver { case sqlmanager_shared.PostgresDriver: - needsOverride, needsReset := sqlmanager_postgres.GetPostgresColumnOverrideAndResetProperties(cInfo) + needsOverride, needsReset := sqlmanager_postgres.GetPostgresColumnOverrideAndResetProperties( + cInfo, + ) return needsOverride, needsReset, nil case sqlmanager_shared.MysqlDriver: - needsOverride, needsReset := sqlmanager_mysql.GetMysqlColumnOverrideAndResetProperties(cInfo) + needsOverride, needsReset := sqlmanager_mysql.GetMysqlColumnOverrideAndResetProperties( + cInfo, + ) return needsOverride, needsReset, nil case sqlmanager_shared.MssqlDriver: - needsOverride, needsReset := sqlmanager_mssql.GetMssqlColumnOverrideAndResetProperties(cInfo) + needsOverride, needsReset := sqlmanager_mssql.GetMssqlColumnOverrideAndResetProperties( + cInfo, + ) return needsOverride, needsReset, nil default: return false, false, fmt.Errorf("unsupported sql driver: %s", driver) diff --git a/backend/pkg/sqlretry/dbtx_retry.go b/backend/pkg/sqlretry/dbtx_retry.go index 491bb76d26..d36b0b77a4 100644 --- a/backend/pkg/sqlretry/dbtx_retry.go +++ b/backend/pkg/sqlretry/dbtx_retry.go @@ -46,7 +46,13 @@ func NewDefault(dbtx sqldbtx.DBTX, logger *slog.Logger) *RetryDBTX { backoff.WithMaxTries(25), backoff.WithMaxElapsedTime(5 * time.Minute), backoff.WithNotify(func(err error, d time.Duration) { - logger.Warn(fmt.Sprintf("sql error with retry: %s, retrying in %s", err.Error(), d.String())) + logger.Warn( + fmt.Sprintf( + "sql error with retry: %s, retrying in %s", + err.Error(), + d.String(), + ), + ) }), } }, @@ -71,7 +77,11 @@ func WithRetryOptions(getRetryOpts func() []backoff.RetryOption) Option { } } -func (r *RetryDBTX) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { +func (r *RetryDBTX) ExecContext( + ctx context.Context, + query string, + args ...any, +) (sql.Result, error) { operation := func() (sql.Result, error) { return r.dbtx.ExecContext(ctx, query, args...) } @@ -85,7 +95,11 @@ func (r *RetryDBTX) PrepareContext(ctx context.Context, query string) (*sql.Stmt return backoffutil.Retry(ctx, operation, r.config.getRetryOpts, isRetryableError) } -func (r *RetryDBTX) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { +func (r *RetryDBTX) QueryContext( + ctx context.Context, + query string, + args ...any, +) (*sql.Rows, error) { operation := func() (*sql.Rows, error) { return r.dbtx.QueryContext(ctx, query, args...) } @@ -111,7 +125,11 @@ func (r *RetryDBTX) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, return backoffutil.Retry(ctx, operation, r.config.getRetryOpts, isRetryableError) } -func (r *RetryDBTX) RetryTx(ctx context.Context, opts *sql.TxOptions, fn func(*sql.Tx) error) error { +func (r *RetryDBTX) RetryTx( + ctx context.Context, + opts *sql.TxOptions, + fn func(*sql.Tx) error, +) error { operation := func() (any, error) { tx, err := r.dbtx.BeginTx(ctx, opts) if err != nil { diff --git a/backend/pkg/table-dependency/table-dependency.go b/backend/pkg/table-dependency/table-dependency.go index 9a31fe28a3..c03ac02a20 100644 --- a/backend/pkg/table-dependency/table-dependency.go +++ b/backend/pkg/table-dependency/table-dependency.go @@ -41,7 +41,10 @@ func GetTablesOrderedByDependency(dependencyMap map[string][]string) (*OrderedTa dep, ok := dependencyMap[table] if !ok || len(dep) == 0 { s, t := sqlmanager_shared.SplitTableKey(table) - orderedTables = append(orderedTables, &sqlmanager_shared.SchemaTable{Schema: s, Table: t}) + orderedTables = append( + orderedTables, + &sqlmanager_shared.SchemaTable{Schema: s, Table: t}, + ) seenTables[table] = struct{}{} delete(tableMap, table) } @@ -58,7 +61,10 @@ func GetTablesOrderedByDependency(dependencyMap map[string][]string) (*OrderedTa deps := dependencyMap[table] if isReady(seenTables, deps, table, cycles) { s, t := sqlmanager_shared.SplitTableKey(table) - orderedTables = append(orderedTables, &sqlmanager_shared.SchemaTable{Schema: s, Table: t}) + orderedTables = append( + orderedTables, + &sqlmanager_shared.SchemaTable{Schema: s, Table: t}, + ) seenTables[table] = struct{}{} delete(tableMap, table) } diff --git a/backend/services/mgmt/v1alpha1/account-hooks-service/service.go b/backend/services/mgmt/v1alpha1/account-hooks-service/service.go index 4df401075c..3190dcc6ca 100644 --- a/backend/services/mgmt/v1alpha1/account-hooks-service/service.go +++ b/backend/services/mgmt/v1alpha1/account-hooks-service/service.go @@ -23,7 +23,10 @@ func New( } } -func (s *Service) GetAccountHooks(ctx context.Context, req *connect.Request[mgmtv1alpha1.GetAccountHooksRequest]) (*connect.Response[mgmtv1alpha1.GetAccountHooksResponse], error) { +func (s *Service) GetAccountHooks( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.GetAccountHooksRequest], +) (*connect.Response[mgmtv1alpha1.GetAccountHooksResponse], error) { resp, err := s.hookservice.GetAccountHooks(ctx, req.Msg) if err != nil { return nil, err @@ -31,7 +34,10 @@ func (s *Service) GetAccountHooks(ctx context.Context, req *connect.Request[mgmt return connect.NewResponse(resp), nil } -func (s *Service) GetAccountHook(ctx context.Context, req *connect.Request[mgmtv1alpha1.GetAccountHookRequest]) (*connect.Response[mgmtv1alpha1.GetAccountHookResponse], error) { +func (s *Service) GetAccountHook( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.GetAccountHookRequest], +) (*connect.Response[mgmtv1alpha1.GetAccountHookResponse], error) { resp, err := s.hookservice.GetAccountHook(ctx, req.Msg) if err != nil { return nil, err @@ -39,7 +45,10 @@ func (s *Service) GetAccountHook(ctx context.Context, req *connect.Request[mgmtv return connect.NewResponse(resp), nil } -func (s *Service) CreateAccountHook(ctx context.Context, req *connect.Request[mgmtv1alpha1.CreateAccountHookRequest]) (*connect.Response[mgmtv1alpha1.CreateAccountHookResponse], error) { +func (s *Service) CreateAccountHook( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.CreateAccountHookRequest], +) (*connect.Response[mgmtv1alpha1.CreateAccountHookResponse], error) { resp, err := s.hookservice.CreateAccountHook(ctx, req.Msg) if err != nil { return nil, err @@ -47,7 +56,10 @@ func (s *Service) CreateAccountHook(ctx context.Context, req *connect.Request[mg return connect.NewResponse(resp), nil } -func (s *Service) DeleteAccountHook(ctx context.Context, req *connect.Request[mgmtv1alpha1.DeleteAccountHookRequest]) (*connect.Response[mgmtv1alpha1.DeleteAccountHookResponse], error) { +func (s *Service) DeleteAccountHook( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.DeleteAccountHookRequest], +) (*connect.Response[mgmtv1alpha1.DeleteAccountHookResponse], error) { resp, err := s.hookservice.DeleteAccountHook(ctx, req.Msg) if err != nil { return nil, err @@ -55,7 +67,10 @@ func (s *Service) DeleteAccountHook(ctx context.Context, req *connect.Request[mg return connect.NewResponse(resp), nil } -func (s *Service) IsAccountHookNameAvailable(ctx context.Context, req *connect.Request[mgmtv1alpha1.IsAccountHookNameAvailableRequest]) (*connect.Response[mgmtv1alpha1.IsAccountHookNameAvailableResponse], error) { +func (s *Service) IsAccountHookNameAvailable( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.IsAccountHookNameAvailableRequest], +) (*connect.Response[mgmtv1alpha1.IsAccountHookNameAvailableResponse], error) { resp, err := s.hookservice.IsAccountHookNameAvailable(ctx, req.Msg) if err != nil { return nil, err @@ -63,7 +78,10 @@ func (s *Service) IsAccountHookNameAvailable(ctx context.Context, req *connect.R return connect.NewResponse(resp), nil } -func (s *Service) UpdateAccountHook(ctx context.Context, req *connect.Request[mgmtv1alpha1.UpdateAccountHookRequest]) (*connect.Response[mgmtv1alpha1.UpdateAccountHookResponse], error) { +func (s *Service) UpdateAccountHook( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.UpdateAccountHookRequest], +) (*connect.Response[mgmtv1alpha1.UpdateAccountHookResponse], error) { resp, err := s.hookservice.UpdateAccountHook(ctx, req.Msg) if err != nil { return nil, err @@ -71,7 +89,10 @@ func (s *Service) UpdateAccountHook(ctx context.Context, req *connect.Request[mg return connect.NewResponse(resp), nil } -func (s *Service) SetAccountHookEnabled(ctx context.Context, req *connect.Request[mgmtv1alpha1.SetAccountHookEnabledRequest]) (*connect.Response[mgmtv1alpha1.SetAccountHookEnabledResponse], error) { +func (s *Service) SetAccountHookEnabled( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.SetAccountHookEnabledRequest], +) (*connect.Response[mgmtv1alpha1.SetAccountHookEnabledResponse], error) { resp, err := s.hookservice.SetAccountHookEnabled(ctx, req.Msg) if err != nil { return nil, err @@ -79,7 +100,10 @@ func (s *Service) SetAccountHookEnabled(ctx context.Context, req *connect.Reques return connect.NewResponse(resp), nil } -func (s *Service) GetActiveAccountHooksByEvent(ctx context.Context, req *connect.Request[mgmtv1alpha1.GetActiveAccountHooksByEventRequest]) (*connect.Response[mgmtv1alpha1.GetActiveAccountHooksByEventResponse], error) { +func (s *Service) GetActiveAccountHooksByEvent( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.GetActiveAccountHooksByEventRequest], +) (*connect.Response[mgmtv1alpha1.GetActiveAccountHooksByEventResponse], error) { resp, err := s.hookservice.GetActiveAccountHooksByEvent(ctx, req.Msg) if err != nil { return nil, err @@ -87,7 +111,10 @@ func (s *Service) GetActiveAccountHooksByEvent(ctx context.Context, req *connect return connect.NewResponse(resp), nil } -func (s *Service) GetSlackConnectionUrl(ctx context.Context, req *connect.Request[mgmtv1alpha1.GetSlackConnectionUrlRequest]) (*connect.Response[mgmtv1alpha1.GetSlackConnectionUrlResponse], error) { +func (s *Service) GetSlackConnectionUrl( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.GetSlackConnectionUrlRequest], +) (*connect.Response[mgmtv1alpha1.GetSlackConnectionUrlResponse], error) { resp, err := s.hookservice.GetSlackConnectionUrl(ctx, req.Msg) if err != nil { return nil, err @@ -95,7 +122,10 @@ func (s *Service) GetSlackConnectionUrl(ctx context.Context, req *connect.Reques return connect.NewResponse(resp), nil } -func (s *Service) HandleSlackOAuthCallback(ctx context.Context, req *connect.Request[mgmtv1alpha1.HandleSlackOAuthCallbackRequest]) (*connect.Response[mgmtv1alpha1.HandleSlackOAuthCallbackResponse], error) { +func (s *Service) HandleSlackOAuthCallback( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.HandleSlackOAuthCallbackRequest], +) (*connect.Response[mgmtv1alpha1.HandleSlackOAuthCallbackResponse], error) { resp, err := s.hookservice.HandleSlackOAuthCallback(ctx, req.Msg) if err != nil { return nil, err @@ -103,7 +133,10 @@ func (s *Service) HandleSlackOAuthCallback(ctx context.Context, req *connect.Req return connect.NewResponse(resp), nil } -func (s *Service) TestSlackConnection(ctx context.Context, req *connect.Request[mgmtv1alpha1.TestSlackConnectionRequest]) (*connect.Response[mgmtv1alpha1.TestSlackConnectionResponse], error) { +func (s *Service) TestSlackConnection( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.TestSlackConnectionRequest], +) (*connect.Response[mgmtv1alpha1.TestSlackConnectionResponse], error) { resp, err := s.hookservice.TestSlackConnection(ctx, req.Msg) if err != nil { return nil, err @@ -111,7 +144,10 @@ func (s *Service) TestSlackConnection(ctx context.Context, req *connect.Request[ return connect.NewResponse(resp), nil } -func (s *Service) SendSlackMessage(ctx context.Context, req *connect.Request[mgmtv1alpha1.SendSlackMessageRequest]) (*connect.Response[mgmtv1alpha1.SendSlackMessageResponse], error) { +func (s *Service) SendSlackMessage( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.SendSlackMessageRequest], +) (*connect.Response[mgmtv1alpha1.SendSlackMessageResponse], error) { resp, err := s.hookservice.SendSlackMessage(ctx, req.Msg) if err != nil { return nil, err diff --git a/backend/services/mgmt/v1alpha1/anonymization-service/anonymization.go b/backend/services/mgmt/v1alpha1/anonymization-service/anonymization.go index a3c9b3e9e3..e6d295f03f 100644 --- a/backend/services/mgmt/v1alpha1/anonymization-service/anonymization.go +++ b/backend/services/mgmt/v1alpha1/anonymization-service/anonymization.go @@ -34,7 +34,13 @@ func (s *Service) AnonymizeMany( logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) if !s.cfg.IsNeosyncCloud { return nil, nucleuserrors.NewNotImplemented( - fmt.Sprintf("%s is not implemented in the OSS version of Neosync.", strings.TrimPrefix(mgmtv1alpha1connect.AnonymizationServiceAnonymizeManyProcedure, "/")), + fmt.Sprintf( + "%s is not implemented in the OSS version of Neosync.", + strings.TrimPrefix( + mgmtv1alpha1connect.AnonymizationServiceAnonymizeManyProcedure, + "/", + ), + ), ) } @@ -58,7 +64,13 @@ func (s *Service) AnonymizeMany( } if account.AccountType == int16(neosyncdb.AccountType_Personal) { return nil, nucleuserrors.NewForbidden( - fmt.Sprintf("%s is not implemented for personal accounts", strings.TrimPrefix(mgmtv1alpha1connect.AnonymizationServiceAnonymizeManyProcedure, "/")), + fmt.Sprintf( + "%s is not implemented for personal accounts", + strings.TrimPrefix( + mgmtv1alpha1connect.AnonymizationServiceAnonymizeManyProcedure, + "/", + ), + ), ) } @@ -69,23 +81,36 @@ func (s *Service) AnonymizeMany( } requestedCount := uint64(len(req.Msg.InputData)) - resp, err := s.useraccountService.IsAccountStatusValid(ctx, connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{ - AccountId: req.Msg.GetAccountId(), - RequestedRecordCount: &requestedCount, - })) + resp, err := s.useraccountService.IsAccountStatusValid( + ctx, + connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{ + AccountId: req.Msg.GetAccountId(), + RequestedRecordCount: &requestedCount, + }), + ) if err != nil { return nil, fmt.Errorf("unable to retrieve account status: %w", err) } if !resp.Msg.IsValid { - return nil, nucleuserrors.NewBadRequest(fmt.Sprintf("unable to anonymize due to account in invalid state. Reason: %q", *resp.Msg.Reason)) + return nil, nucleuserrors.NewBadRequest( + fmt.Sprintf( + "unable to anonymize due to account in invalid state. Reason: %q", + *resp.Msg.Reason, + ), + ) } anonymizer, err := jsonanonymizer.NewAnonymizer( jsonanonymizer.WithTransformerMappings(req.Msg.TransformerMappings), jsonanonymizer.WithDefaultTransformers(req.Msg.DefaultTransformers), jsonanonymizer.WithHaltOnFailure(req.Msg.HaltOnFailure), - jsonanonymizer.WithConditionalAnonymizeConfig(s.cfg.IsPresidioEnabled, s.analyze, s.anonymize, s.cfg.PresidioDefaultLanguage), + jsonanonymizer.WithConditionalAnonymizeConfig( + s.cfg.IsPresidioEnabled, + s.analyze, + s.anonymize, + s.cfg.PresidioDefaultLanguage, + ), jsonanonymizer.WithTransformerClient(s.transformerClient), jsonanonymizer.WithLogger(logger), ) @@ -168,14 +193,18 @@ func (s *Service) AnonymizeSingle( if !s.cfg.IsNeosyncCloud || account.AccountType == int16(neosyncdb.AccountType_Personal) { for _, mapping := range req.Msg.GetTransformerMappings() { if mapping.GetTransformer().GetTransformPiiTextConfig() != nil { - return nil, nucleuserrors.NewForbidden("TransformPiiText is not available for use. Please contact us to upgrade your account.") + return nil, nucleuserrors.NewForbidden( + "TransformPiiText is not available for use. Please contact us to upgrade your account.", + ) } } defaultTransforms := req.Msg.GetDefaultTransformers() if defaultTransforms.GetBoolean().GetTransformPiiTextConfig() != nil || defaultTransforms.GetN().GetTransformPiiTextConfig() != nil || defaultTransforms.GetS().GetTransformPiiTextConfig() != nil { - return nil, nucleuserrors.NewForbidden("TransformPiiText is not available for use. Please contact us to upgrade your account.") + return nil, nucleuserrors.NewForbidden( + "TransformPiiText is not available for use. Please contact us to upgrade your account.", + ) } } @@ -186,22 +215,35 @@ func (s *Service) AnonymizeSingle( } requestedCount := uint64(len(req.Msg.InputData)) - resp, err := s.useraccountService.IsAccountStatusValid(ctx, connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{ - AccountId: req.Msg.GetAccountId(), - RequestedRecordCount: &requestedCount, - })) + resp, err := s.useraccountService.IsAccountStatusValid( + ctx, + connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{ + AccountId: req.Msg.GetAccountId(), + RequestedRecordCount: &requestedCount, + }), + ) if err != nil { return nil, fmt.Errorf("unable to retrieve account status: %w", err) } if !resp.Msg.IsValid { - return nil, nucleuserrors.NewBadRequest(fmt.Sprintf("unable to anonymize due to account in invalid state. Reason: %q", *resp.Msg.Reason)) + return nil, nucleuserrors.NewBadRequest( + fmt.Sprintf( + "unable to anonymize due to account in invalid state. Reason: %q", + *resp.Msg.Reason, + ), + ) } anonymizer, err := jsonanonymizer.NewAnonymizer( jsonanonymizer.WithTransformerMappings(req.Msg.TransformerMappings), jsonanonymizer.WithDefaultTransformers(req.Msg.DefaultTransformers), - jsonanonymizer.WithConditionalAnonymizeConfig(s.cfg.IsPresidioEnabled, s.analyze, s.anonymize, s.cfg.PresidioDefaultLanguage), + jsonanonymizer.WithConditionalAnonymizeConfig( + s.cfg.IsPresidioEnabled, + s.analyze, + s.anonymize, + s.cfg.PresidioDefaultLanguage, + ), jsonanonymizer.WithTransformerClient(s.transformerClient), jsonanonymizer.WithLogger(logger), ) @@ -254,7 +296,10 @@ func getMetricLabels(ctx context.Context, requestName, accountId string) []attri attribute.String(metrics.AccountIdLabel, accountId), attribute.String(metrics.ApiRequestId, requestId), attribute.String(metrics.ApiRequestName, requestName), - attribute.String(metrics.NeosyncDateLabel, time.Now().UTC().Format(metrics.NeosyncDateFormat)), + attribute.String( + metrics.NeosyncDateLabel, + time.Now().UTC().Format(metrics.NeosyncDateFormat), + ), } } @@ -281,14 +326,21 @@ func validateTransformerConfig(cfg *mgmtv1alpha1.TransformerConfig) error { if defaultAnonymizer != nil { child := defaultAnonymizer.GetTransform().GetConfig().GetTransformPiiTextConfig() if child != nil { - return nucleuserrors.NewBadRequest("found nested TransformPiiText config in default anonymizer. TransformPiiText may not be used deeply nested within itself.") + return nucleuserrors.NewBadRequest( + "found nested TransformPiiText config in default anonymizer. TransformPiiText may not be used deeply nested within itself.", + ) } } entityAnonymizers := root.GetEntityAnonymizers() for entity, entityAnonymizer := range entityAnonymizers { child := entityAnonymizer.GetTransform().GetConfig().GetTransformPiiTextConfig() if child != nil { - return nucleuserrors.NewBadRequest(fmt.Sprintf("found nested TransformPiiText config in entity (%s) anonymizer. TransformPiiText may not be used deeply nested within itself.", entity)) + return nucleuserrors.NewBadRequest( + fmt.Sprintf( + "found nested TransformPiiText config in entity (%s) anonymizer. TransformPiiText may not be used deeply nested within itself.", + entity, + ), + ) } } return nil @@ -299,7 +351,9 @@ type transformerMsgToValidate interface { GetTransformerMappings() []*mgmtv1alpha1.TransformerMapping } -func getTransformerConfigsToValidate(msg transformerMsgToValidate) iter.Seq[*mgmtv1alpha1.TransformerConfig] { +func getTransformerConfigsToValidate( + msg transformerMsgToValidate, +) iter.Seq[*mgmtv1alpha1.TransformerConfig] { return func(yield func(*mgmtv1alpha1.TransformerConfig) bool) { if msg.GetDefaultTransformers().GetBoolean() != nil { if !yield(msg.GetDefaultTransformers().GetBoolean()) { diff --git a/backend/services/mgmt/v1alpha1/api-key-service/api-keys.go b/backend/services/mgmt/v1alpha1/api-key-service/api-keys.go index a9dfb601de..b5edbe0b2e 100644 --- a/backend/services/mgmt/v1alpha1/api-key-service/api-keys.go +++ b/backend/services/mgmt/v1alpha1/api-key-service/api-keys.go @@ -162,12 +162,16 @@ func (s *Service) RegenerateAccountApiKey( if err != nil { return nil, err } - updatedApiKey, err := s.db.Q.UpdateAccountApiKeyValue(ctx, s.db.Db, db_queries.UpdateAccountApiKeyValueParams{ - KeyValue: hashedKeyValue, - ExpiresAt: expiresAt, - UpdatedByID: user.PgId(), - ID: apiKeyUuid, - }) + updatedApiKey, err := s.db.Q.UpdateAccountApiKeyValue( + ctx, + s.db.Db, + db_queries.UpdateAccountApiKeyValueParams{ + KeyValue: hashedKeyValue, + ExpiresAt: expiresAt, + UpdatedByID: user.PgId(), + ID: apiKeyUuid, + }, + ) if err != nil { return nil, err } diff --git a/backend/services/mgmt/v1alpha1/auth-service/tokens.go b/backend/services/mgmt/v1alpha1/auth-service/tokens.go index 2437d41f1a..fd9cae42f8 100644 --- a/backend/services/mgmt/v1alpha1/auth-service/tokens.go +++ b/backend/services/mgmt/v1alpha1/auth-service/tokens.go @@ -25,13 +25,22 @@ func (s *Service) LoginCli( req *connect.Request[mgmtv1alpha1.LoginCliRequest], ) (*connect.Response[mgmtv1alpha1.LoginCliResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) - resp, err := s.authclient.GetTokenResponse(ctx, s.cfg.CliClientId, req.Msg.Code, req.Msg.RedirectUri) + resp, err := s.authclient.GetTokenResponse( + ctx, + s.cfg.CliClientId, + req.Msg.Code, + req.Msg.RedirectUri, + ) if err != nil { return nil, err } if resp.Error != nil { logger.Error( - fmt.Sprintf("Unable to get access token. Title: %s -- Description: %s", resp.Error.Error, resp.Error.ErrorDescription), + fmt.Sprintf( + "Unable to get access token. Title: %s -- Description: %s", + resp.Error.Error, + resp.Error.ErrorDescription, + ), ) return nil, nucleuserrors.NewUnauthenticated("Request unauthenticated") } @@ -66,7 +75,11 @@ func (s *Service) RefreshCli( } if resp.Error != nil { logger.Error( - fmt.Sprintf("Unable to get refreshed token. Title: %s -- Description: %s", resp.Error.Error, resp.Error.ErrorDescription), + fmt.Sprintf( + "Unable to get refreshed token. Title: %s -- Description: %s", + resp.Error.Error, + resp.Error.ErrorDescription, + ), ) return nil, nucleuserrors.NewUnauthenticated("Unable to refresh access token") } diff --git a/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go b/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go index a753ccc6e6..131f713b9b 100644 --- a/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go +++ b/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go @@ -34,9 +34,12 @@ func (s *Service) GetConnectionDataStream( ) error { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("connectionId", req.Msg.ConnectionId) - connResp, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: req.Msg.ConnectionId, - })) + connResp, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: req.Msg.ConnectionId, + }), + ) if err != nil { return err } @@ -46,7 +49,13 @@ func (s *Service) GetConnectionDataStream( if err != nil { return err } - err = connectiondatabuilder.StreamData(ctx, stream, req.Msg.StreamConfig, req.Msg.Schema, req.Msg.Table) + err = connectiondatabuilder.StreamData( + ctx, + stream, + req.Msg.StreamConfig, + req.Msg.Schema, + req.Msg.Table, + ) if err != nil { return err } @@ -96,16 +105,22 @@ func (s *Service) GetConnectionSchemaMap( ctx context.Context, req *connect.Request[mgmtv1alpha1.GetConnectionSchemaMapRequest], ) (*connect.Response[mgmtv1alpha1.GetConnectionSchemaMapResponse], error) { - schemaResp, err := s.GetConnectionSchema(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ - ConnectionId: req.Msg.GetConnectionId(), - SchemaConfig: req.Msg.GetSchemaConfig(), - })) + schemaResp, err := s.GetConnectionSchema( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ + ConnectionId: req.Msg.GetConnectionId(), + SchemaConfig: req.Msg.GetSchemaConfig(), + }), + ) if err != nil { return nil, err } outputMap := map[string]*mgmtv1alpha1.GetConnectionSchemaResponse{} for _, dbcol := range schemaResp.Msg.GetSchemas() { - schematableKey := sqlmanager_shared.SchemaTable{Schema: dbcol.Schema, Table: dbcol.Table}.String() + schematableKey := sqlmanager_shared.SchemaTable{ + Schema: dbcol.Schema, + Table: dbcol.Table, + }.String() resp, ok := outputMap[schematableKey] if !ok { resp = &mgmtv1alpha1.GetConnectionSchemaResponse{} @@ -124,9 +139,12 @@ func (s *Service) GetConnectionSchema( ) (*connect.Response[mgmtv1alpha1.GetConnectionSchemaResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("connectionId", req.Msg.ConnectionId) - connResp, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: req.Msg.ConnectionId, - })) + connResp, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: req.Msg.ConnectionId, + }), + ) if err != nil { return nil, err } @@ -151,18 +169,27 @@ func (s *Service) GetConnectionInitStatements( req *connect.Request[mgmtv1alpha1.GetConnectionInitStatementsRequest], ) (*connect.Response[mgmtv1alpha1.GetConnectionInitStatementsResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) - connection, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: req.Msg.ConnectionId, - })) + connection, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: req.Msg.ConnectionId, + }), + ) if err != nil { return nil, err } - connectiondatabuilder, err := s.connectiondatabuilder.NewDataConnection(logger, connection.Msg.GetConnection()) + connectiondatabuilder, err := s.connectiondatabuilder.NewDataConnection( + logger, + connection.Msg.GetConnection(), + ) if err != nil { return nil, err } - initStatementsResponse, err := connectiondatabuilder.GetInitStatements(ctx, req.Msg.GetOptions()) + initStatementsResponse, err := connectiondatabuilder.GetInitStatements( + ctx, + req.Msg.GetOptions(), + ) if err != nil { return nil, err } @@ -180,26 +207,39 @@ func (s *Service) GetAiGeneratedData( ) (*connect.Response[mgmtv1alpha1.GetAiGeneratedDataResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) _ = logger - aiconnectionResp, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: req.Msg.GetAiConnectionId(), - })) + aiconnectionResp, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: req.Msg.GetAiConnectionId(), + }), + ) if err != nil { return nil, err } aiconnection := aiconnectionResp.Msg.GetConnection() - dbconnectionResp, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: req.Msg.GetDataConnectionId(), - })) + dbconnectionResp, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: req.Msg.GetDataConnectionId(), + }), + ) if err != nil { return nil, err } - connectiondatabuilder, err := s.connectiondatabuilder.NewDataConnection(logger, dbconnectionResp.Msg.GetConnection()) + connectiondatabuilder, err := s.connectiondatabuilder.NewDataConnection( + logger, + dbconnectionResp.Msg.GetConnection(), + ) if err != nil { return nil, err } - dbcols, err := connectiondatabuilder.GetTableSchema(ctx, req.Msg.GetTable().GetSchema(), req.Msg.GetTable().GetTable()) + dbcols, err := connectiondatabuilder.GetTableSchema( + ctx, + req.Msg.GetTable().GetSchema(), + req.Msg.GetTable().GetTable(), + ) if err != nil { return nil, err } @@ -214,17 +254,32 @@ func (s *Service) GetAiGeneratedData( return nil, nucleuserrors.NewBadRequest("connection must be a valid openai connection") } - client, err := azopenai.NewClientForOpenAI(openaiconfig.GetApiUrl(), azcore.NewKeyCredential(openaiconfig.GetApiKey()), &azopenai.ClientOptions{}) + client, err := azopenai.NewClientForOpenAI( + openaiconfig.GetApiUrl(), + azcore.NewKeyCredential(openaiconfig.GetApiKey()), + &azopenai.ClientOptions{}, + ) if err != nil { return nil, fmt.Errorf("unable to init openai client: %w", err) } conversation := []azopenai.ChatRequestMessageClassification{ &azopenai.ChatRequestSystemMessage{ - Content: azopenai.NewChatRequestSystemMessageContent(fmt.Sprintf("You generate data in JSON format. Generate %d records in a json array located on the data key", req.Msg.GetCount())), + Content: azopenai.NewChatRequestSystemMessageContent( + fmt.Sprintf( + "You generate data in JSON format. Generate %d records in a json array located on the data key", + req.Msg.GetCount(), + ), + ), }, &azopenai.ChatRequestUserMessage{ - Content: azopenai.NewChatRequestUserMessageContent(fmt.Sprintf("%s\n%s", req.Msg.GetUserPrompt(), fmt.Sprintf("Each record looks like this: %s", strings.Join(columns, ",")))), + Content: azopenai.NewChatRequestUserMessageContent( + fmt.Sprintf( + "%s\n%s", + req.Msg.GetUserPrompt(), + fmt.Sprintf("Each record looks like this: %s", strings.Join(columns, ",")), + ), + ), }, } @@ -252,7 +307,10 @@ func (s *Service) GetAiGeneratedData( var dataResponse completionResponse err = json.Unmarshal([]byte(*choice.Message.Content), &dataResponse) if err != nil { - return nil, fmt.Errorf("unable to unmarshal openai message content into expected response: %w", err) + return nil, fmt.Errorf( + "unable to unmarshal openai message content into expected response: %w", + err, + ) } dtoRecords := []*structpb.Struct{} @@ -276,14 +334,20 @@ func (s *Service) GetConnectionTableConstraints( req *connect.Request[mgmtv1alpha1.GetConnectionTableConstraintsRequest], ) (*connect.Response[mgmtv1alpha1.GetConnectionTableConstraintsResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) - connection, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: req.Msg.ConnectionId, - })) + connection, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: req.Msg.ConnectionId, + }), + ) if err != nil { return nil, err } - connectiondatabuilder, err := s.connectiondatabuilder.NewDataConnection(logger, connection.Msg.GetConnection()) + connectiondatabuilder, err := s.connectiondatabuilder.NewDataConnection( + logger, + connection.Msg.GetConnection(), + ) if err != nil { return nil, err } @@ -300,17 +364,28 @@ func (s *Service) GetTableRowCount( req *connect.Request[mgmtv1alpha1.GetTableRowCountRequest], ) (*connect.Response[mgmtv1alpha1.GetTableRowCountResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) - connection, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: req.Msg.ConnectionId, - })) + connection, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: req.Msg.ConnectionId, + }), + ) if err != nil { return nil, err } - connectiondatabuilder, err := s.connectiondatabuilder.NewDataConnection(logger, connection.Msg.GetConnection()) + connectiondatabuilder, err := s.connectiondatabuilder.NewDataConnection( + logger, + connection.Msg.GetConnection(), + ) if err != nil { return nil, err } - count, err := connectiondatabuilder.GetTableRowCount(ctx, req.Msg.Schema, req.Msg.Table, req.Msg.WhereClause) + count, err := connectiondatabuilder.GetTableRowCount( + ctx, + req.Msg.Schema, + req.Msg.Table, + req.Msg.WhereClause, + ) if err != nil { return nil, err } @@ -325,14 +400,20 @@ func (s *Service) GetAllSchemasAndTables( req *connect.Request[mgmtv1alpha1.GetAllSchemasAndTablesRequest], ) (*connect.Response[mgmtv1alpha1.GetAllSchemasAndTablesResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) - connection, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: req.Msg.ConnectionId, - })) + connection, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: req.Msg.ConnectionId, + }), + ) if err != nil { return nil, err } - connectiondatabuilder, err := s.connectiondatabuilder.NewDataConnection(logger, connection.Msg.GetConnection()) + connectiondatabuilder, err := s.connectiondatabuilder.NewDataConnection( + logger, + connection.Msg.GetConnection(), + ) if err != nil { return nil, fmt.Errorf("unable to create connection data builder: %w", err) } diff --git a/backend/services/mgmt/v1alpha1/connection-service/connection.go b/backend/services/mgmt/v1alpha1/connection-service/connection.go index d2c6b5b9c0..adb69fa192 100644 --- a/backend/services/mgmt/v1alpha1/connection-service/connection.go +++ b/backend/services/mgmt/v1alpha1/connection-service/connection.go @@ -195,9 +195,12 @@ func (s *Service) CheckConnectionConfigById( return nil, err } - resp, err := s.CheckConnectionConfig(ctx, connect.NewRequest(&mgmtv1alpha1.CheckConnectionConfigRequest{ - ConnectionConfig: connResp.Msg.GetConnection().ConnectionConfig, - })) + resp, err := s.CheckConnectionConfig( + ctx, + connect.NewRequest(&mgmtv1alpha1.CheckConnectionConfigRequest{ + ConnectionConfig: connResp.Msg.GetConnection().ConnectionConfig, + }), + ) if err != nil { return nil, err } @@ -209,9 +212,14 @@ func (s *Service) CheckConnectionConfigById( }), nil } -func getDbRoleFromConnectionConfig(cconfig *mgmtv1alpha1.ConnectionConfig, logger *slog.Logger) (string, error) { +func getDbRoleFromConnectionConfig( + cconfig *mgmtv1alpha1.ConnectionConfig, + logger *slog.Logger, +) (string, error) { if cconfig == nil { - return "", errors.New("connection config was nil, unable to retrieve db role/user from config") + return "", errors.New( + "connection config was nil, unable to retrieve db role/user from config", + ) } switch typedconfig := cconfig.GetConfig().(type) { @@ -254,10 +262,14 @@ func (s *Service) IsConnectionNameAvailable( return nil, err } - count, err := s.db.Q.IsConnectionNameAvailable(ctx, s.db.Db, db_queries.IsConnectionNameAvailableParams{ - AccountId: accountUuid, - ConnectionName: req.Msg.ConnectionName, - }) + count, err := s.db.Q.IsConnectionNameAvailable( + ctx, + s.db.Db, + db_queries.IsConnectionNameAvailableParams{ + AccountId: accountUuid, + ConnectionName: req.Msg.ConnectionName, + }, + ) if err != nil { return nil, err } @@ -278,7 +290,11 @@ func (s *Service) GetConnections( if err := user.EnforceConnection(ctx, userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), rbac.ConnectionAction_View); err != nil { return nil, err } - canViewSensitive, err := user.Connection(ctx, userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), rbac.ConnectionAction_ViewSensitive) + canViewSensitive, err := user.Connection( + ctx, + userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), + rbac.ConnectionAction_ViewSensitive, + ) if err != nil { return nil, err } @@ -336,7 +352,11 @@ func (s *Service) GetConnection( if err := user.EnforceConnection(ctx, userdata.NewDbDomainEntity(connection.AccountID, connection.ID), rbac.ConnectionAction_View); err != nil { return nil, err } - canViewSensitive, err := user.Connection(ctx, userdata.NewDbDomainEntity(connection.AccountID, connection.ID), rbac.ConnectionAction_ViewSensitive) + canViewSensitive, err := user.Connection( + ctx, + userdata.NewDbDomainEntity(connection.AccountID, connection.ID), + rbac.ConnectionAction_ViewSensitive, + ) if err != nil { return nil, err } @@ -518,12 +538,19 @@ func (s *Service) CheckSqlQuery( ) (*connect.Response[mgmtv1alpha1.CheckSqlQueryResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("connectionId", req.Msg.GetId()) - connection, err := s.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{Id: req.Msg.GetId()})) + connection, err := s.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{Id: req.Msg.GetId()}), + ) if err != nil { return nil, err } - conn, err := s.sqlConnector.NewDbFromConnectionConfig(connection.Msg.GetConnection().GetConnectionConfig(), logger, sqlconnect.WithConnectionTimeout(10)) + conn, err := s.sqlConnector.NewDbFromConnectionConfig( + connection.Msg.GetConnection().GetConnectionConfig(), + logger, + sqlconnect.WithConnectionTimeout(10), + ) if err != nil { return nil, err } @@ -575,7 +602,10 @@ func (s *Service) CheckSSHConnectionById( ) (*connect.Response[mgmtv1alpha1.CheckSSHConnectionByIdResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("connectionId", req.Msg.GetId()) - connection, err := s.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{Id: req.Msg.GetId()})) + connection, err := s.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{Id: req.Msg.GetId()}), + ) if err != nil { return nil, err } @@ -601,7 +631,10 @@ func (s *Service) CheckSSHConnectionById( }), nil } -func checkSSHConnection(sshTunnel *mgmtv1alpha1.SSHTunnel, logger *slog.Logger) (*mgmtv1alpha1.CheckSSHConnectionResult, error) { +func checkSSHConnection( + sshTunnel *mgmtv1alpha1.SSHTunnel, + logger *slog.Logger, +) (*mgmtv1alpha1.CheckSSHConnectionResult, error) { if sshTunnel == nil { errorMsg := "no ssh tunnel config found" return &mgmtv1alpha1.CheckSSHConnectionResult{ diff --git a/backend/services/mgmt/v1alpha1/job-service/jobs.go b/backend/services/mgmt/v1alpha1/job-service/jobs.go index 7e32ee6000..1c0c675bdd 100644 --- a/backend/services/mgmt/v1alpha1/job-service/jobs.go +++ b/backend/services/mgmt/v1alpha1/job-service/jobs.go @@ -43,7 +43,11 @@ func (s *Service) GetJobs( if err != nil { return nil, err } - err = user.EnforceJob(ctx, userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), rbac.JobAction_View) + err = user.EnforceJob( + ctx, + userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), + rbac.JobAction_View, + ) if err != nil { return nil, err } @@ -70,7 +74,11 @@ func (s *Service) GetJobs( var destinationAssociations []db_queries.NeosyncApiJobDestinationConnectionAssociation if len(jobIds) > 0 { - destinationAssociations, err = s.db.Q.GetJobConnectionDestinationsByJobIds(ctx, s.db.Db, jobIds) + destinationAssociations, err = s.db.Q.GetJobConnectionDestinationsByJobIds( + ctx, + s.db.Db, + jobIds, + ) if err != nil && !neosyncdb.IsNoRows(err) { return nil, fmt.Errorf("unable to get job connection destinations by job ids: %w", err) } else if err != nil && neosyncdb.IsNoRows(err) { @@ -171,14 +179,25 @@ func (s *Service) GetJobStatus( ) (*connect.Response[mgmtv1alpha1.GetJobStatusResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("jobId", req.Msg.GetJobId()) - jobResp, err := s.GetJob(ctx, connect.NewRequest(&mgmtv1alpha1.GetJobRequest{Id: req.Msg.GetJobId()})) + jobResp, err := s.GetJob( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetJobRequest{Id: req.Msg.GetJobId()}), + ) if err != nil { return nil, err } - schedule, err := s.temporalmgr.DescribeSchedule(ctx, jobResp.Msg.GetJob().GetAccountId(), jobResp.Msg.GetJob().GetId(), logger) + schedule, err := s.temporalmgr.DescribeSchedule( + ctx, + jobResp.Msg.GetJob().GetAccountId(), + jobResp.Msg.GetJob().GetId(), + logger, + ) if err != nil { - return nil, fmt.Errorf("unable to describe temporal schedule when retrieving job status: %w", err) + return nil, fmt.Errorf( + "unable to describe temporal schedule when retrieving job status: %w", + err, + ) } return connect.NewResponse(&mgmtv1alpha1.GetJobStatusResponse{ @@ -196,7 +215,11 @@ func (s *Service) GetJobStatuses( if err != nil { return nil, err } - err = user.EnforceJob(ctx, userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), rbac.JobAction_View) + err = user.EnforceJob( + ctx, + userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), + rbac.JobAction_View, + ) if err != nil { return nil, err } @@ -225,13 +248,19 @@ func (s *Service) GetJobStatuses( logger, ) if err != nil { - return nil, fmt.Errorf("unable to describe temporal schedules when retrieving job statuses: %w", err) + return nil, fmt.Errorf( + "unable to describe temporal schedules when retrieving job statuses: %w", + err, + ) } dtos := make([]*mgmtv1alpha1.JobStatusRecord, len(jobs)) for idx, resp := range responses { if resp.Error != nil { - logger.Warn(fmt.Errorf("unable to describe temporal schedule when retrieving job statuses: %w", resp.Error).Error()) + logger.Warn( + fmt.Errorf("unable to describe temporal schedule when retrieving job statuses: %w", resp.Error). + Error(), + ) } else if resp.Schedule != nil { dtos[idx] = &mgmtv1alpha1.JobStatusRecord{ JobId: scheduleIds[idx], @@ -252,14 +281,25 @@ func (s *Service) GetJobRecentRuns( logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("jobId", req.Msg.JobId) - jobResp, err := s.GetJob(ctx, connect.NewRequest(&mgmtv1alpha1.GetJobRequest{Id: req.Msg.GetJobId()})) + jobResp, err := s.GetJob( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetJobRequest{Id: req.Msg.GetJobId()}), + ) if err != nil { return nil, err } - schedule, err := s.temporalmgr.DescribeSchedule(ctx, jobResp.Msg.GetJob().GetAccountId(), jobResp.Msg.GetJob().GetId(), logger) + schedule, err := s.temporalmgr.DescribeSchedule( + ctx, + jobResp.Msg.GetJob().GetAccountId(), + jobResp.Msg.GetJob().GetId(), + logger, + ) if err != nil { - return nil, fmt.Errorf("unable to describe temporal schedule when retrieving job recent runs: %w", err) + return nil, fmt.Errorf( + "unable to describe temporal schedule when retrieving job recent runs: %w", + err, + ) } return connect.NewResponse(&mgmtv1alpha1.GetJobRecentRunsResponse{ @@ -274,14 +314,25 @@ func (s *Service) GetJobNextRuns( logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("jobId", req.Msg.GetJobId()) - jobResp, err := s.GetJob(ctx, connect.NewRequest(&mgmtv1alpha1.GetJobRequest{Id: req.Msg.GetJobId()})) + jobResp, err := s.GetJob( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetJobRequest{Id: req.Msg.GetJobId()}), + ) if err != nil { return nil, err } - schedule, err := s.temporalmgr.DescribeSchedule(ctx, jobResp.Msg.GetJob().GetAccountId(), jobResp.Msg.GetJob().GetId(), logger) + schedule, err := s.temporalmgr.DescribeSchedule( + ctx, + jobResp.Msg.GetJob().GetAccountId(), + jobResp.Msg.GetJob().GetId(), + logger, + ) if err != nil { - return nil, fmt.Errorf("unable to describe temporal schedule when retrieving job next runs: %w", err) + return nil, fmt.Errorf( + "unable to describe temporal schedule when retrieving job next runs: %w", + err, + ) } return connect.NewResponse(&mgmtv1alpha1.GetJobNextRunsResponse{ @@ -305,7 +356,11 @@ func (s *Service) CreateJob( if err != nil { return nil, err } - err = user.EnforceJob(ctx, userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), rbac.JobAction_Create) + err = user.EnforceJob( + ctx, + userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), + rbac.JobAction_Create, + ) if err != nil { return nil, err } @@ -333,10 +388,14 @@ func (s *Service) CreateJob( } logger.Debug("verifying connections") - count, err := s.db.Q.AreConnectionsInAccount(ctx, s.db.Db, db_queries.AreConnectionsInAccountParams{ - AccountId: accountUuid, - ConnectionIds: connectionUuids, - }) + count, err := s.db.Q.AreConnectionsInAccount( + ctx, + s.db.Db, + db_queries.AreConnectionsInAccountParams{ + AccountId: accountUuid, + ConnectionIds: connectionUuids, + }, + ) if err != nil { return nil, fmt.Errorf("unable to check if connections are in provided account: %w", err) } @@ -378,7 +437,12 @@ func (s *Service) CreateJob( if err != nil { return nil, err } - areConnectionsCompatible, err := verifyConnectionsAreCompatible(ctx, s.db, sourceUuid, destinations) + areConnectionsCompatible, err := verifyConnectionsAreCompatible( + ctx, + s.db, + sourceUuid, + destinations, + ) if err != nil { return nil, fmt.Errorf("unable to verify if all connections are compatible: %w", err) } @@ -437,7 +501,9 @@ func (s *Service) CreateJob( return nil, fmt.Errorf("unable to verify account's temporal workspace. error: %w", err) } if !hasNs { - return nil, nucleuserrors.NewBadRequest("must first configure temporal namespace in account settings") + return nil, nucleuserrors.NewBadRequest( + "must first configure temporal namespace in account settings", + ) } taskQueue, err := s.temporalmgr.GetSyncJobTaskQueue(ctx, req.Msg.GetAccountId(), logger) @@ -536,7 +602,11 @@ func (s *Service) CreateJob( logger.Debug("deleting newly created job") removeJobErr := s.db.Q.RemoveJobById(ctx, s.db.Db, cj.ID) if removeJobErr != nil { - return nil, fmt.Errorf("unable to create scheduled job and was unable to fully cleanup partially created resources: %w: %w", removeJobErr, err) + return nil, fmt.Errorf( + "unable to create scheduled job and was unable to fully cleanup partially created resources: %w: %w", + removeJobErr, + err, + ) } return nil, fmt.Errorf("unable to create scheduled job: %w", err) } @@ -546,7 +616,13 @@ func (s *Service) CreateJob( if req.Msg.InitiateJobRun { logger.Debug("triggering initial job run") // manually trigger job run - err := s.temporalmgr.TriggerSchedule(ctx, req.Msg.GetAccountId(), scheduleId, &temporalclient.ScheduleTriggerOptions{}, logger) + err := s.temporalmgr.TriggerSchedule( + ctx, + req.Msg.GetAccountId(), + scheduleId, + &temporalclient.ScheduleTriggerOptions{}, + logger, + ) if err != nil { // don't return error here logger.Error(fmt.Errorf("unable to trigger job: %w", err).Error()) @@ -591,7 +667,11 @@ func (s *Service) DeleteJob( if err != nil { return nil, err } - err = user.EnforceJob(ctx, userdata.NewDbDomainEntity(dbJob.AccountID, dbJob.ID), rbac.JobAction_Delete) + err = user.EnforceJob( + ctx, + userdata.NewDbDomainEntity(dbJob.AccountID, dbJob.ID), + rbac.JobAction_Delete, + ) if err != nil { return nil, err } @@ -672,7 +752,9 @@ func (s *Service) CreateJobDestinationConnections( return nil, err } if !isInSameAccount { - return nil, nucleuserrors.NewBadRequest("connections are not all within the provided account") + return nil, nucleuserrors.NewBadRequest( + "connections are not all within the provided account", + ) } logger.Debug("creating job destination connections", "connectionIds", connectionIds) @@ -909,9 +991,12 @@ func (s *Service) UpdateJobSourceConnection( } // retrieves the connection details - conn, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: connectionIdToVerify, - })) + conn, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: connectionIdToVerify, + }), + ) if err != nil { return nil, err } @@ -980,7 +1065,15 @@ func (s *Service) UpdateJobSourceConnection( vfkKeys := map[string]struct{}{} virtualForeignKeys := []*pg_models.VirtualForeignConstraint{} for _, fk := range req.Msg.GetVirtualForeignKeys() { - key := fmt.Sprintf("%s.%s.%s.%s.%s.%s", fk.GetSchema(), fk.GetTable(), strings.Join(fk.GetColumns(), "."), fk.GetForeignKey().GetSchema(), fk.GetForeignKey().GetTable(), strings.Join(fk.GetForeignKey().GetColumns(), ".")) + key := fmt.Sprintf( + "%s.%s.%s.%s.%s.%s", + fk.GetSchema(), + fk.GetTable(), + strings.Join(fk.GetColumns(), "."), + fk.GetForeignKey().GetSchema(), + fk.GetForeignKey().GetTable(), + strings.Join(fk.GetForeignKey().GetColumns(), "."), + ) if _, exists := vfkKeys[key]; exists { // skip duplicates continue @@ -1111,9 +1204,12 @@ func (s *Service) SetJobSourceSqlConnectionSubsets( return nil, nucleuserrors.NewInternalError("unable to find connection id") } - connectionResp, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: *connectionId, - })) + connectionResp, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: *connectionId, + }), + ) if err != nil { return nil, err } @@ -1121,7 +1217,9 @@ func (s *Service) SetJobSourceSqlConnectionSubsets( if connection.ConnectionConfig == nil || (connection.ConnectionConfig.GetPgConfig() == nil && connection.ConnectionConfig.GetMysqlConfig() == nil && connection.ConnectionConfig.GetDynamodbConfig() == nil && connection.ConnectionConfig.GetMssqlConfig() == nil) { - return nil, nucleuserrors.NewBadRequest("may only update subsets for select source connections") + return nil, nucleuserrors.NewBadRequest( + "may only update subsets for select source connections", + ) } if err := s.db.SetSourceSubsets( @@ -1193,11 +1291,15 @@ func (s *Service) UpdateJobDestinationConnection( // todo(NEOS-1281): need a lot more validation here for changing connection uuid, matching options, as well as creating a new destination // if that destination is not supported with the source type logger.Debug("updating job destination connection") - _, err = s.db.Q.UpdateJobConnectionDestination(ctx, s.db.Db, db_queries.UpdateJobConnectionDestinationParams{ - ID: destinationUuid, - ConnectionID: connectionUuid, - Options: options, - }) + _, err = s.db.Q.UpdateJobConnectionDestination( + ctx, + s.db.Db, + db_queries.UpdateJobConnectionDestinationParams{ + ID: destinationUuid, + ConnectionID: connectionUuid, + Options: options, + }, + ) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err } else if err != nil && neosyncdb.IsNoRows(err) { @@ -1333,7 +1435,12 @@ func (s *Service) verifyConnectionInAccount( return nil } -func verifyConnectionsInAccount(ctx context.Context, db *neosyncdb.NeosyncDb, connectionUuids []pgtype.UUID, accountUuid pgtype.UUID) (bool, error) { +func verifyConnectionsInAccount( + ctx context.Context, + db *neosyncdb.NeosyncDb, + connectionUuids []pgtype.UUID, + accountUuid pgtype.UUID, +) (bool, error) { conns, err := db.Q.GetConnectionsByIds(ctx, db.Db, connectionUuids) if err != nil { return false, err @@ -1360,7 +1467,12 @@ func verifyConnectionIdsUnique(connectionIds []string) bool { return true } -func verifyConnectionsAreCompatible(ctx context.Context, db *neosyncdb.NeosyncDb, sourceConnId pgtype.UUID, destinations []*destination) (bool, error) { +func verifyConnectionsAreCompatible( + ctx context.Context, + db *neosyncdb.NeosyncDb, + sourceConnId pgtype.UUID, + destinations []*destination, +) (bool, error) { var sourceConnection db_queries.NeosyncApiConnection dests := make([]db_queries.NeosyncApiConnection, len(destinations)) group := new(errgroup.Group) @@ -1393,26 +1505,31 @@ func verifyConnectionsAreCompatible(ctx context.Context, db *neosyncdb.NeosyncDb for i := range dests { d := dests[i] // AWS S3 and GCP CloudStorage are always a valid destination regardless of source connection type - if d.ConnectionConfig.AwsS3Config != nil || d.ConnectionConfig.GcpCloudStorageConfig != nil { + if d.ConnectionConfig.AwsS3Config != nil || + d.ConnectionConfig.GcpCloudStorageConfig != nil { continue } if sourceConnection.ConnectionConfig.PgConfig != nil && d.ConnectionConfig.PgConfig == nil { // invalid Postgres source cannot have Mysql destination return false, nil } - if sourceConnection.ConnectionConfig.MysqlConfig != nil && d.ConnectionConfig.MysqlConfig == nil { + if sourceConnection.ConnectionConfig.MysqlConfig != nil && + d.ConnectionConfig.MysqlConfig == nil { // invalid Mysql source cannot have non-Mysql or non-AWS connection return false, nil } - if sourceConnection.ConnectionConfig.MongoConfig != nil && d.ConnectionConfig.MongoConfig == nil { + if sourceConnection.ConnectionConfig.MongoConfig != nil && + d.ConnectionConfig.MongoConfig == nil { // invalid Mongo source cannot have anything other than mongo to start return false, nil } - if sourceConnection.ConnectionConfig.DynamoDBConfig != nil && d.ConnectionConfig.DynamoDBConfig == nil { + if sourceConnection.ConnectionConfig.DynamoDBConfig != nil && + d.ConnectionConfig.DynamoDBConfig == nil { // invalid DynamoDB source cannot have anything other than dynamodb to start return false, nil } - if sourceConnection.ConnectionConfig.MssqlConfig != nil && d.ConnectionConfig.MssqlConfig == nil { + if sourceConnection.ConnectionConfig.MssqlConfig != nil && + d.ConnectionConfig.MssqlConfig == nil { return false, nil } } @@ -1498,7 +1615,9 @@ func (s *Service) SetJobWorkflowOptions( return nil, err } - return connect.NewResponse(&mgmtv1alpha1.SetJobWorkflowOptionsResponse{Job: updatedJob.Msg.Job}), nil + return connect.NewResponse( + &mgmtv1alpha1.SetJobWorkflowOptionsResponse{Job: updatedJob.Msg.Job}, + ), nil } func getDurationFromInt(input *int64) time.Duration { @@ -1552,7 +1671,9 @@ func (s *Service) SetJobSyncOptions( if err != nil { return nil, err } - return connect.NewResponse(&mgmtv1alpha1.SetJobSyncOptionsResponse{Job: updatedJob.Msg.Job}), nil + return connect.NewResponse( + &mgmtv1alpha1.SetJobSyncOptionsResponse{Job: updatedJob.Msg.Job}, + ), nil } func (s *Service) ValidateJobMappings( @@ -1562,9 +1683,12 @@ func (s *Service) ValidateJobMappings( logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("accountId", req.Msg.GetAccountId()) - connection, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: req.Msg.GetConnectionId(), - })) + connection, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: req.Msg.GetConnectionId(), + }), + ) if err != nil { return nil, err } @@ -1574,11 +1698,17 @@ func (s *Service) ValidateJobMappings( return nil, errors.New("connection config for connection was nil") } - if connConfig.GetAwsS3Config() != nil || connConfig.GetMongoConfig() != nil || connConfig.GetDynamodbConfig() != nil { + if connConfig.GetAwsS3Config() != nil || connConfig.GetMongoConfig() != nil || + connConfig.GetDynamodbConfig() != nil { return connect.NewResponse(&mgmtv1alpha1.ValidateJobMappingsResponse{}), nil } - db, err := s.sqlmanager.NewSqlConnection(ctx, connectionmanager.NewUniqueSession(), connection.Msg.GetConnection(), logger) + db, err := s.sqlmanager.NewSqlConnection( + ctx, + connectionmanager.NewUniqueSession(), + connection.Msg.GetConnection(), + logger, + ) if err != nil { return nil, err } @@ -1613,7 +1743,10 @@ func (s *Service) ValidateJobMappings( } } - validator := job_util.NewJobMappingsValidator(req.Msg.Mappings, job_util.WithJobSourceOptions(sqlSourceOpts)) + validator := job_util.NewJobMappingsValidator( + req.Msg.Mappings, + job_util.WithJobSourceOptions(sqlSourceOpts), + ) result, err := validator.Validate(colInfoMap, req.Msg.VirtualForeignKeys, tableConstraints) if err != nil { return nil, err @@ -1696,9 +1829,12 @@ func (s *Service) ValidateSchema( req *connect.Request[mgmtv1alpha1.ValidateSchemaRequest], ) (*connect.Response[mgmtv1alpha1.ValidateSchemaResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) - schemaConn, err := s.connectionService.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: req.Msg.GetConnectionId(), - })) + schemaConn, err := s.connectionService.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: req.Msg.GetConnectionId(), + }), + ) if err != nil { return nil, err } @@ -1733,7 +1869,9 @@ func (s *Service) ValidateSchema( return connect.NewResponse(resp), nil } -func convertSchemaTables(tables []*sqlmanager_shared.SchemaTable) []*mgmtv1alpha1.ValidateSchemaResponse_Table { +func convertSchemaTables( + tables []*sqlmanager_shared.SchemaTable, +) []*mgmtv1alpha1.ValidateSchemaResponse_Table { var protoTables []*mgmtv1alpha1.ValidateSchemaResponse_Table for _, table := range tables { protoTables = append(protoTables, &mgmtv1alpha1.ValidateSchemaResponse_Table{ @@ -1745,7 +1883,9 @@ func convertSchemaTables(tables []*sqlmanager_shared.SchemaTable) []*mgmtv1alpha } func isConnectionSQLType(connection *mgmtv1alpha1.Connection) bool { - return connection.GetConnectionConfig().GetPgConfig() != nil || connection.GetConnectionConfig().GetMysqlConfig() != nil || connection.GetConnectionConfig().GetMssqlConfig() != nil + return connection.GetConnectionConfig().GetPgConfig() != nil || + connection.GetConnectionConfig().GetMysqlConfig() != nil || + connection.GetConnectionConfig().GetMssqlConfig() != nil } func getJobSourceConnectionId(jobSource *mgmtv1alpha1.JobSource) (*string, error) { @@ -1781,7 +1921,9 @@ func getJobSourceConnectionId(jobSource *mgmtv1alpha1.JobSource) (*string, error return connectionIdToVerify, nil } -func getConnectionSchemaConfigByConnectionType(connection *mgmtv1alpha1.Connection) (*mgmtv1alpha1.ConnectionSchemaConfig, error) { +func getConnectionSchemaConfigByConnectionType( + connection *mgmtv1alpha1.Connection, +) (*mgmtv1alpha1.ConnectionSchemaConfig, error) { switch conn := connection.GetConnectionConfig().GetConfig().(type) { case *mgmtv1alpha1.ConnectionConfig_PgConfig: return &mgmtv1alpha1.ConnectionSchemaConfig{ diff --git a/backend/services/mgmt/v1alpha1/job-service/runs.go b/backend/services/mgmt/v1alpha1/job-service/runs.go index b374949e2b..038fba0efd 100644 --- a/backend/services/mgmt/v1alpha1/job-service/runs.go +++ b/backend/services/mgmt/v1alpha1/job-service/runs.go @@ -90,7 +90,12 @@ func (s *Service) GetJobRuns( return nil, err } - workflows, err := s.temporalmgr.GetWorkflowExecutionsByScheduleIds(ctx, accountId, jobIds, logger) + workflows, err := s.temporalmgr.GetWorkflowExecutionsByScheduleIds( + ctx, + accountId, + jobIds, + logger, + ) if err != nil { return nil, err } @@ -112,7 +117,12 @@ func (s *Service) GetJobRun( logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("jobRunId", req.Msg.JobRunId) - res, err := s.temporalmgr.DescribeWorklowExecution(ctx, req.Msg.GetAccountId(), req.Msg.GetJobRunId(), logger) + res, err := s.temporalmgr.DescribeWorklowExecution( + ctx, + req.Msg.GetAccountId(), + req.Msg.GetJobRunId(), + logger, + ) if err != nil { return nil, fmt.Errorf("unable to describe workflow execution: %w", err) } @@ -157,7 +167,11 @@ func (s *Service) GetJobRunEvents( return connect.NewResponse(resp), nil } -func (s *Service) getEventsByWorkflowId(ctx context.Context, accountId, workflowId string, logger *slog.Logger) (*mgmtv1alpha1.GetJobRunEventsResponse, error) { +func (s *Service) getEventsByWorkflowId( + ctx context.Context, + accountId, workflowId string, + logger *slog.Logger, +) (*mgmtv1alpha1.GetJobRunEventsResponse, error) { isRunComplete := false activityOrder := []int64{} activityMap := map[int64]*mgmtv1alpha1.JobRunEvent{} @@ -195,9 +209,12 @@ func (s *Service) getEventsByWorkflowId(ctx context.Context, accountId, workflow } if len(attributes.Input.Payloads) > 1 { var rawMap map[string]string - err := converter.GetDefaultDataConverter().FromPayload(attributes.Input.Payloads[1], &rawMap) + err := converter.GetDefaultDataConverter(). + FromPayload(attributes.Input.Payloads[1], &rawMap) if err != nil { - logger.Error(fmt.Errorf("unable to convert to event input payload: %w", err).Error()) + logger.Error( + fmt.Errorf("unable to convert to event input payload: %w", err).Error(), + ) } schema, schemaExists := rawMap["Schema"] @@ -277,9 +294,12 @@ func (s *Service) getEventsByWorkflowId(ctx context.Context, accountId, workflow switch attributes.GetWorkflowType().GetName() { case "TableSync": var tableSyncRequest tablesync_workflow.TableSyncRequest - err := converter.GetDefaultDataConverter().FromPayload(attributes.Input.Payloads[0], &tableSyncRequest) + err := converter.GetDefaultDataConverter(). + FromPayload(attributes.Input.Payloads[0], &tableSyncRequest) if err != nil { - logger.Error(fmt.Errorf("unable to convert to event input payload: %w", err).Error()) + logger.Error( + fmt.Errorf("unable to convert to event input payload: %w", err).Error(), + ) } metadata := &mgmtv1alpha1.JobRunEventMetadata{} @@ -292,9 +312,12 @@ func (s *Service) getEventsByWorkflowId(ctx context.Context, accountId, workflow jobRunEvent.Metadata = metadata case "TablePiiDetect": var piiDetectTableRequest piidetect_table_workflow.TablePiiDetectRequest - err := converter.GetDefaultDataConverter().FromPayload(attributes.Input.Payloads[0], &piiDetectTableRequest) + err := converter.GetDefaultDataConverter(). + FromPayload(attributes.Input.Payloads[0], &piiDetectTableRequest) if err != nil { - logger.Error(fmt.Errorf("unable to convert to event input payload: %w", err).Error()) + logger.Error( + fmt.Errorf("unable to convert to event input payload: %w", err).Error(), + ) } metadata := &mgmtv1alpha1.JobRunEventMetadata{} metadata.Metadata = &mgmtv1alpha1.JobRunEventMetadata_SyncMetadata{ @@ -369,9 +392,20 @@ func (s *Service) getEventsByWorkflowId(ctx context.Context, accountId, workflow continue } logger.Debug("child workflow is not complete, checking if it is closed") - info, err := s.temporalmgr.GetWorkflowExecutionById(ctx, accountId, childWorkflowId, logger) + info, err := s.temporalmgr.GetWorkflowExecutionById( + ctx, + accountId, + childWorkflowId, + logger, + ) if err != nil { - logger.Warn(fmt.Sprintf("unable to get workflow execution info for %s: %s", childWorkflowId, err)) + logger.Warn( + fmt.Sprintf( + "unable to get workflow execution info for %s: %s", + childWorkflowId, + err, + ), + ) continue } @@ -383,16 +417,25 @@ func (s *Service) getEventsByWorkflowId(ctx context.Context, accountId, workflow switch info.GetStatus() { case enums.WORKFLOW_EXECUTION_STATUS_COMPLETED: highestEventId++ - childEvent.Tasks = append(childEvent.Tasks, dtomaps.ToJobRunEventTaskDto(&history.HistoryEvent{ - EventId: highestEventId, - EventTime: info.GetCloseTime(), - EventType: enums.EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_COMPLETED, - Attributes: nil, - }, nil)) + childEvent.Tasks = append( + childEvent.Tasks, + dtomaps.ToJobRunEventTaskDto(&history.HistoryEvent{ + EventId: highestEventId, + EventTime: info.GetCloseTime(), + EventType: enums.EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_COMPLETED, + Attributes: nil, + }, nil), + ) case enums.WORKFLOW_EXECUTION_STATUS_FAILED: resp, err := s.getEventsByWorkflowId(ctx, accountId, childWorkflowId, logger) if err != nil { - logger.Warn(fmt.Sprintf("unable to get events by workflow id for %s: %s", childWorkflowId, err)) + logger.Warn( + fmt.Sprintf( + "unable to get events by workflow id for %s: %s", + childWorkflowId, + err, + ), + ) continue } var eventErr *mgmtv1alpha1.JobRunEventTaskError @@ -405,36 +448,48 @@ func (s *Service) getEventsByWorkflowId(ctx context.Context, accountId, workflow } } highestEventId++ - childEvent.Tasks = append(childEvent.Tasks, dtomaps.ToJobRunEventTaskDto(&history.HistoryEvent{ - EventId: highestEventId, - EventTime: info.GetCloseTime(), - EventType: enums.EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_FAILED, - Attributes: nil, - }, eventErr)) + childEvent.Tasks = append( + childEvent.Tasks, + dtomaps.ToJobRunEventTaskDto(&history.HistoryEvent{ + EventId: highestEventId, + EventTime: info.GetCloseTime(), + EventType: enums.EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_FAILED, + Attributes: nil, + }, eventErr), + ) case enums.WORKFLOW_EXECUTION_STATUS_TIMED_OUT: highestEventId++ - childEvent.Tasks = append(childEvent.Tasks, dtomaps.ToJobRunEventTaskDto(&history.HistoryEvent{ - EventId: highestEventId, - EventTime: info.GetCloseTime(), - EventType: enums.EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TIMED_OUT, - Attributes: nil, - }, nil)) + childEvent.Tasks = append( + childEvent.Tasks, + dtomaps.ToJobRunEventTaskDto(&history.HistoryEvent{ + EventId: highestEventId, + EventTime: info.GetCloseTime(), + EventType: enums.EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TIMED_OUT, + Attributes: nil, + }, nil), + ) case enums.WORKFLOW_EXECUTION_STATUS_CANCELED: highestEventId++ - childEvent.Tasks = append(childEvent.Tasks, dtomaps.ToJobRunEventTaskDto(&history.HistoryEvent{ - EventId: highestEventId, - EventTime: info.GetCloseTime(), - EventType: enums.EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_CANCELED, - Attributes: nil, - }, nil)) + childEvent.Tasks = append( + childEvent.Tasks, + dtomaps.ToJobRunEventTaskDto(&history.HistoryEvent{ + EventId: highestEventId, + EventTime: info.GetCloseTime(), + EventType: enums.EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_CANCELED, + Attributes: nil, + }, nil), + ) case enums.WORKFLOW_EXECUTION_STATUS_TERMINATED: highestEventId++ - childEvent.Tasks = append(childEvent.Tasks, dtomaps.ToJobRunEventTaskDto(&history.HistoryEvent{ - EventId: highestEventId, - EventTime: info.GetCloseTime(), - EventType: enums.EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TERMINATED, - Attributes: nil, - }, nil)) + childEvent.Tasks = append( + childEvent.Tasks, + dtomaps.ToJobRunEventTaskDto(&history.HistoryEvent{ + EventId: highestEventId, + EventTime: info.GetCloseTime(), + EventType: enums.EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TERMINATED, + Attributes: nil, + }, nil), + ) } } } @@ -542,7 +597,12 @@ func (s *Service) TerminateJobRun( return nil, err } - err = s.temporalmgr.TerminateWorkflow(ctx, req.Msg.GetAccountId(), req.Msg.GetJobRunId(), logger) + err = s.temporalmgr.TerminateWorkflow( + ctx, + req.Msg.GetAccountId(), + req.Msg.GetJobRunId(), + logger, + ) if err != nil { return nil, fmt.Errorf("unable to terminate job run: %w", err) } @@ -574,7 +634,12 @@ func (s *Service) DeleteJobRun( return nil, err } - err = s.temporalmgr.DeleteWorkflowExecution(ctx, req.Msg.GetAccountId(), req.Msg.GetJobRunId(), logger) + err = s.temporalmgr.DeleteWorkflowExecution( + ctx, + req.Msg.GetAccountId(), + req.Msg.GetJobRunId(), + logger, + ) if err != nil { return nil, fmt.Errorf("unable to delete job run: %w", err) } @@ -596,7 +661,12 @@ func (s *Service) GetJobRunLogsStream( logger = logger.With("jobRunId", req.Msg.GetJobRunId()) onLogLine := func(logline *mgmtv1alpha1.GetJobRunLogsResponse_LogLine) error { - return stream.Send(&mgmtv1alpha1.GetJobRunLogsStreamResponse{LogLine: logline.LogLine, Timestamp: logline.Timestamp}) + return stream.Send( + &mgmtv1alpha1.GetJobRunLogsStreamResponse{ + LogLine: logline.LogLine, + Timestamp: logline.Timestamp, + }, + ) } return s.streamLogs(ctx, req.Msg, &logLineStreamer{onLogLine: onLogLine}, logger) } @@ -640,8 +710,11 @@ func (s *Service) streamLogs( stream logStreamer, logger *slog.Logger, ) error { - if s.cfg.RunLogConfig == nil || !s.cfg.RunLogConfig.IsEnabled || s.cfg.RunLogConfig.RunLogType == nil { - return nucleuserrors.NewNotImplemented("job run logs is not enabled. please configure or contact system administrator to enable logs.") + if s.cfg.RunLogConfig == nil || !s.cfg.RunLogConfig.IsEnabled || + s.cfg.RunLogConfig.RunLogType == nil { + return nucleuserrors.NewNotImplemented( + "job run logs is not enabled. please configure or contact system administrator to enable logs.", + ) } jobRunResp, err := s.GetJobRun(ctx, connect.NewRequest(&mgmtv1alpha1.GetJobRunRequest{ @@ -674,7 +747,9 @@ func (s *Service) streamLogs( } return nil default: - return nucleuserrors.NewNotImplemented("streaming log pods not implemented for this container type") + return nucleuserrors.NewNotImplemented( + "streaming log pods not implemented for this container type", + ) } } @@ -708,7 +783,12 @@ func (s *Service) streamK8sWorkerPodLogs( if s.cfg.RunLogConfig.RunLogPodConfig == nil { return nucleuserrors.NewInternalError("run logs configured but no config provided") } - workflowExecution, err := s.temporalmgr.GetWorkflowExecutionById(ctx, req.GetAccountId(), req.GetJobRunId(), logger) + workflowExecution, err := s.temporalmgr.GetWorkflowExecutionById( + ctx, + req.GetAccountId(), + req.GetJobRunId(), + logger, + ) if err != nil { return err } @@ -723,7 +803,11 @@ func (s *Service) streamK8sWorkerPodLogs( return fmt.Errorf("unable to create kubernetes clientset: %w", err) } - appNameSelector, err := labels.NewRequirement("app", selection.Equals, []string{s.cfg.RunLogConfig.RunLogPodConfig.WorkerAppName}) + appNameSelector, err := labels.NewRequirement( + "app", + selection.Equals, + []string{s.cfg.RunLogConfig.RunLogPodConfig.WorkerAppName}, + ) if err != nil { return fmt.Errorf("unable to build label selector when finding k8s logs: %w", err) } @@ -798,13 +882,19 @@ func (s *Service) streamLokiWorkerLogs( stream logStreamer, logger *slog.Logger, ) error { - if s.cfg.RunLogConfig == nil || !s.cfg.RunLogConfig.IsEnabled || s.cfg.RunLogConfig.LokiRunLogConfig == nil { + if s.cfg.RunLogConfig == nil || !s.cfg.RunLogConfig.IsEnabled || + s.cfg.RunLogConfig.LokiRunLogConfig == nil { return nucleuserrors.NewInternalError("run logs configured but no config provided") } if s.cfg.RunLogConfig.LokiRunLogConfig.LabelsQuery == "" { return nucleuserrors.NewInternalError("must provide a labels query for loki to filter by") } - workflowExecution, err := s.temporalmgr.GetWorkflowExecutionById(ctx, req.GetAccountId(), req.GetJobRunId(), logger) + workflowExecution, err := s.temporalmgr.GetWorkflowExecutionById( + ctx, + req.GetAccountId(), + req.GetJobRunId(), + logger, + ) if err != nil { return fmt.Errorf("unable to retrieve workflow execution: %w", err) } @@ -863,7 +953,13 @@ func (s *Service) streamLokiWorkerLogs( } for _, entry := range entries { - err := stream.Send(&mgmtv1alpha1.GetJobRunLogsResponse_LogLine{LogLine: entry.Line, Labels: entry.Labels.Map(), Timestamp: timestamppb.New(entry.Timestamp)}) + err := stream.Send( + &mgmtv1alpha1.GetJobRunLogsResponse_LogLine{ + LogLine: entry.Line, + Labels: entry.Labels.Map(), + Timestamp: timestamppb.New(entry.Timestamp), + }, + ) if err != nil { return err } @@ -902,7 +998,12 @@ func logLevelToString(loglevel mgmtv1alpha1.LogLevel) string { } } -func buildLokiQuery(lokiLables string, keep []string, workflowId string, loglevels []string) string { +func buildLokiQuery( + lokiLables string, + keep []string, + workflowId string, + loglevels []string, +) string { query := fmt.Sprintf("{%s} | json", lokiLables) query = fmt.Sprintf("%s | JobRunId=%q", query, workflowId) @@ -981,7 +1082,9 @@ func (s *Service) SetRunContext( } if s.cfg.IsNeosyncCloud && !user.IsWorkerApiKey() { - return nil, nucleuserrors.NewUnauthenticated("must provide valid authentication credentials for this endpoint") + return nil, nucleuserrors.NewUnauthenticated( + "must provide valid authentication credentials for this endpoint", + ) } accountUuid, err := neosyncdb.ToUuid(id.GetAccountId()) @@ -1020,7 +1123,9 @@ func (s *Service) SetRunContexts( } if s.cfg.IsNeosyncCloud && !user.IsWorkerApiKey() { - return nil, nucleuserrors.NewUnauthenticated("must provide valid authentication credentials for this endpoint") + return nil, nucleuserrors.NewUnauthenticated( + "must provide valid authentication credentials for this endpoint", + ) } accountUuid, err := neosyncdb.ToUuid(id.GetAccountId()) @@ -1079,11 +1184,15 @@ func (s *Service) GetPiiDetectionReport( // this allows us to effectively stream in the latest reports while the job is running if len(tableRunContexts) == 0 { logger.Debug("no table run contexts found in job report, fetching table level reports") - runContexts, err := s.db.Q.GetRunContextsByExternalIdSuffix(ctx, s.db.Db, db_queries.GetRunContextsByExternalIdSuffixParams{ - WorkflowId: jobRun.GetId(), - ExternalIdSuffix: piidetect_table_activities.PiiTableReportSuffix, - AccountId: accountUuid, - }) + runContexts, err := s.db.Q.GetRunContextsByExternalIdSuffix( + ctx, + s.db.Db, + db_queries.GetRunContextsByExternalIdSuffixParams{ + WorkflowId: jobRun.GetId(), + ExternalIdSuffix: piidetect_table_activities.PiiTableReportSuffix, + AccountId: accountUuid, + }, + ) if err != nil && !neosyncdb.IsNoRows(err) { return nil, fmt.Errorf("unable to retrieve run contexts: %w", err) } @@ -1115,7 +1224,11 @@ func (s *Service) GetPiiDetectionReport( }), nil } -func (s *Service) getTableRunContextsFromJobReport(ctx context.Context, jobRun *mgmtv1alpha1.JobRun, accountUuid pgtype.UUID) ([]*db_queries.NeosyncApiRuncontext, error) { +func (s *Service) getTableRunContextsFromJobReport( + ctx context.Context, + jobRun *mgmtv1alpha1.JobRun, + accountUuid pgtype.UUID, +) ([]*db_queries.NeosyncApiRuncontext, error) { runContext, err := s.db.Q.GetRunContextByKey(ctx, s.db.Db, db_queries.GetRunContextByKeyParams{ WorkflowId: jobRun.GetId(), ExternalId: piidetect_job_activities.BuildJobReportExternalId(jobRun.GetJobId()), @@ -1131,7 +1244,11 @@ func (s *Service) getTableRunContextsFromJobReport(ctx context.Context, jobRun * if err != nil { return nil, fmt.Errorf("unable to unmarshal run context for job pii detect report: %w", err) } - tableRunContextKeys := make([]*mgmtv1alpha1.RunContextKey, 0, len(jobReport.SuccessfulTableReports)) + tableRunContextKeys := make( + []*mgmtv1alpha1.RunContextKey, + 0, + len(jobReport.SuccessfulTableReports), + ) for _, tableReport := range jobReport.SuccessfulTableReports { tableRunContextKeys = append(tableRunContextKeys, tableReport.ReportKey) } @@ -1142,7 +1259,10 @@ func (s *Service) getTableRunContextsFromJobReport(ctx context.Context, jobRun * return tableRunContexts, nil } -func (s *Service) getDbRunContextsFromKeys(ctx context.Context, keys []*mgmtv1alpha1.RunContextKey) ([]*db_queries.NeosyncApiRuncontext, error) { +func (s *Service) getDbRunContextsFromKeys( + ctx context.Context, + keys []*mgmtv1alpha1.RunContextKey, +) ([]*db_queries.NeosyncApiRuncontext, error) { errgrp, errctx := errgroup.WithContext(ctx) errgrp.SetLimit(10) runContexts := []*db_queries.NeosyncApiRuncontext{} @@ -1158,11 +1278,15 @@ func (s *Service) getDbRunContextsFromKeys(ctx context.Context, keys []*mgmtv1al if err != nil { return fmt.Errorf("unable to convert account id to uuid: %w", err) } - runContext, err := s.db.Q.GetRunContextByKey(errctx, s.db.Db, db_queries.GetRunContextByKeyParams{ - WorkflowId: key.GetJobRunId(), - ExternalId: key.GetExternalId(), - AccountId: accountUuid, - }) + runContext, err := s.db.Q.GetRunContextByKey( + errctx, + s.db.Db, + db_queries.GetRunContextByKeyParams{ + WorkflowId: key.GetJobRunId(), + ExternalId: key.GetExternalId(), + AccountId: accountUuid, + }, + ) if err != nil { return fmt.Errorf("unable to get run context: %w", err) } @@ -1178,7 +1302,9 @@ func (s *Service) getDbRunContextsFromKeys(ctx context.Context, keys []*mgmtv1al return runContexts, nil } -func getReportsFromTableContexts(tableContexts []*db_queries.NeosyncApiRuncontext) ([]*piidetect_table_activities.TableReport, error) { +func getReportsFromTableContexts( + tableContexts []*db_queries.NeosyncApiRuncontext, +) ([]*piidetect_table_activities.TableReport, error) { reports := make([]*piidetect_table_activities.TableReport, len(tableContexts)) for i := range tableContexts { runContext := tableContexts[i] @@ -1192,13 +1318,19 @@ func getReportsFromTableContexts(tableContexts []*db_queries.NeosyncApiRuncontex return reports, nil } -func getTableReportDtos(reports []*piidetect_table_activities.TableReport) []*mgmtv1alpha1.PiiDetectionReport_TableReport { +func getTableReportDtos( + reports []*piidetect_table_activities.TableReport, +) []*mgmtv1alpha1.PiiDetectionReport_TableReport { reportDtos := make([]*mgmtv1alpha1.PiiDetectionReport_TableReport, len(reports)) for i, report := range reports { reportDtos[i] = &mgmtv1alpha1.PiiDetectionReport_TableReport{ - Schema: report.TableSchema, - Table: report.TableName, - Columns: make([]*mgmtv1alpha1.PiiDetectionReport_TableReport_ColumnReport, 0, len(report.ColumnReports)), + Schema: report.TableSchema, + Table: report.TableName, + Columns: make( + []*mgmtv1alpha1.PiiDetectionReport_TableReport_ColumnReport, + 0, + len(report.ColumnReports), + ), } for _, columnReport := range report.ColumnReports { columnReportDto := &mgmtv1alpha1.PiiDetectionReport_TableReport_ColumnReport{ diff --git a/backend/services/mgmt/v1alpha1/metrics-service/metrics.go b/backend/services/mgmt/v1alpha1/metrics-service/metrics.go index 5624b33efa..9893ee8be6 100644 --- a/backend/services/mgmt/v1alpha1/metrics-service/metrics.go +++ b/backend/services/mgmt/v1alpha1/metrics-service/metrics.go @@ -41,12 +41,23 @@ func (s *Service) GetDailyMetricCount( timeDiff := end.Sub(start) if timeDiff > timeLimit { - return nil, nucleuserrors.NewBadRequest("duration between start and end must not exceed 60 days") + return nil, nucleuserrors.NewBadRequest( + "duration between start and end must not exceed 60 days", + ) } queryLabels := metrics.MetricLabels{ - metrics.NewNotEqLabel(metrics.IsUpdateConfigLabel, "true"), // we want to always exclude update configs - metrics.NewRegexMatchLabel(metrics.NeosyncDateLabel, strings.Join(metrics.GenerateMonthRegexRange(req.Msg.GetStart(), req.Msg.GetEnd()), metricDateSeparator)), + metrics.NewNotEqLabel( + metrics.IsUpdateConfigLabel, + "true", + ), // we want to always exclude update configs + metrics.NewRegexMatchLabel( + metrics.NeosyncDateLabel, + strings.Join( + metrics.GenerateMonthRegexRange(req.Msg.GetStart(), req.Msg.GetEnd()), + metricDateSeparator, + ), + ), } switch identifier := req.Msg.Identifier.(type) { @@ -126,12 +137,23 @@ func (s *Service) GetMetricCount( timeDiff := end.Sub(start) if timeDiff > timeLimit { - return nil, nucleuserrors.NewBadRequest("duration between start and end must not exceed 60 days") + return nil, nucleuserrors.NewBadRequest( + "duration between start and end must not exceed 60 days", + ) } queryLabels := metrics.MetricLabels{ - metrics.NewNotEqLabel(metrics.IsUpdateConfigLabel, "true"), // we want to always exclude update configs - metrics.NewRegexMatchLabel(metrics.NeosyncDateLabel, strings.Join(metrics.GenerateMonthRegexRange(req.Msg.GetStartDay(), req.Msg.GetEndDay()), metricDateSeparator)), + metrics.NewNotEqLabel( + metrics.IsUpdateConfigLabel, + "true", + ), // we want to always exclude update configs + metrics.NewRegexMatchLabel( + metrics.NeosyncDateLabel, + strings.Join( + metrics.GenerateMonthRegexRange(req.Msg.GetStartDay(), req.Msg.GetEndDay()), + metricDateSeparator, + ), + ), } switch identifier := req.Msg.Identifier.(type) { diff --git a/backend/services/mgmt/v1alpha1/transformers-service/entities.go b/backend/services/mgmt/v1alpha1/transformers-service/entities.go index 437c31a1f4..7e74baf7d4 100644 --- a/backend/services/mgmt/v1alpha1/transformers-service/entities.go +++ b/backend/services/mgmt/v1alpha1/transformers-service/entities.go @@ -23,7 +23,15 @@ func (s *Service) GetTransformPiiEntities( req *connect.Request[mgmtv1alpha1.GetTransformPiiEntitiesRequest], ) (*connect.Response[mgmtv1alpha1.GetTransformPiiEntitiesResponse], error) { if !s.cfg.IsPresidioEnabled { - return nil, nucleuserrors.NewNotImplemented(fmt.Sprintf("%s is not implemented", strings.TrimPrefix(mgmtv1alpha1connect.TransformersServiceGetTransformPiiEntitiesProcedure, "/"))) + return nil, nucleuserrors.NewNotImplemented( + fmt.Sprintf( + "%s is not implemented", + strings.TrimPrefix( + mgmtv1alpha1connect.TransformersServiceGetTransformPiiEntitiesProcedure, + "/", + ), + ), + ) } if s.entityclient == nil { return nil, nucleuserrors.NewInternalError("entity service is enabled but client was nil.") @@ -32,19 +40,31 @@ func (s *Service) GetTransformPiiEntities( if err != nil { return nil, err } - err = user.EnforceJob(ctx, userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), rbac.JobAction_View) + err = user.EnforceJob( + ctx, + userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), + rbac.JobAction_View, + ) if err != nil { return nil, err } - resp, err := s.entityclient.GetSupportedentitiesWithResponse(ctx, &presidioapi.GetSupportedentitiesParams{ - Language: &enLanguage, - }) + resp, err := s.entityclient.GetSupportedentitiesWithResponse( + ctx, + &presidioapi.GetSupportedentitiesParams{ + Language: &enLanguage, + }, + ) if err != nil { return nil, fmt.Errorf("unable to retrieve available entities: %w", err) } if resp.JSON200 == nil { - return nil, fmt.Errorf("received non-200 response from entity api: %s %d %s", resp.Status(), resp.StatusCode(), string(resp.Body)) + return nil, fmt.Errorf( + "received non-200 response from entity api: %s %d %s", + resp.Status(), + resp.StatusCode(), + string(resp.Body), + ) } entities := *resp.JSON200 diff --git a/backend/services/mgmt/v1alpha1/transformers-service/system_transformers.go b/backend/services/mgmt/v1alpha1/transformers-service/system_transformers.go index ed2d71b177..65ef921eeb 100644 --- a/backend/services/mgmt/v1alpha1/transformers-service/system_transformers.go +++ b/backend/services/mgmt/v1alpha1/transformers-service/system_transformers.go @@ -23,12 +23,18 @@ var ( baseSystemTransformers = []*mgmtv1alpha1.SystemTransformer{ { - Name: "Generate Email", - Description: "Generates a new randomized email address.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_EMAIL, + Name: "Generate Email", + Description: "Generates a new randomized email address.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_EMAIL, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateEmailConfig{ GenerateEmailConfig: &mgmtv1alpha1.GenerateEmail{ @@ -38,12 +44,17 @@ var ( }, }, { - Name: "Transform Email", - Description: "Transforms an existing email address.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_EMAIL, + Name: "Transform Email", + Description: "Transforms an existing email address.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_EMAIL, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformEmailConfig{ TransformEmailConfig: &mgmtv1alpha1.TransformEmail{ @@ -57,12 +68,18 @@ var ( }, }, { - Name: "Generate Boolean", - Description: "Generates a boolean value at random.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_BOOLEAN, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_BOOLEAN, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_BOOL, + Name: "Generate Boolean", + Description: "Generates a boolean value at random.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_BOOLEAN, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_BOOLEAN, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_BOOL, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateBoolConfig{ GenerateBoolConfig: &mgmtv1alpha1.GenerateBool{}, @@ -70,12 +87,18 @@ var ( }, }, { - Name: "Generate Card Number", - Description: "Generates a card number.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_CARD_NUMBER, + Name: "Generate Card Number", + Description: "Generates a card number.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_CARD_NUMBER, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateCardNumberConfig{ GenerateCardNumberConfig: &mgmtv1alpha1.GenerateCardNumber{ @@ -85,12 +108,18 @@ var ( }, }, { - Name: "Generate City", - Description: "Randomly selects a city from a list of predfined US cities.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_CITY, + Name: "Generate City", + Description: "Randomly selects a city from a list of predfined US cities.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_CITY, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateCityConfig{ GenerateCityConfig: &mgmtv1alpha1.GenerateCity{}, @@ -98,12 +127,18 @@ var ( }, }, { - Name: "Use Column Default", - Description: "Defers to the database column default", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_DEFAULT, + Name: "Use Column Default", + Description: "Defers to the database column default", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_DEFAULT, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateDefaultConfig{ GenerateDefaultConfig: &mgmtv1alpha1.GenerateDefault{}, @@ -111,12 +146,18 @@ var ( }, }, { - Name: "Generate International Phone Number", - Description: "Generates a phone number in international format with the + character at the start of the phone number. Note that the + sign is not included in the min or max.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_E164_PHONE_NUMBER, + Name: "Generate International Phone Number", + Description: "Generates a phone number in international format with the + character at the start of the phone number. Note that the + sign is not included in the min or max.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_E164_PHONE_NUMBER, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateE164PhoneNumberConfig{ GenerateE164PhoneNumberConfig: &mgmtv1alpha1.GenerateE164PhoneNumber{ @@ -127,12 +168,18 @@ var ( }, }, { - Name: "Generate First Name", - Description: "Generates a random first name. ", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_FIRST_NAME, + Name: "Generate First Name", + Description: "Generates a random first name. ", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_FIRST_NAME, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateFirstNameConfig{ GenerateFirstNameConfig: &mgmtv1alpha1.GenerateFirstName{}, @@ -140,12 +187,18 @@ var ( }, }, { - Name: "Generate Float64", - Description: "Generates a random float64 value.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_FLOAT64, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_FLOAT64, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_FLOAT64, + Name: "Generate Float64", + Description: "Generates a random float64 value.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_FLOAT64, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_FLOAT64, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_FLOAT64, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateFloat64Config{ GenerateFloat64Config: &mgmtv1alpha1.GenerateFloat64{ @@ -158,12 +211,18 @@ var ( }, }, { - Name: "Generate Full Address", - Description: "Randomly generates a street address in the format: {street_num} {street_addresss} {street_descriptor} {city}, {state} {zipcode}. For example, 123 Main Street Boston, Massachusetts 02169.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_FULL_ADDRESS, + Name: "Generate Full Address", + Description: "Randomly generates a street address in the format: {street_num} {street_addresss} {street_descriptor} {city}, {state} {zipcode}. For example, 123 Main Street Boston, Massachusetts 02169.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_FULL_ADDRESS, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateFullAddressConfig{ GenerateFullAddressConfig: &mgmtv1alpha1.GenerateFullAddress{}, @@ -171,12 +230,18 @@ var ( }, }, { - Name: "Generate Full Name", - Description: "Generates a new full name consisting of a first and last name", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_FULL_NAME, + Name: "Generate Full Name", + Description: "Generates a new full name consisting of a first and last name", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_FULL_NAME, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateFullNameConfig{ GenerateFullNameConfig: &mgmtv1alpha1.GenerateFullName{}, @@ -184,12 +249,18 @@ var ( }, }, { - Name: "Generate Gender", - Description: "Randomly generates one of the following genders: female, male, undefined, nonbinary.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_GENDER, + Name: "Generate Gender", + Description: "Randomly generates one of the following genders: female, male, undefined, nonbinary.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_GENDER, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateGenderConfig{ GenerateGenderConfig: &mgmtv1alpha1.GenerateGender{ @@ -199,12 +270,18 @@ var ( }, }, { - Name: "Generate Int64 Phone Number", - Description: "Generates a new phone number with a default length of 10.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_INT64_PHONE_NUMBER, + Name: "Generate Int64 Phone Number", + Description: "Generates a new phone number with a default length of 10.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_INT64_PHONE_NUMBER, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateInt64PhoneNumberConfig{ GenerateInt64PhoneNumberConfig: &mgmtv1alpha1.GenerateInt64PhoneNumber{}, @@ -212,12 +289,18 @@ var ( }, }, { - Name: "Generate Random Int64", - Description: "Generates a random int64 value.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_INT64, + Name: "Generate Random Int64", + Description: "Generates a random int64 value.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_INT64, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateInt64Config{ GenerateInt64Config: &mgmtv1alpha1.GenerateInt64{ @@ -229,12 +312,18 @@ var ( }, }, { - Name: "Generate Last Name", - Description: "Generates a random last name.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_LAST_NAME, + Name: "Generate Last Name", + Description: "Generates a random last name.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_LAST_NAME, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateLastNameConfig{ GenerateLastNameConfig: &mgmtv1alpha1.GenerateLastName{}, @@ -242,12 +331,18 @@ var ( }, }, { - Name: "Generate SHA256 Hash", - Description: "SHA256 hashes a randomly generated value.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_SHA256HASH, + Name: "Generate SHA256 Hash", + Description: "SHA256 hashes a randomly generated value.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_SHA256HASH, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateSha256HashConfig{ GenerateSha256HashConfig: &mgmtv1alpha1.GenerateSha256Hash{}, @@ -255,12 +350,18 @@ var ( }, }, { - Name: "Generate SSN", - Description: "Generates a completely random social security numbers including the hyphens in the format ", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_SSN, + Name: "Generate SSN", + Description: "Generates a completely random social security numbers including the hyphens in the format ", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_SSN, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateSsnConfig{ GenerateSsnConfig: &mgmtv1alpha1.GenerateSSN{}, @@ -268,12 +369,18 @@ var ( }, }, { - Name: "Generate State", - Description: "Randomly selects a US state.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_STATE, + Name: "Generate State", + Description: "Randomly selects a US state.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_STATE, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateStateConfig{ GenerateStateConfig: &mgmtv1alpha1.GenerateState{ @@ -283,12 +390,18 @@ var ( }, }, { - Name: "Generate Street Address", - Description: "Randomly generates a street address in the format: {street_num} {street_addresss} {street_descriptor}. For example, 123 Main Street.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_STREET_ADDRESS, + Name: "Generate Street Address", + Description: "Randomly generates a street address in the format: {street_num} {street_addresss} {street_descriptor}. For example, 123 Main Street.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_STREET_ADDRESS, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateStreetAddressConfig{ GenerateStreetAddressConfig: &mgmtv1alpha1.GenerateStreetAddress{}, @@ -296,12 +409,18 @@ var ( }, }, { - Name: "Generate String Phone Number", - Description: "Generates a phone number and returns it as a string.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_STRING_PHONE_NUMBER, + Name: "Generate String Phone Number", + Description: "Generates a phone number and returns it as a string.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_STRING_PHONE_NUMBER, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateStringPhoneNumberConfig{ GenerateStringPhoneNumberConfig: &mgmtv1alpha1.GenerateStringPhoneNumber{ @@ -312,12 +431,18 @@ var ( }, }, { - Name: "Generate Random String", - Description: "Creates a randomly ordered alphanumeric string between the specified range", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_RANDOM_STRING, + Name: "Generate Random String", + Description: "Creates a randomly ordered alphanumeric string between the specified range", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_RANDOM_STRING, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateStringConfig{ GenerateStringConfig: &mgmtv1alpha1.GenerateString{ @@ -328,12 +453,18 @@ var ( }, }, { - Name: "Generate Unix Timestamp", - Description: "Randomly generates a Unix timestamp", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_UNIXTIMESTAMP, + Name: "Generate Unix Timestamp", + Description: "Randomly generates a Unix timestamp", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_UNIXTIMESTAMP, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateUnixtimestampConfig{ GenerateUnixtimestampConfig: &mgmtv1alpha1.GenerateUnixTimestamp{}, @@ -341,12 +472,18 @@ var ( }, }, { - Name: "Generate Username", - Description: "Randomly generates a username in the format .", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_USERNAME, + Name: "Generate Username", + Description: "Randomly generates a username in the format .", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_USERNAME, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateUsernameConfig{ GenerateUsernameConfig: &mgmtv1alpha1.GenerateUsername{}, @@ -354,12 +491,18 @@ var ( }, }, { - Name: "Generate UTC Timestamp", - Description: "Randomly generates a UTC timestamp.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_TIME, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_TIME, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_UTCTIMESTAMP, + Name: "Generate UTC Timestamp", + Description: "Randomly generates a UTC timestamp.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_TIME, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_TIME, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_UTCTIMESTAMP, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateUtctimestampConfig{ GenerateUtctimestampConfig: &mgmtv1alpha1.GenerateUtcTimestamp{}, @@ -367,12 +510,19 @@ var ( }, }, { - Name: "Generate UUID", - Description: "Generates a new UUIDv4 identifier.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_UUID, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_UUID, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_UUID, + Name: "Generate UUID", + Description: "Generates a new UUIDv4 identifier.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_UUID, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_UUID, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_UUID, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateUuidConfig{ GenerateUuidConfig: &mgmtv1alpha1.GenerateUuid{ @@ -382,12 +532,18 @@ var ( }, }, { - Name: "Generate Zipcode", - Description: "Randomly selects a zip code from a list of predefined US zipcodes.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_ZIPCODE, + Name: "Generate Zipcode", + Description: "Randomly selects a zip code from a list of predefined US zipcodes.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_ZIPCODE, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateZipcodeConfig{ GenerateZipcodeConfig: &mgmtv1alpha1.GenerateZipcode{}, @@ -395,12 +551,17 @@ var ( }, }, { - Name: "Transform E164 Phone Number", - Description: "Transforms an existing E164 formatted phone number.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_E164_PHONE_NUMBER, + Name: "Transform E164 Phone Number", + Description: "Transforms an existing E164 formatted phone number.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_E164_PHONE_NUMBER, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformE164PhoneNumberConfig{ TransformE164PhoneNumberConfig: &mgmtv1alpha1.TransformE164PhoneNumber{ @@ -410,12 +571,17 @@ var ( }, }, { - Name: "Transform First Name", - Description: "Transforms an existing first name", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_FIRST_NAME, + Name: "Transform First Name", + Description: "Transforms an existing first name", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_FIRST_NAME, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformFirstNameConfig{ TransformFirstNameConfig: &mgmtv1alpha1.TransformFirstName{ @@ -425,12 +591,17 @@ var ( }, }, { - Name: "Transform Float64", - Description: "Transforms an existing float value.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_FLOAT64, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_FLOAT64, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_FLOAT64, + Name: "Transform Float64", + Description: "Transforms an existing float value.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_FLOAT64, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_FLOAT64, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_FLOAT64, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformFloat64Config{ TransformFloat64Config: &mgmtv1alpha1.TransformFloat64{ @@ -441,12 +612,17 @@ var ( }, }, { - Name: "Transform Full Name", - Description: "Transforms an existing full name.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_FULL_NAME, + Name: "Transform Full Name", + Description: "Transforms an existing full name.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_FULL_NAME, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformFullNameConfig{ TransformFullNameConfig: &mgmtv1alpha1.TransformFullName{ @@ -456,12 +632,17 @@ var ( }, }, { - Name: "Transform Int64 Phone Number", - Description: "Transforms an existing phone number that is typed as an integer", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_INT64_PHONE_NUMBER, + Name: "Transform Int64 Phone Number", + Description: "Transforms an existing phone number that is typed as an integer", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_INT64_PHONE_NUMBER, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformInt64PhoneNumberConfig{ TransformInt64PhoneNumberConfig: &mgmtv1alpha1.TransformInt64PhoneNumber{ @@ -471,12 +652,17 @@ var ( }, }, { - Name: "Transform Int64", - Description: "Transforms an existing integer value.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_INT64, + Name: "Transform Int64", + Description: "Transforms an existing integer value.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_INT64, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformInt64Config{ TransformInt64Config: &mgmtv1alpha1.TransformInt64{ @@ -487,12 +673,17 @@ var ( }, }, { - Name: "Transform Last Name", - Description: "Transforms an existing last name.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_LAST_NAME, + Name: "Transform Last Name", + Description: "Transforms an existing last name.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_LAST_NAME, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformLastNameConfig{ TransformLastNameConfig: &mgmtv1alpha1.TransformLastName{ @@ -502,12 +693,17 @@ var ( }, }, { - Name: "Transform String Phone Number", - Description: "Transforms an existing phone number that is typed as a string.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_PHONE_NUMBER, + Name: "Transform String Phone Number", + Description: "Transforms an existing phone number that is typed as a string.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_PHONE_NUMBER, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformPhoneNumberConfig{ TransformPhoneNumberConfig: &mgmtv1alpha1.TransformPhoneNumber{ @@ -517,12 +713,17 @@ var ( }, }, { - Name: "Transform String", - Description: "Transforms an existing string value.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_STRING, + Name: "Transform String", + Description: "Transforms an existing string value.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_STRING, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformStringConfig{ TransformStringConfig: &mgmtv1alpha1.TransformString{ @@ -532,12 +733,17 @@ var ( }, }, { - Name: "Passthrough", - Description: "Passes the input value through to the desination with no changes.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, + Name: "Passthrough", + Description: "Passes the input value through to the desination with no changes.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_PassthroughConfig{ PassthroughConfig: &mgmtv1alpha1.Passthrough{}, @@ -545,12 +751,17 @@ var ( }, }, { - Name: "Null", - Description: "Inserts a string instead of the source value.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_NULL, + Name: "Null", + Description: "Inserts a string instead of the source value.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_NULL, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_Nullconfig{ Nullconfig: &mgmtv1alpha1.Null{}, @@ -558,25 +769,38 @@ var ( }, }, { - Name: "Transform Javascript", - Description: "Write custom javascript to transform data", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_JAVASCRIPT, + Name: "Transform Javascript", + Description: "Write custom javascript to transform data", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_JAVASCRIPT, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformJavascriptConfig{ - TransformJavascriptConfig: &mgmtv1alpha1.TransformJavascript{Code: `return value + "test";`}, + TransformJavascriptConfig: &mgmtv1alpha1.TransformJavascript{ + Code: `return value + "test";`, + }, }, }, }, { - Name: "Generate Categorical", - Description: "Randomly selects a value from a predefined list of values", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_CATEGORICAL, + Name: "Generate Categorical", + Description: "Randomly selects a value from a predefined list of values", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_CATEGORICAL, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateCategoricalConfig{ GenerateCategoricalConfig: &mgmtv1alpha1.GenerateCategorical{ @@ -586,12 +810,17 @@ var ( }, }, { - Name: "Transform Character Scramble", - Description: "Transforms a string value by scrambling each character with another character in the same unicode block. Letters will be substituted with letters, numbers with numbers and special characters with special characters. Spaces and capitalization is preserved.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_CHARACTER_SCRAMBLE, + Name: "Transform Character Scramble", + Description: "Transforms a string value by scrambling each character with another character in the same unicode block. Letters will be substituted with letters, numbers with numbers and special characters with special characters. Spaces and capitalization is preserved.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_CHARACTER_SCRAMBLE, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformCharacterScrambleConfig{ TransformCharacterScrambleConfig: &mgmtv1alpha1.TransformCharacterScramble{ @@ -601,25 +830,39 @@ var ( }, }, { - Name: "Generate Javascript", - Description: "Write custom Javascript to generate synthetic data.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_JAVASCRIPT, + Name: "Generate Javascript", + Description: "Write custom Javascript to generate synthetic data.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_ANY, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_JAVASCRIPT, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateJavascriptConfig{ - GenerateJavascriptConfig: &mgmtv1alpha1.GenerateJavascript{Code: `return "testvalue";`}, + GenerateJavascriptConfig: &mgmtv1alpha1.GenerateJavascript{ + Code: `return "testvalue";`, + }, }, }, }, { - Name: "Generate Country", - Description: "Randomly selects a Country.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_COUNTRY, + Name: "Generate Country", + Description: "Randomly selects a Country.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_COUNTRY, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateCountryConfig{ GenerateCountryConfig: &mgmtv1alpha1.GenerateCountry{ @@ -629,12 +872,18 @@ var ( }, }, { - Name: "Generate Business Name", - Description: "Generates a random business name.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_BUSINESS_NAME, + Name: "Generate Business Name", + Description: "Generates a random business name.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_BUSINESS_NAME, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateBusinessNameConfig{ GenerateBusinessNameConfig: &mgmtv1alpha1.GenerateBusinessName{}, @@ -642,12 +891,18 @@ var ( }, }, { - Name: "Generate IP Address", - Description: "Generates a random IP address.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_IP_ADDRESS, + Name: "Generate IP Address", + Description: "Generates a random IP address.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_GENERATE, + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_IP_ADDRESS, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateIpAddressConfig{ GenerateIpAddressConfig: &mgmtv1alpha1.GenerateIpAddress{ @@ -657,12 +912,18 @@ var ( }, }, { - Name: "Transform UUID", - Description: "Transforms a UUID", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_UUID, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_UUID, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_UUID, + Name: "Transform UUID", + Description: "Transforms a UUID", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_UUID, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_UUID, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_UUID, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformUuidConfig{ TransformUuidConfig: &mgmtv1alpha1.TransformUuid{}, @@ -670,12 +931,17 @@ var ( }, }, { - Name: "Scramble Identity", - Description: "Scrambles an integer while keeping it unique.", - DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, - DataTypes: []mgmtv1alpha1.TransformerDataType{mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL}, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_SCRAMBLE_IDENTITY, + Name: "Scramble Identity", + Description: "Scrambles an integer while keeping it unique.", + DataType: mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + DataTypes: []mgmtv1alpha1.TransformerDataType{ + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_INT64, + mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, + }, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_SCRAMBLE_IDENTITY, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformScrambleIdentityConfig{ TransformScrambleIdentityConfig: &mgmtv1alpha1.TransformScrambleIdentity{}, @@ -730,7 +996,9 @@ func (s *Service) GetSystemTransformerBySource( transformer, ok := transformerMap[req.Msg.GetSource()] if !ok { - return nil, nucleuserrors.NewNotFound("unable to find system transformer with provided source") + return nil, nucleuserrors.NewNotFound( + "unable to find system transformer with provided source", + ) } return connect.NewResponse(&mgmtv1alpha1.GetSystemTransformerBySourceResponse{ Transformer: transformer, diff --git a/backend/services/mgmt/v1alpha1/transformers-service/userdefined_transformers.go b/backend/services/mgmt/v1alpha1/transformers-service/userdefined_transformers.go index 05511b1755..c8a53de1d3 100644 --- a/backend/services/mgmt/v1alpha1/transformers-service/userdefined_transformers.go +++ b/backend/services/mgmt/v1alpha1/transformers-service/userdefined_transformers.go @@ -26,7 +26,11 @@ func (s *Service) GetUserDefinedTransformers( if err != nil { return nil, err } - err = user.EnforceJob(ctx, userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), rbac.JobAction_View) + err = user.EnforceJob( + ctx, + userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), + rbac.JobAction_View, + ) if err != nil { return nil, err } @@ -43,9 +47,17 @@ func (s *Service) GetUserDefinedTransformers( dtoTransformers := []*mgmtv1alpha1.UserDefinedTransformer{} for idx := range transformers { transformer := transformers[idx] - dto, err := dtomaps.ToUserDefinedTransformerDto(&transformer, s.getSystemTransformerSourceMap()) + dto, err := dtomaps.ToUserDefinedTransformerDto( + &transformer, + s.getSystemTransformerSourceMap(), + ) if err != nil { - return nil, fmt.Errorf("failed to map user defined transformer %s with source %d: %w", neosyncdb.UUIDString(transformer.ID), transformer.Source, err) + return nil, fmt.Errorf( + "failed to map user defined transformer %s with source %d: %w", + neosyncdb.UUIDString(transformer.ID), + transformer.Source, + err, + ) } dtoTransformers = append(dtoTransformers, dto) } @@ -73,14 +85,23 @@ func (s *Service) GetUserDefinedTransformerById( dto, err := dtomaps.ToUserDefinedTransformerDto(&transformer, s.getSystemTransformerSourceMap()) if err != nil { - return nil, fmt.Errorf("failed to map user defined transformer %s with source %d: %w", neosyncdb.UUIDString(transformer.ID), transformer.Source, err) + return nil, fmt.Errorf( + "failed to map user defined transformer %s with source %d: %w", + neosyncdb.UUIDString(transformer.ID), + transformer.Source, + err, + ) } user, err := s.userdataclient.GetUser(ctx) if err != nil { return nil, err } - err = user.EnforceJob(ctx, userdata.NewWildcardDomainEntity(dto.GetAccountId()), rbac.JobAction_View) + err = user.EnforceJob( + ctx, + userdata.NewWildcardDomainEntity(dto.GetAccountId()), + rbac.JobAction_View, + ) if err != nil { return nil, err } @@ -90,12 +111,19 @@ func (s *Service) GetUserDefinedTransformerById( }), nil } -func (s *Service) CreateUserDefinedTransformer(ctx context.Context, req *connect.Request[mgmtv1alpha1.CreateUserDefinedTransformerRequest]) (*connect.Response[mgmtv1alpha1.CreateUserDefinedTransformerResponse], error) { +func (s *Service) CreateUserDefinedTransformer( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.CreateUserDefinedTransformerRequest], +) (*connect.Response[mgmtv1alpha1.CreateUserDefinedTransformerResponse], error) { user, err := s.userdataclient.GetUser(ctx) if err != nil { return nil, err } - err = user.EnforceJob(ctx, userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), rbac.JobAction_Edit) + err = user.EnforceJob( + ctx, + userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), + rbac.JobAction_Edit, + ) if err != nil { return nil, err } @@ -114,7 +142,9 @@ func (s *Service) CreateUserDefinedTransformer(ctx context.Context, req *connect UpdatedByID: user.PgId(), } - err = UserDefinedTransformer.TransformerConfig.FromTransformerConfigDto(req.Msg.TransformerConfig) + err = UserDefinedTransformer.TransformerConfig.FromTransformerConfigDto( + req.Msg.TransformerConfig, + ) if err != nil { return nil, err } @@ -126,7 +156,12 @@ func (s *Service) CreateUserDefinedTransformer(ctx context.Context, req *connect dto, err := dtomaps.ToUserDefinedTransformerDto(&ct, s.getSystemTransformerSourceMap()) if err != nil { - return nil, fmt.Errorf("failed to map user defined transformer %s with source %d: %w", neosyncdb.UUIDString(ct.ID), ct.Source, err) + return nil, fmt.Errorf( + "failed to map user defined transformer %s with source %d: %w", + neosyncdb.UUIDString(ct.ID), + ct.Source, + err, + ) } return connect.NewResponse(&mgmtv1alpha1.CreateUserDefinedTransformerResponse{ @@ -134,7 +169,10 @@ func (s *Service) CreateUserDefinedTransformer(ctx context.Context, req *connect }), nil } -func (s *Service) DeleteUserDefinedTransformer(ctx context.Context, req *connect.Request[mgmtv1alpha1.DeleteUserDefinedTransformerRequest]) (*connect.Response[mgmtv1alpha1.DeleteUserDefinedTransformerResponse], error) { +func (s *Service) DeleteUserDefinedTransformer( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.DeleteUserDefinedTransformerRequest], +) (*connect.Response[mgmtv1alpha1.DeleteUserDefinedTransformerResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) logger = logger.With("transformerId", req.Msg.GetTransformerId()) @@ -154,7 +192,11 @@ func (s *Service) DeleteUserDefinedTransformer(ctx context.Context, req *connect if err != nil { return nil, err } - err = user.EnforceJob(ctx, userdata.NewWildcardDomainEntity(neosyncdb.UUIDString(transformer.AccountID)), rbac.JobAction_Delete) + err = user.EnforceJob( + ctx, + userdata.NewWildcardDomainEntity(neosyncdb.UUIDString(transformer.AccountID)), + rbac.JobAction_Delete, + ) if err != nil { return nil, err } @@ -169,7 +211,10 @@ func (s *Service) DeleteUserDefinedTransformer(ctx context.Context, req *connect return connect.NewResponse(&mgmtv1alpha1.DeleteUserDefinedTransformerResponse{}), nil } -func (s *Service) UpdateUserDefinedTransformer(ctx context.Context, req *connect.Request[mgmtv1alpha1.UpdateUserDefinedTransformerRequest]) (*connect.Response[mgmtv1alpha1.UpdateUserDefinedTransformerResponse], error) { +func (s *Service) UpdateUserDefinedTransformer( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.UpdateUserDefinedTransformerRequest], +) (*connect.Response[mgmtv1alpha1.UpdateUserDefinedTransformerResponse], error) { tUuid, err := neosyncdb.ToUuid(req.Msg.TransformerId) if err != nil { return nil, err @@ -185,7 +230,11 @@ func (s *Service) UpdateUserDefinedTransformer(ctx context.Context, req *connect if err != nil { return nil, err } - err = user.EnforceJob(ctx, userdata.NewWildcardDomainEntity(neosyncdb.UUIDString(transformer.AccountID)), rbac.JobAction_Edit) + err = user.EnforceJob( + ctx, + userdata.NewWildcardDomainEntity(neosyncdb.UUIDString(transformer.AccountID)), + rbac.JobAction_Edit, + ) if err != nil { return nil, err } @@ -208,9 +257,17 @@ func (s *Service) UpdateUserDefinedTransformer(ctx context.Context, req *connect return nil, err } - dto, err := dtomaps.ToUserDefinedTransformerDto(&updatedTransformer, s.getSystemTransformerSourceMap()) + dto, err := dtomaps.ToUserDefinedTransformerDto( + &updatedTransformer, + s.getSystemTransformerSourceMap(), + ) if err != nil { - return nil, fmt.Errorf("failed to map user defined transformer %s with source %d: %w", neosyncdb.UUIDString(updatedTransformer.ID), updatedTransformer.Source, err) + return nil, fmt.Errorf( + "failed to map user defined transformer %s with source %d: %w", + neosyncdb.UUIDString(updatedTransformer.ID), + updatedTransformer.Source, + err, + ) } return connect.NewResponse(&mgmtv1alpha1.UpdateUserDefinedTransformerResponse{ @@ -218,12 +275,19 @@ func (s *Service) UpdateUserDefinedTransformer(ctx context.Context, req *connect }), err } -func (s *Service) IsTransformerNameAvailable(ctx context.Context, req *connect.Request[mgmtv1alpha1.IsTransformerNameAvailableRequest]) (*connect.Response[mgmtv1alpha1.IsTransformerNameAvailableResponse], error) { +func (s *Service) IsTransformerNameAvailable( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.IsTransformerNameAvailableRequest], +) (*connect.Response[mgmtv1alpha1.IsTransformerNameAvailableResponse], error) { user, err := s.userdataclient.GetUser(ctx) if err != nil { return nil, err } - err = user.EnforceJob(ctx, userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), rbac.JobAction_View) + err = user.EnforceJob( + ctx, + userdata.NewWildcardDomainEntity(req.Msg.GetAccountId()), + rbac.JobAction_View, + ) if err != nil { return nil, err } @@ -232,10 +296,14 @@ func (s *Service) IsTransformerNameAvailable(ctx context.Context, req *connect.R return nil, err } - count, err := s.db.Q.IsTransformerNameAvailable(ctx, s.db.Db, db_queries.IsTransformerNameAvailableParams{ - AccountId: accountUuid, - TransformerName: req.Msg.TransformerName, - }) + count, err := s.db.Q.IsTransformerNameAvailable( + ctx, + s.db.Db, + db_queries.IsTransformerNameAvailableParams{ + AccountId: accountUuid, + TransformerName: req.Msg.TransformerName, + }, + ) if err != nil { return nil, err } @@ -246,7 +314,10 @@ func (s *Service) IsTransformerNameAvailable(ctx context.Context, req *connect.R } // use the goja library to validate that the javascript can compile and theoretically run -func (s *Service) ValidateUserJavascriptCode(ctx context.Context, req *connect.Request[mgmtv1alpha1.ValidateUserJavascriptCodeRequest]) (*connect.Response[mgmtv1alpha1.ValidateUserJavascriptCodeResponse], error) { +func (s *Service) ValidateUserJavascriptCode( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.ValidateUserJavascriptCodeRequest], +) (*connect.Response[mgmtv1alpha1.ValidateUserJavascriptCodeResponse], error) { js := constructJavascriptCode(req.Msg.GetCode()) _, err := goja.Compile("test", js, true) @@ -272,7 +343,10 @@ func constructJavascriptCode(jsCode string) string { } } -func (s *Service) ValidateUserRegexCode(ctx context.Context, req *connect.Request[mgmtv1alpha1.ValidateUserRegexCodeRequest]) (*connect.Response[mgmtv1alpha1.ValidateUserRegexCodeResponse], error) { +func (s *Service) ValidateUserRegexCode( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.ValidateUserRegexCodeRequest], +) (*connect.Response[mgmtv1alpha1.ValidateUserRegexCodeResponse], error) { _, err := regexp.Compile(req.Msg.GetUserProvidedRegex()) // todo: should return error message here and surface to user return connect.NewResponse(&mgmtv1alpha1.ValidateUserRegexCodeResponse{ diff --git a/backend/services/mgmt/v1alpha1/user-account-service/account-onboarding.go b/backend/services/mgmt/v1alpha1/user-account-service/account-onboarding.go index 19f3777f0b..3baa55928e 100644 --- a/backend/services/mgmt/v1alpha1/user-account-service/account-onboarding.go +++ b/backend/services/mgmt/v1alpha1/user-account-service/account-onboarding.go @@ -21,7 +21,11 @@ func (s *Service) GetAccountOnboardingConfig( if err != nil { return nil, err } - err = user.EnforceAccount(ctx, userdata.NewIdentifier(req.Msg.GetAccountId()), rbac.AccountAction_View) + err = user.EnforceAccount( + ctx, + userdata.NewIdentifier(req.Msg.GetAccountId()), + rbac.AccountAction_View, + ) if err != nil { return nil, err } @@ -50,7 +54,11 @@ func (s *Service) SetAccountOnboardingConfig( if err != nil { return nil, err } - err = user.EnforceAccount(ctx, userdata.NewIdentifier(req.Msg.GetAccountId()), rbac.AccountAction_Edit) + err = user.EnforceAccount( + ctx, + userdata.NewIdentifier(req.Msg.GetAccountId()), + rbac.AccountAction_Edit, + ) if err != nil { return nil, err } @@ -68,10 +76,14 @@ func (s *Service) SetAccountOnboardingConfig( onboardingConfigModel := &pg_models.AccountOnboardingConfig{} onboardingConfigModel.FromDto(inputCfg) - account, err := s.db.Q.UpdateAccountOnboardingConfig(ctx, s.db.Db, db_queries.UpdateAccountOnboardingConfigParams{ - OnboardingConfig: onboardingConfigModel, - AccountId: accountUuid, - }) + account, err := s.db.Q.UpdateAccountOnboardingConfig( + ctx, + s.db.Db, + db_queries.UpdateAccountOnboardingConfigParams{ + OnboardingConfig: onboardingConfigModel, + AccountId: accountUuid, + }, + ) if err != nil { return nil, err } diff --git a/backend/services/mgmt/v1alpha1/user-account-service/account-temporal-config.go b/backend/services/mgmt/v1alpha1/user-account-service/account-temporal-config.go index 4fbe103b9d..adafcd68bd 100644 --- a/backend/services/mgmt/v1alpha1/user-account-service/account-temporal-config.go +++ b/backend/services/mgmt/v1alpha1/user-account-service/account-temporal-config.go @@ -25,7 +25,11 @@ func (s *Service) GetAccountTemporalConfig( if err != nil { return nil, err } - err = user.EnforceAccount(ctx, userdata.NewIdentifier(req.Msg.GetAccountId()), rbac.AccountAction_View) + err = user.EnforceAccount( + ctx, + userdata.NewIdentifier(req.Msg.GetAccountId()), + rbac.AccountAction_View, + ) if err != nil { return nil, err } @@ -53,7 +57,11 @@ func (s *Service) SetAccountTemporalConfig( return nil, err } - err = user.EnforceAccount(ctx, userdata.NewIdentifier(req.Msg.GetAccountId()), rbac.AccountAction_Edit) + err = user.EnforceAccount( + ctx, + userdata.NewIdentifier(req.Msg.GetAccountId()), + rbac.AccountAction_Edit, + ) if err != nil { return nil, err } @@ -71,10 +79,14 @@ func (s *Service) SetAccountTemporalConfig( tc := &pg_models.TemporalConfig{} tc.FromDto(dtoCfg) - _, err = s.db.Q.UpdateTemporalConfigByAccount(ctx, s.db.Db, db_queries.UpdateTemporalConfigByAccountParams{ - TemporalConfig: tc, - AccountId: accountUuid, - }) + _, err = s.db.Q.UpdateTemporalConfigByAccount( + ctx, + s.db.Db, + db_queries.UpdateTemporalConfigByAccountParams{ + TemporalConfig: tc, + AccountId: accountUuid, + }, + ) if err != nil { return nil, err } diff --git a/backend/services/mgmt/v1alpha1/user-account-service/billing.go b/backend/services/mgmt/v1alpha1/user-account-service/billing.go index ff5acc680a..d425733d9f 100644 --- a/backend/services/mgmt/v1alpha1/user-account-service/billing.go +++ b/backend/services/mgmt/v1alpha1/user-account-service/billing.go @@ -39,7 +39,11 @@ func (s *Service) GetAccountStatus( if err != nil { return nil, err } - err = user.EnforceAccount(ctx, userdata.NewIdentifier(req.Msg.GetAccountId()), rbac.AccountAction_View) + err = user.EnforceAccount( + ctx, + userdata.NewIdentifier(req.Msg.GetAccountId()), + rbac.AccountAction_View, + ) if err != nil { return nil, err } @@ -125,7 +129,10 @@ func (s *Service) getStripeSubscriptions(customerId string) ([]*stripe.Subscript output = append(output, subIter.Subscription()) } if subIter.Err() != nil { - return nil, fmt.Errorf("encountered error when retrieving stripe subscriptions: %w", subIter.Err()) + return nil, fmt.Errorf( + "encountered error when retrieving stripe subscriptions: %w", + subIter.Err(), + ) } return output, nil } @@ -164,9 +171,12 @@ func (s *Service) IsAccountStatusValid( ctx context.Context, req *connect.Request[mgmtv1alpha1.IsAccountStatusValidRequest], ) (*connect.Response[mgmtv1alpha1.IsAccountStatusValidResponse], error) { - accountStatusResp, err := s.GetAccountStatus(ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountStatusRequest{ - AccountId: req.Msg.GetAccountId(), - })) + accountStatusResp, err := s.GetAccountStatus( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetAccountStatusRequest{ + AccountId: req.Msg.GetAccountId(), + }), + ) if err != nil { return nil, err } @@ -222,7 +232,15 @@ func (s *Service) GetAccountBillingCheckoutSession( ) (*connect.Response[mgmtv1alpha1.GetAccountBillingCheckoutSessionResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) if !s.cfg.IsNeosyncCloud || s.billingclient == nil { - return nil, nucleuserrors.NewNotImplemented(fmt.Sprintf("%s is not implemented", strings.TrimPrefix(mgmtv1alpha1connect.UserAccountServiceGetAccountBillingCheckoutSessionProcedure, "/"))) + return nil, nucleuserrors.NewNotImplemented( + fmt.Sprintf( + "%s is not implemented", + strings.TrimPrefix( + mgmtv1alpha1connect.UserAccountServiceGetAccountBillingCheckoutSessionProcedure, + "/", + ), + ), + ) } logger = logger.With("accountId", req.Msg.GetAccountId()) userdataclient := s.UserDataClient() @@ -236,7 +254,11 @@ func (s *Service) GetAccountBillingCheckoutSession( return nil, err } - err = user.EnforceAccount(ctx, userdata.NewIdentifier(req.Msg.GetAccountId()), rbac.AccountAction_Edit) + err = user.EnforceAccount( + ctx, + userdata.NewIdentifier(req.Msg.GetAccountId()), + rbac.AccountAction_Edit, + ) if err != nil { return nil, err } @@ -249,13 +271,23 @@ func (s *Service) GetAccountBillingCheckoutSession( logger, ) if err != nil { - return nil, fmt.Errorf("was unable to get account and/or upsert stripe customer id: %w", err) + return nil, fmt.Errorf( + "was unable to get account and/or upsert stripe customer id: %w", + err, + ) } if !account.StripeCustomerID.Valid { - return nil, errors.New("stripe customer id does not exist on account after creation attempt") + return nil, errors.New( + "stripe customer id does not exist on account after creation attempt", + ) } - session, err := s.generateCheckoutSession(account.StripeCustomerID.String, account.AccountSlug, user.Id(), logger) + session, err := s.generateCheckoutSession( + account.StripeCustomerID.String, + account.AccountSlug, + user.Id(), + logger, + ) if err != nil { return nil, fmt.Errorf("unable to generate billing checkout session: %w", err) } @@ -270,7 +302,15 @@ func (s *Service) GetAccountBillingPortalSession( req *connect.Request[mgmtv1alpha1.GetAccountBillingPortalSessionRequest], ) (*connect.Response[mgmtv1alpha1.GetAccountBillingPortalSessionResponse], error) { if !s.cfg.IsNeosyncCloud || s.billingclient == nil { - return nil, nucleuserrors.NewNotImplemented(fmt.Sprintf("%s is not implemented", strings.TrimPrefix(mgmtv1alpha1connect.UserAccountServiceGetAccountBillingPortalSessionProcedure, "/"))) + return nil, nucleuserrors.NewNotImplemented( + fmt.Sprintf( + "%s is not implemented", + strings.TrimPrefix( + mgmtv1alpha1connect.UserAccountServiceGetAccountBillingPortalSessionProcedure, + "/", + ), + ), + ) } userdataclient := s.UserDataClient() user, err := userdataclient.GetUser(ctx) @@ -278,7 +318,11 @@ func (s *Service) GetAccountBillingPortalSession( return nil, err } - err = user.EnforceAccount(ctx, userdata.NewIdentifier(req.Msg.GetAccountId()), rbac.AccountAction_Edit) + err = user.EnforceAccount( + ctx, + userdata.NewIdentifier(req.Msg.GetAccountId()), + rbac.AccountAction_Edit, + ) if err != nil { return nil, err } @@ -293,10 +337,15 @@ func (s *Service) GetAccountBillingPortalSession( return nil, err } if !account.StripeCustomerID.Valid { - return nil, nucleuserrors.NewForbidden("requested account does not have a valid stripe customer id") + return nil, nucleuserrors.NewForbidden( + "requested account does not have a valid stripe customer id", + ) } - session, err := s.billingclient.NewBillingPortalSession(account.StripeCustomerID.String, account.AccountSlug) + session, err := s.billingclient.NewBillingPortalSession( + account.StripeCustomerID.String, + account.AccountSlug, + ) if err != nil { return nil, fmt.Errorf("unable to generate billing portal session: %w", err) } @@ -315,7 +364,9 @@ func (s *Service) GetBillingAccounts( return nil, err } if s.cfg.IsNeosyncCloud && !user.IsWorkerApiKey() { - return nil, nucleuserrors.NewUnauthorized("must provide valid authentication credentials for this endpoint") + return nil, nucleuserrors.NewUnauthorized( + "must provide valid authentication credentials for this endpoint", + ) } accountIdsToFilter := []pgtype.UUID{} @@ -353,7 +404,9 @@ func (s *Service) SetBillingMeterEvent( return nil, err } if s.cfg.IsNeosyncCloud && !user.IsWorkerApiKey() { - return nil, nucleuserrors.NewUnauthorized("must provide valid authentication credentials for this endpoint") + return nil, nucleuserrors.NewUnauthorized( + "must provide valid authentication credentials for this endpoint", + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx). @@ -395,7 +448,8 @@ func (s *Service) SetBillingMeterEvent( }) if err != nil { if stripeErr, ok := err.(*stripe.Error); ok { - if stripeErr.Type == stripe.ErrorTypeInvalidRequest && strings.Contains(stripeErr.Msg, "An event already exists with identifier") { + if stripeErr.Type == stripe.ErrorTypeInvalidRequest && + strings.Contains(stripeErr.Msg, "An event already exists with identifier") { logger.Warn("unable to create new meter event, identifier already exists") return connect.NewResponse(&mgmtv1alpha1.SetBillingMeterEventResponse{}), nil } diff --git a/backend/services/mgmt/v1alpha1/user-account-service/users.go b/backend/services/mgmt/v1alpha1/user-account-service/users.go index a34bf49246..668997d521 100644 --- a/backend/services/mgmt/v1alpha1/user-account-service/users.go +++ b/backend/services/mgmt/v1alpha1/user-account-service/users.go @@ -62,7 +62,8 @@ func (s *Service) GetUser( } if tokenctxResp.ApiKeyContextData != nil { - if tokenctxResp.ApiKeyContextData.ApiKeyType == apikey.AccountApiKey && tokenctxResp.ApiKeyContextData.ApiKey != nil { + if tokenctxResp.ApiKeyContextData.ApiKeyType == apikey.AccountApiKey && + tokenctxResp.ApiKeyContextData.ApiKey != nil { return connect.NewResponse(&mgmtv1alpha1.GetUserResponse{ UserId: neosyncdb.UUIDString(tokenctxResp.ApiKeyContextData.ApiKey.UserID), }), nil @@ -71,7 +72,12 @@ func (s *Service) GetUser( UserId: "00000000-0000-0000-0000-000000000000", }), nil } - return nil, nucleuserrors.NewUnauthenticated(fmt.Sprintf("invalid api key type when calling GetUser: %s", tokenctxResp.ApiKeyContextData.ApiKeyType)) + return nil, nucleuserrors.NewUnauthenticated( + fmt.Sprintf( + "invalid api key type when calling GetUser: %s", + tokenctxResp.ApiKeyContextData.ApiKeyType, + ), + ) } else if tokenctxResp.JwtContextData != nil { user, err := s.db.Q.GetUserAssociationByProviderSub(ctx, s.db.Db, tokenctxResp.JwtContextData.AuthUserId) if err != nil && !neosyncdb.IsNoRows(err) { @@ -84,7 +90,9 @@ func (s *Service) GetUser( UserId: neosyncdb.UUIDString(user.UserID), }), nil } - return nil, nucleuserrors.NewUnauthenticated("unable to find a valid user based on the provided auth credentials") + return nil, nucleuserrors.NewUnauthenticated( + "unable to find a valid user based on the provided auth credentials", + ) } func (s *Service) SetUser( @@ -132,7 +140,9 @@ func (s *Service) SetUser( UserId: neosyncdb.UUIDString(user.ID), }), nil } - return nil, nucleuserrors.NewUnauthenticated("unable to find a valid user based on the provided auth credentials") + return nil, nucleuserrors.NewUnauthenticated( + "unable to find a valid user based on the provided auth credentials", + ) } func (s *Service) GetUserAccounts( @@ -167,10 +177,14 @@ func (s *Service) ConvertPersonalToTeamAccount( req *connect.Request[mgmtv1alpha1.ConvertPersonalToTeamAccountRequest], ) (*connect.Response[mgmtv1alpha1.ConvertPersonalToTeamAccountResponse], error) { if !s.cfg.IsAuthEnabled { - return nil, nucleuserrors.NewForbidden("unable to convert personal account to team account as authentication is not enabled") + return nil, nucleuserrors.NewForbidden( + "unable to convert personal account to team account as authentication is not enabled", + ) } if s.cfg.IsNeosyncCloud && s.billingclient == nil { - return nil, nucleuserrors.NewForbidden("creating team accounts via the API is currently forbidden in Neosync Cloud environments. Please contact us to create a team account.") + return nil, nucleuserrors.NewForbidden( + "creating team accounts via the API is currently forbidden in Neosync Cloud environments. Please contact us to create a team account.", + ) } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) @@ -186,7 +200,9 @@ func (s *Service) ConvertPersonalToTeamAccount( personalAccountId := req.Msg.GetAccountId() if personalAccountId == "" { - logger.Debug("account id was not provided during personal->team conversion. Attempting to find personal account") + logger.Debug( + "account id was not provided during personal->team conversion. Attempting to find personal account", + ) accounts, err := s.db.Q.GetAccountsByUser(ctx, s.db.Db, userId) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err @@ -197,7 +213,11 @@ func (s *Service) ConvertPersonalToTeamAccount( for idx := range accounts { if accounts[idx].AccountType == int16(neosyncdb.AccountType_Personal) { personalAccountId = neosyncdb.UUIDString(accounts[idx].ID) - logger.Debug("found personal account to convert to team account", "personalAccountId", personalAccountId) + logger.Debug( + "found personal account to convert to team account", + "personalAccountId", + personalAccountId, + ) break } } @@ -229,11 +249,15 @@ func (s *Service) ConvertPersonalToTeamAccount( if err != nil { return nil, err } - resp, err := s.db.ConvertPersonalToTeamAccount(ctx, &neosyncdb.ConvertPersonalToTeamAccountRequest{ - UserId: userId, - PersonalAccountId: personalAccountUuid, - TeamName: req.Msg.GetName(), - }, logger) + resp, err := s.db.ConvertPersonalToTeamAccount( + ctx, + &neosyncdb.ConvertPersonalToTeamAccountRequest{ + UserId: userId, + PersonalAccountId: personalAccountUuid, + TeamName: req.Msg.GetName(), + }, + logger, + ) if err != nil { return nil, err } @@ -241,12 +265,18 @@ func (s *Service) ConvertPersonalToTeamAccount( newPersonalAccountId := neosyncdb.UUIDString(resp.PersonalAccount.ID) if err := s.rbacClient.SetupNewAccount(ctx, newPersonalAccountId, logger); err != nil { // note: if this fails the account is kind of in a broken state... - return nil, fmt.Errorf("unable to setup newly converted personal account, please reach out to support for further assistance: %w", err) + return nil, fmt.Errorf( + "unable to setup newly converted personal account, please reach out to support for further assistance: %w", + err, + ) } if err := s.rbacClient.SetAccountRole(ctx, rbac.NewUserIdEntity(user.Msg.GetUserId()), rbac.NewAccountIdEntity(newPersonalAccountId), mgmtv1alpha1.AccountRole_ACCOUNT_ROLE_ADMIN); err != nil { // note: if this fails the account is kind of in a broken state... - return nil, fmt.Errorf("unable to set account role for user in new personal account, please reach out to support for further assistance: %w", err) + return nil, fmt.Errorf( + "unable to set account role for user in new personal account, please reach out to support for further assistance: %w", + err, + ) } var checkoutSessionUrl *string @@ -258,9 +288,17 @@ func (s *Service) ConvertPersonalToTeamAccount( logger, ) if err != nil { - return nil, fmt.Errorf("unable to upsert stripe customer id after account creation: %w", err) + return nil, fmt.Errorf( + "unable to upsert stripe customer id after account creation: %w", + err, + ) } - session, err := s.generateCheckoutSession(account.StripeCustomerID.String, account.AccountSlug, user.Msg.GetUserId(), logger) + session, err := s.generateCheckoutSession( + account.StripeCustomerID.String, + account.AccountSlug, + user.Msg.GetUserId(), + logger, + ) if err != nil { return nil, fmt.Errorf("unable to generate checkout session: %w", err) } @@ -296,16 +334,27 @@ func (s *Service) SetPersonalAccount( } logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) - logger = logger.With("accountId", neosyncdb.UUIDString(account.ID), "userId", user.Msg.GetUserId()) + logger = logger.With( + "accountId", + neosyncdb.UUIDString(account.ID), + "userId", + user.Msg.GetUserId(), + ) if err := s.rbacClient.SetupNewAccount(ctx, neosyncdb.UUIDString(account.ID), logger); err != nil { // note: if this fails the account is kind of in a broken state... - return nil, fmt.Errorf("unable to setup new account, please reach out to support for further assistance: %w", err) + return nil, fmt.Errorf( + "unable to setup new account, please reach out to support for further assistance: %w", + err, + ) } if err := s.rbacClient.SetAccountRole(ctx, rbac.NewUserIdEntity(user.Msg.GetUserId()), rbac.NewAccountIdEntity(neosyncdb.UUIDString(account.ID)), mgmtv1alpha1.AccountRole_ACCOUNT_ROLE_ADMIN); err != nil { // note: if this fails the account is kind of in a broken state... - return nil, fmt.Errorf("unable to set account role for user, please reach out to support for further assistance: %w", err) + return nil, fmt.Errorf( + "unable to set account role for user, please reach out to support for further assistance: %w", + err, + ) } return connect.NewResponse(&mgmtv1alpha1.SetPersonalAccountResponse{ @@ -330,10 +379,14 @@ func (s *Service) IsUserInAccount( if err != nil { return nil, err } - apiKeyCount, err := s.db.Q.IsUserInAccountApiKey(ctx, s.db.Db, db_queries.IsUserInAccountApiKeyParams{ - AccountId: accountId, - UserId: userId, - }) + apiKeyCount, err := s.db.Q.IsUserInAccountApiKey( + ctx, + s.db.Db, + db_queries.IsUserInAccountApiKeyParams{ + AccountId: accountId, + UserId: userId, + }, + ) if err != nil { return nil, err } @@ -360,10 +413,14 @@ func (s *Service) CreateTeamAccount( ) (*connect.Response[mgmtv1alpha1.CreateTeamAccountResponse], error) { logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx) if !s.cfg.IsAuthEnabled { - return nil, nucleuserrors.NewForbidden("unable to create team account as authentication is not enabled") + return nil, nucleuserrors.NewForbidden( + "unable to create team account as authentication is not enabled", + ) } if s.cfg.IsNeosyncCloud && s.billingclient == nil { - return nil, nucleuserrors.NewForbidden("creating team accounts via the API is currently forbidden in Neosync Cloud environments. Please contact us to create a team account.") + return nil, nucleuserrors.NewForbidden( + "creating team accounts via the API is currently forbidden in Neosync Cloud environments. Please contact us to create a team account.", + ) } user, err := s.GetUser(ctx, connect.NewRequest(&mgmtv1alpha1.GetUserRequest{})) @@ -391,9 +448,17 @@ func (s *Service) CreateTeamAccount( logger, ) if err != nil { - return nil, fmt.Errorf("unable to upsert stripe customer id after account creation: %w", err) + return nil, fmt.Errorf( + "unable to upsert stripe customer id after account creation: %w", + err, + ) } - session, err := s.generateCheckoutSession(account.StripeCustomerID.String, account.AccountSlug, user.Msg.GetUserId(), logger) + session, err := s.generateCheckoutSession( + account.StripeCustomerID.String, + account.AccountSlug, + user.Msg.GetUserId(), + logger, + ) if err != nil { return nil, fmt.Errorf("unable to generate checkout session: %w", err) } @@ -403,12 +468,18 @@ func (s *Service) CreateTeamAccount( if err := s.rbacClient.SetupNewAccount(ctx, neosyncdb.UUIDString(account.ID), logger); err != nil { // note: if this fails the account is kind of in a broken state... - return nil, fmt.Errorf("unable to setup new account, please reach out to support for further assistance: %w", err) + return nil, fmt.Errorf( + "unable to setup new account, please reach out to support for further assistance: %w", + err, + ) } if err := s.rbacClient.SetAccountRole(ctx, rbac.NewUserIdEntity(user.Msg.GetUserId()), rbac.NewAccountIdEntity(neosyncdb.UUIDString(account.ID)), mgmtv1alpha1.AccountRole_ACCOUNT_ROLE_ADMIN); err != nil { // note: if this fails the account is kind of in a broken state... - return nil, fmt.Errorf("unable to set account role for user, please reach out to support for further assistance: %w", err) + return nil, fmt.Errorf( + "unable to set account role for user, please reach out to support for further assistance: %w", + err, + ) } return connect.NewResponse(&mgmtv1alpha1.CreateTeamAccountResponse{ @@ -417,11 +488,16 @@ func (s *Service) CreateTeamAccount( }), nil } -func (s *Service) getCreateStripeAccountFunction(userId string, logger *slog.Logger) func(ctx context.Context, account db_queries.NeosyncApiAccount) (string, error) { +func (s *Service) getCreateStripeAccountFunction( + userId string, + logger *slog.Logger, +) func(ctx context.Context, account db_queries.NeosyncApiAccount) (string, error) { return func(ctx context.Context, account db_queries.NeosyncApiAccount) (string, error) { email := s.getEmailFromToken(ctx, logger) if email == nil { - return "", errors.New("unable to retrieve user email from auth token when creating stripe account") + return "", errors.New( + "unable to retrieve user email from auth token when creating stripe account", + ) } customer, err := s.billingclient.NewCustomer(&billing.CustomerRequest{ Email: *email, @@ -436,7 +512,10 @@ func (s *Service) getCreateStripeAccountFunction(userId string, logger *slog.Log } } -func (s *Service) generateCheckoutSession(customerId, accountSlug, userId string, logger *slog.Logger) (*stripe.CheckoutSession, error) { +func (s *Service) generateCheckoutSession( + customerId, accountSlug, userId string, + logger *slog.Logger, +) (*stripe.CheckoutSession, error) { if s.billingclient == nil { return nil, errors.New("unable to generate checkout session as stripe client is nil") } @@ -451,7 +530,9 @@ func (s *Service) generateCheckoutSession(customerId, accountSlug, userId string func (s *Service) getEmailFromToken(ctx context.Context, logger *slog.Logger) *string { tokenctxResp, err := tokenctx.GetTokenCtx(ctx) if err != nil { - logger.Error(fmt.Errorf("unable to retrieve token from ctx when getting email: %w", err).Error()) + logger.Error( + fmt.Errorf("unable to retrieve token from ctx when getting email: %w", err).Error(), + ) return nil } if tokenctxResp.JwtContextData != nil && tokenctxResp.JwtContextData.Claims != nil { @@ -495,7 +576,12 @@ func (s *Service) GetTeamAccountMembers( rbacUsers = append(rbacUsers, rbac.NewPgUserIdEntity(user.UserID)) } - userRoles := s.rbacClient.GetUserRoles(ctx, rbacUsers, rbac.NewAccountIdEntity(neosyncdb.UUIDString(accountUuid)), logger) + userRoles := s.rbacClient.GetUserRoles( + ctx, + rbacUsers, + rbac.NewAccountIdEntity(neosyncdb.UUIDString(accountUuid)), + logger, + ) logger.Debug(fmt.Sprintf("found %d users with roles", len(userRoles))) dtoUsers := make([]*mgmtv1alpha1.AccountUser, len(userIdentities)) @@ -509,13 +595,24 @@ func (s *Service) GetTeamAccountMembers( } role, ok := userRoles[rbac.NewPgUserIdEntity(user.UserID).String()] if ok { - logger.Debug(fmt.Sprintf("found role for user: %s - %s", neosyncdb.UUIDString(user.UserID), role.String())) + logger.Debug( + fmt.Sprintf( + "found role for user: %s - %s", + neosyncdb.UUIDString(user.UserID), + role.String(), + ), + ) dtoUsers[i].Role = role.ToDto() } else { dtoUsers[i].Role = mgmtv1alpha1.AccountRole_ACCOUNT_ROLE_UNSPECIFIED } if user.ProviderSub == "" { - logger.Warn(fmt.Sprintf("unable to find provider sub associated with user id: %q", neosyncdb.UUIDString(user.UserID))) + logger.Warn( + fmt.Sprintf( + "unable to find provider sub associated with user id: %q", + neosyncdb.UUIDString(user.UserID), + ), + ) return nil } else { authuser, err := s.authadminclient.GetUserBySub(ctx, user.ProviderSub) @@ -613,7 +710,14 @@ func (s *Service) InviteUserToTeamAccount( role = pgtype.Int4{Int32: int32(req.Msg.GetRole()), Valid: true} } - invite, err := s.db.CreateTeamAccountInvite(ctx, accountUuid, user.PgId(), req.Msg.GetEmail(), expiresAt, role) + invite, err := s.db.CreateTeamAccountInvite( + ctx, + accountUuid, + user.PgId(), + req.Msg.GetEmail(), + expiresAt, + role, + ) if err != nil { return nil, err } @@ -718,11 +822,14 @@ func (s *Service) AcceptTeamAccountInvite( return nil, err } if tokenctxResp.JwtContextData == nil { - return nil, nucleuserrors.NewUnauthenticated("must be a valid jwt user to accept team account invites") + return nil, nucleuserrors.NewUnauthenticated( + "must be a valid jwt user to accept team account invites", + ) } var email *string - if tokenctxResp.JwtContextData.Claims != nil && tokenctxResp.JwtContextData.Claims.Email != nil { + if tokenctxResp.JwtContextData.Claims != nil && + tokenctxResp.JwtContextData.Claims.Email != nil { email = tokenctxResp.JwtContextData.Claims.Email } else { userinfo, err := s.authclient.GetUserInfo(ctx, tokenctxResp.JwtContextData.RawToken) @@ -736,7 +843,9 @@ func (s *Service) AcceptTeamAccountInvite( email = &userinfo.Email } if email == nil { - return nil, nucleuserrors.NewUnauthenticated("unable to find email to valid to add user to account") + return nil, nucleuserrors.NewUnauthenticated( + "unable to find email to valid to add user to account", + ) } validateResp, err := s.db.ValidateInviteAddUserToAccount(ctx, userUuid, req.Msg.Token, *email) @@ -745,7 +854,10 @@ func (s *Service) AcceptTeamAccountInvite( } if err := s.rbacClient.SetAccountRole(ctx, rbac.NewUserIdEntity(user.Msg.GetUserId()), rbac.NewAccountIdEntity(neosyncdb.UUIDString(validateResp.AccountId)), validateResp.Role); err != nil { - return nil, fmt.Errorf("unable to set account role for user, please reach out to support for further assistance: %w", err) + return nil, fmt.Errorf( + "unable to set account role for user, please reach out to support for further assistance: %w", + err, + ) } if err := s.verifyTeamAccount(ctx, validateResp.AccountId); err != nil { @@ -797,7 +909,12 @@ func (s *Service) SetUserRole( return nil, nucleuserrors.NewBadRequest("provided user id is not in account") } - err = s.rbacClient.SetAccountRole(ctx, rbac.NewPgUserIdEntity(requestingUserUuid), rbac.NewAccountIdEntity(req.Msg.GetAccountId()), req.Msg.GetRole()) + err = s.rbacClient.SetAccountRole( + ctx, + rbac.NewPgUserIdEntity(requestingUserUuid), + rbac.NewAccountIdEntity(req.Msg.GetAccountId()), + req.Msg.GetRole(), + ) if err != nil { return nil, err } @@ -810,13 +927,17 @@ func (s *Service) verifyTeamAccount(ctx context.Context, accountId pgtype.UUID) if err != nil { return err } - if account.AccountType != int16(neosyncdb.AccountType_Team) && account.AccountType != int16(neosyncdb.AccountType_Enterprise) { + if account.AccountType != int16(neosyncdb.AccountType_Team) && + account.AccountType != int16(neosyncdb.AccountType_Enterprise) { return nucleuserrors.NewForbidden("account is not a team account") } return nil } -func (s *Service) GetSystemInformation(ctx context.Context, req *connect.Request[mgmtv1alpha1.GetSystemInformationRequest]) (*connect.Response[mgmtv1alpha1.GetSystemInformationResponse], error) { +func (s *Service) GetSystemInformation( + ctx context.Context, + req *connect.Request[mgmtv1alpha1.GetSystemInformationRequest], +) (*connect.Response[mgmtv1alpha1.GetSystemInformationResponse], error) { versionInfo := version.Get() builtDate, err := time.Parse(time.RFC3339, versionInfo.BuildDate) if err != nil { @@ -854,23 +975,37 @@ func (s *Service) HasPermission( switch req.Msg.GetResource().GetType() { case mgmtv1alpha1.ResourcePermission_TYPE_ACCOUNT: if req.Msg.GetResource().GetId() != req.Msg.GetAccountId() { - return connect.NewResponse(&mgmtv1alpha1.HasPermissionResponse{HasPermission: false}), nil + return connect.NewResponse( + &mgmtv1alpha1.HasPermissionResponse{HasPermission: false}, + ), nil } switch req.Msg.GetResource().GetAction() { case mgmtv1alpha1.ResourcePermission_ACTION_CREATE: - ok, err := user.Account(ctx, userdata.NewIdentifier(req.Msg.GetAccountId()), rbac.AccountAction_Create) + ok, err := user.Account( + ctx, + userdata.NewIdentifier(req.Msg.GetAccountId()), + rbac.AccountAction_Create, + ) if err != nil { return nil, err } hasPermission = ok case mgmtv1alpha1.ResourcePermission_ACTION_READ: - ok, err := user.Account(ctx, userdata.NewIdentifier(req.Msg.GetAccountId()), rbac.AccountAction_View) + ok, err := user.Account( + ctx, + userdata.NewIdentifier(req.Msg.GetAccountId()), + rbac.AccountAction_View, + ) if err != nil { return nil, err } hasPermission = ok case mgmtv1alpha1.ResourcePermission_ACTION_UPDATE: - ok, err := user.Account(ctx, userdata.NewIdentifier(req.Msg.GetAccountId()), rbac.AccountAction_Edit) + ok, err := user.Account( + ctx, + userdata.NewIdentifier(req.Msg.GetAccountId()), + rbac.AccountAction_Edit, + ) if err != nil { return nil, err } @@ -879,25 +1014,41 @@ func (s *Service) HasPermission( case mgmtv1alpha1.ResourcePermission_TYPE_CONNECTION: switch req.Msg.GetResource().GetAction() { case mgmtv1alpha1.ResourcePermission_ACTION_CREATE: - ok, err := user.Connection(ctx, userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), rbac.ConnectionAction_Create) + ok, err := user.Connection( + ctx, + userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), + rbac.ConnectionAction_Create, + ) if err != nil { return nil, err } hasPermission = ok case mgmtv1alpha1.ResourcePermission_ACTION_READ: - ok, err := user.Connection(ctx, userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), rbac.ConnectionAction_View) + ok, err := user.Connection( + ctx, + userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), + rbac.ConnectionAction_View, + ) if err != nil { return nil, err } hasPermission = ok case mgmtv1alpha1.ResourcePermission_ACTION_UPDATE: - ok, err := user.Connection(ctx, userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), rbac.ConnectionAction_Edit) + ok, err := user.Connection( + ctx, + userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), + rbac.ConnectionAction_Edit, + ) if err != nil { return nil, err } hasPermission = ok case mgmtv1alpha1.ResourcePermission_ACTION_DELETE: - ok, err := user.Connection(ctx, userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), rbac.ConnectionAction_Delete) + ok, err := user.Connection( + ctx, + userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), + rbac.ConnectionAction_Delete, + ) if err != nil { return nil, err } @@ -906,32 +1057,50 @@ func (s *Service) HasPermission( case mgmtv1alpha1.ResourcePermission_TYPE_JOB: switch req.Msg.GetResource().GetAction() { case mgmtv1alpha1.ResourcePermission_ACTION_CREATE: - ok, err := user.Job(ctx, userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), rbac.JobAction_Create) + ok, err := user.Job( + ctx, + userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), + rbac.JobAction_Create, + ) if err != nil { return nil, err } hasPermission = ok case mgmtv1alpha1.ResourcePermission_ACTION_READ: - ok, err := user.Job(ctx, userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), rbac.JobAction_View) + ok, err := user.Job( + ctx, + userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), + rbac.JobAction_View, + ) if err != nil { return nil, err } hasPermission = ok case mgmtv1alpha1.ResourcePermission_ACTION_UPDATE: - ok, err := user.Job(ctx, userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), rbac.JobAction_Edit) + ok, err := user.Job( + ctx, + userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), + rbac.JobAction_Edit, + ) if err != nil { return nil, err } hasPermission = ok case mgmtv1alpha1.ResourcePermission_ACTION_DELETE: - ok, err := user.Job(ctx, userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), rbac.JobAction_Delete) + ok, err := user.Job( + ctx, + userdata.NewDomainEntity(req.Msg.GetAccountId(), req.Msg.GetResource().GetId()), + rbac.JobAction_Delete, + ) if err != nil { return nil, err } hasPermission = ok } } - return connect.NewResponse(&mgmtv1alpha1.HasPermissionResponse{HasPermission: hasPermission}), nil + return connect.NewResponse( + &mgmtv1alpha1.HasPermissionResponse{HasPermission: hasPermission}, + ), nil } func (s *Service) HasPermissions( @@ -947,10 +1116,13 @@ func (s *Service) HasPermissions( for i, resource := range req.Msg.GetResources() { i, resource := i, resource // https://golang.org/doc/faq#closures_and_goroutines g.Go(func() error { - resp, err := s.HasPermission(errctx, connect.NewRequest(&mgmtv1alpha1.HasPermissionRequest{ - AccountId: req.Msg.GetAccountId(), - Resource: resource, - })) + resp, err := s.HasPermission( + errctx, + connect.NewRequest(&mgmtv1alpha1.HasPermissionRequest{ + AccountId: req.Msg.GetAccountId(), + Resource: resource, + }), + ) if err != nil { return err } diff --git a/backend/sql/postgresql/models/models.go b/backend/sql/postgresql/models/models.go index 00cf0cba3c..2297b47fb6 100644 --- a/backend/sql/postgresql/models/models.go +++ b/backend/sql/postgresql/models/models.go @@ -347,7 +347,9 @@ type MongoConnectionConfig struct { ClientTls *ClientTls `json:"clientTls,omitempty"` } -func (m *MongoConnectionConfig) ToDto(canViewSensitive bool) (*mgmtv1alpha1.MongoConnectionConfig, error) { +func (m *MongoConnectionConfig) ToDto( + canViewSensitive bool, +) (*mgmtv1alpha1.MongoConnectionConfig, error) { if m.Url == nil { return nil, errors.New("mongo connection does not contain url") } @@ -519,7 +521,9 @@ type DynamoDBConfig struct { Endpoint *string `json:"Endpoint,omitempty"` } -func (d *DynamoDBConfig) ToDto(canViewSensitive bool) (*mgmtv1alpha1.DynamoDBConnectionConfig, error) { +func (d *DynamoDBConfig) ToDto( + canViewSensitive bool, +) (*mgmtv1alpha1.DynamoDBConnectionConfig, error) { var creds *mgmtv1alpha1.AwsS3Credentials if d.Credentials != nil { creds = d.Credentials.ToDto(canViewSensitive) @@ -965,7 +969,10 @@ func (p *MssqlColumnRemovalStrategy) ToDto() *mgmtv1alpha1.MssqlSourceConnection } return nil } -func (p *MssqlColumnRemovalStrategy) FromDto(dto *mgmtv1alpha1.MssqlSourceConnectionOptions_ColumnRemovalStrategy) { + +func (p *MssqlColumnRemovalStrategy) FromDto( + dto *mgmtv1alpha1.MssqlSourceConnectionOptions_ColumnRemovalStrategy, +) { if dto == nil { dto = &mgmtv1alpha1.MssqlSourceConnectionOptions_ColumnRemovalStrategy{} } @@ -982,9 +989,12 @@ type MssqlContinueJobColumnRemovalStrategy struct{} func (m *MssqlSourceOptions) ToDto() *mgmtv1alpha1.MssqlSourceConnectionOptions { dto := &mgmtv1alpha1.MssqlSourceConnectionOptions{ - HaltOnNewColumnAddition: m.HaltOnNewColumnAddition, - ConnectionId: m.ConnectionId, - Schemas: make([]*mgmtv1alpha1.MssqlSourceSchemaOption, len(m.Schemas)), + HaltOnNewColumnAddition: m.HaltOnNewColumnAddition, + ConnectionId: m.ConnectionId, + Schemas: make( + []*mgmtv1alpha1.MssqlSourceSchemaOption, + len(m.Schemas), + ), SubsetByForeignKeyConstraints: m.SubsetByForeignKeyConstraints, } for idx := range m.Schemas { @@ -1032,7 +1042,9 @@ func (m *MssqlSourceSchemaOption) FromDto(dto *mgmtv1alpha1.MssqlSourceSchemaOpt m.Tables = FromDtoMssqlSourceTableOption(dto.GetTables()) } -func FromDtoMssqlSourceSchemaOptions(dtos []*mgmtv1alpha1.MssqlSourceSchemaOption) []*MssqlSourceSchemaOption { +func FromDtoMssqlSourceSchemaOptions( + dtos []*mgmtv1alpha1.MssqlSourceSchemaOption, +) []*MssqlSourceSchemaOption { output := make([]*MssqlSourceSchemaOption, len(dtos)) for idx := range dtos { output[idx] = &MssqlSourceSchemaOption{} @@ -1041,7 +1053,9 @@ func FromDtoMssqlSourceSchemaOptions(dtos []*mgmtv1alpha1.MssqlSourceSchemaOptio return output } -func FromDtoMssqlSourceTableOption(dtos []*mgmtv1alpha1.MssqlSourceTableOption) []*MssqlSourceTableOption { +func FromDtoMssqlSourceTableOption( + dtos []*mgmtv1alpha1.MssqlSourceTableOption, +) []*MssqlSourceTableOption { output := make([]*MssqlSourceTableOption, len(dtos)) for idx := range dtos { output[idx] = &MssqlSourceTableOption{} @@ -1091,7 +1105,10 @@ func (s *DynamoDBSourceUnmappedTransformConfig) ToDto() *mgmtv1alpha1.DynamoDBSo S: s.S.ToTransformerDto(), } } -func (s *DynamoDBSourceUnmappedTransformConfig) FromDto(dto *mgmtv1alpha1.DynamoDBSourceUnmappedTransformConfig) error { + +func (s *DynamoDBSourceUnmappedTransformConfig) FromDto( + dto *mgmtv1alpha1.DynamoDBSourceUnmappedTransformConfig, +) error { if dto == nil { dto = &mgmtv1alpha1.DynamoDBSourceUnmappedTransformConfig{} } @@ -1247,7 +1264,10 @@ func (p *MysqlColumnRemovalStrategy) ToDto() *mgmtv1alpha1.MysqlSourceConnection } return nil } -func (p *MysqlColumnRemovalStrategy) FromDto(dto *mgmtv1alpha1.MysqlSourceConnectionOptions_ColumnRemovalStrategy) { + +func (p *MysqlColumnRemovalStrategy) FromDto( + dto *mgmtv1alpha1.MysqlSourceConnectionOptions_ColumnRemovalStrategy, +) { if dto == nil { dto = &mgmtv1alpha1.MysqlSourceConnectionOptions_ColumnRemovalStrategy{} } @@ -1293,7 +1313,10 @@ func (p *PostgresNewColumnAdditionStrategy) ToDto() *mgmtv1alpha1.PostgresSource } return nil } -func (p *PostgresNewColumnAdditionStrategy) FromDto(dto *mgmtv1alpha1.PostgresSourceConnectionOptions_NewColumnAdditionStrategy) { + +func (p *PostgresNewColumnAdditionStrategy) FromDto( + dto *mgmtv1alpha1.PostgresSourceConnectionOptions_NewColumnAdditionStrategy, +) { if dto == nil { dto = &mgmtv1alpha1.PostgresSourceConnectionOptions_NewColumnAdditionStrategy{} } @@ -1329,7 +1352,10 @@ func (p *PostgresColumnRemovalStrategy) ToDto() *mgmtv1alpha1.PostgresSourceConn } return nil } -func (p *PostgresColumnRemovalStrategy) FromDto(dto *mgmtv1alpha1.PostgresSourceConnectionOptions_ColumnRemovalStrategy) { + +func (p *PostgresColumnRemovalStrategy) FromDto( + dto *mgmtv1alpha1.PostgresSourceConnectionOptions_ColumnRemovalStrategy, +) { if dto == nil { dto = &mgmtv1alpha1.PostgresSourceConnectionOptions_ColumnRemovalStrategy{} } @@ -1429,7 +1455,9 @@ func (s *PostgresSourceOptions) FromDto(dto *mgmtv1alpha1.PostgresSourceConnecti } } -func FromDtoPostgresSourceSchemaOptions(dtos []*mgmtv1alpha1.PostgresSourceSchemaOption) []*PostgresSourceSchemaOption { +func FromDtoPostgresSourceSchemaOptions( + dtos []*mgmtv1alpha1.PostgresSourceSchemaOption, +) []*PostgresSourceSchemaOption { output := make([]*PostgresSourceSchemaOption, len(dtos)) for idx := range dtos { schema := dtos[idx] @@ -1508,7 +1536,9 @@ func (s *MysqlNewColumnAdditionStrategy) ToDto() *mgmtv1alpha1.MysqlSourceConnec return nil } -func (s *MysqlNewColumnAdditionStrategy) FromDto(dto *mgmtv1alpha1.MysqlSourceConnectionOptions_NewColumnAdditionStrategy) { +func (s *MysqlNewColumnAdditionStrategy) FromDto( + dto *mgmtv1alpha1.MysqlSourceConnectionOptions_NewColumnAdditionStrategy, +) { if dto.GetStrategy() != nil { switch dto.GetStrategy().(type) { case *mgmtv1alpha1.MysqlSourceConnectionOptions_NewColumnAdditionStrategy_HaltJob_: @@ -1539,7 +1569,9 @@ func (s *MysqlSourceOptions) FromDto(dto *mgmtv1alpha1.MysqlSourceConnectionOpti } } -func FromDtoMysqlSourceSchemaOptions(dtos []*mgmtv1alpha1.MysqlSourceSchemaOption) []*MysqlSourceSchemaOption { +func FromDtoMysqlSourceSchemaOptions( + dtos []*mgmtv1alpha1.MysqlSourceSchemaOption, +) []*MysqlSourceSchemaOption { output := make([]*MysqlSourceSchemaOption, len(dtos)) for idx := range dtos { schema := dtos[idx] @@ -1560,7 +1592,9 @@ func FromDtoMysqlSourceSchemaOptions(dtos []*mgmtv1alpha1.MysqlSourceSchemaOptio return output } -func FromDtoDynamoDBSourceTableOptions(dtos []*mgmtv1alpha1.DynamoDBSourceTableOption) []*DynamoDBSourceTableOption { +func FromDtoDynamoDBSourceTableOptions( + dtos []*mgmtv1alpha1.DynamoDBSourceTableOption, +) []*DynamoDBSourceTableOption { tables := make([]*DynamoDBSourceTableOption, len(dtos)) for i, table := range dtos { t := &DynamoDBSourceTableOption{} @@ -1598,7 +1632,9 @@ func (s *GenerateSourceOptions) FromDto(dto *mgmtv1alpha1.GenerateSourceOptions) s.Schemas = FromDtoGenerateSourceSchemaOptions(dto.Schemas) } -func FromDtoGenerateSourceSchemaOptions(dtos []*mgmtv1alpha1.GenerateSourceSchemaOption) []*GenerateSourceSchemaOption { +func FromDtoGenerateSourceSchemaOptions( + dtos []*mgmtv1alpha1.GenerateSourceSchemaOption, +) []*GenerateSourceSchemaOption { output := make([]*GenerateSourceSchemaOption, len(dtos)) for idx := range dtos { schema := dtos[idx] @@ -1654,7 +1690,9 @@ func (s *AiGenerateSourceOptions) FromDto(dto *mgmtv1alpha1.AiGenerateSourceOpti s.GenerateBatchSize = dto.GenerateBatchSize } -func FromDtoAiGenerateSourceSchemaOptions(dtos []*mgmtv1alpha1.AiGenerateSourceSchemaOption) []*AiGenerateSourceSchemaOption { +func FromDtoAiGenerateSourceSchemaOptions( + dtos []*mgmtv1alpha1.AiGenerateSourceSchemaOption, +) []*AiGenerateSourceSchemaOption { output := make([]*AiGenerateSourceSchemaOption, len(dtos)) for idx := range dtos { schema := dtos[idx] @@ -1808,7 +1846,10 @@ func (d *DynamoDBDestinationOptions) ToDto() *mgmtv1alpha1.DynamoDBDestinationCo TableMappings: tableMappings, } } -func (d *DynamoDBDestinationOptions) FromDto(dto *mgmtv1alpha1.DynamoDBDestinationConnectionOptions) { + +func (d *DynamoDBDestinationOptions) FromDto( + dto *mgmtv1alpha1.DynamoDBDestinationConnectionOptions, +) { d.TableMappings = make([]*DynamoDBDestinationTableMapping, 0, len(dto.GetTableMappings())) for _, dtotm := range dto.GetTableMappings() { @@ -1829,7 +1870,10 @@ func (d *DynamoDBDestinationTableMapping) ToDto() *mgmtv1alpha1.DynamoDBDestinat DestinationTable: d.DestinationTable, } } -func (d *DynamoDBDestinationTableMapping) FromDto(dto *mgmtv1alpha1.DynamoDBDestinationTableMapping) { + +func (d *DynamoDBDestinationTableMapping) FromDto( + dto *mgmtv1alpha1.DynamoDBDestinationTableMapping, +) { d.SourceTable = dto.GetSourceTable() d.DestinationTable = dto.GetDestinationTable() } @@ -1880,7 +1924,9 @@ func (m *PostgresDestinationOptions) ToDto() *mgmtv1alpha1.PostgresDestinationCo } } -func (m *PostgresDestinationOptions) FromDto(dto *mgmtv1alpha1.PostgresDestinationConnectionOptions) { +func (m *PostgresDestinationOptions) FromDto( + dto *mgmtv1alpha1.PostgresDestinationConnectionOptions, +) { if dto == nil { dto = &mgmtv1alpha1.PostgresDestinationConnectionOptions{} } @@ -2219,7 +2265,9 @@ func (a *AwsS3DestinationOptions) ToDto() *mgmtv1alpha1.AwsS3DestinationConnecti storageClass := mgmtv1alpha1.AwsS3DestinationConnectionOptions_STORAGE_CLASS_UNSPECIFIED if a.StorageClass != nil { if _, ok := mgmtv1alpha1.AwsS3DestinationConnectionOptions_StorageClass_name[*a.StorageClass]; ok { - storageClass = mgmtv1alpha1.AwsS3DestinationConnectionOptions_StorageClass(*a.StorageClass) + storageClass = mgmtv1alpha1.AwsS3DestinationConnectionOptions_StorageClass( + *a.StorageClass, + ) } } var batch *mgmtv1alpha1.BatchConfig diff --git a/backend/sql/postgresql/models/transformers.go b/backend/sql/postgresql/models/transformers.go index 1e8db5dfa8..f88c010f6d 100644 --- a/backend/sql/postgresql/models/transformers.go +++ b/backend/sql/postgresql/models/transformers.go @@ -218,7 +218,9 @@ type GenerateIpAddressConfig struct { IpType *int32 `json:"ipType,omitempty"` } -func (t *JobMappingTransformerModel) FromTransformerDto(tr *mgmtv1alpha1.JobMappingTransformer) error { +func (t *JobMappingTransformerModel) FromTransformerDto( + tr *mgmtv1alpha1.JobMappingTransformer, +) error { if tr == nil { tr = &mgmtv1alpha1.JobMappingTransformer{} } @@ -428,11 +430,15 @@ func (t *TransformerConfig) ToTransformerConfigDto() *mgmtv1alpha1.TransformerCo return &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformEmailConfig{ TransformEmailConfig: &mgmtv1alpha1.TransformEmail{ - PreserveDomain: t.TransformEmail.PreserveDomain, - PreserveLength: t.TransformEmail.PreserveLength, - ExcludedDomains: t.TransformEmail.ExcludedDomains, - EmailType: (*mgmtv1alpha1.GenerateEmailType)(t.TransformEmail.EmailType), - InvalidEmailAction: (*mgmtv1alpha1.InvalidEmailAction)(t.TransformEmail.InvalidEmailAction), + PreserveDomain: t.TransformEmail.PreserveDomain, + PreserveLength: t.TransformEmail.PreserveLength, + ExcludedDomains: t.TransformEmail.ExcludedDomains, + EmailType: (*mgmtv1alpha1.GenerateEmailType)( + t.TransformEmail.EmailType, + ), + InvalidEmailAction: (*mgmtv1alpha1.InvalidEmailAction)( + t.TransformEmail.InvalidEmailAction, + ), }, }, } diff --git a/cli/internal/auth/account-id.go b/cli/internal/auth/account-id.go index f217731640..8e65fab04c 100644 --- a/cli/internal/auth/account-id.go +++ b/cli/internal/auth/account-id.go @@ -30,7 +30,10 @@ func ResolveAccountIdFromFlag( } if apiKey != nil && *apiKey != "" { logger.Debug("api key detected, attempting to resolve account id from key.") - uaResp, err := userclient.GetUserAccounts(ctx, connect.NewRequest(&mgmtv1alpha1.GetUserAccountsRequest{})) + uaResp, err := userclient.GetUserAccounts( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetUserAccountsRequest{}), + ) if err != nil { return "", fmt.Errorf("unable to resolve account id from api key: %w", err) } @@ -44,7 +47,10 @@ func ResolveAccountIdFromFlag( } accountId, err := userconfig.GetAccountId() if err != nil { - return "", fmt.Errorf(`unable to resolve account id from account context, please use the "neosync accounts switch" command to set an active account context: %w`, err) + return "", fmt.Errorf( + `unable to resolve account id from account context, please use the "neosync accounts switch" command to set an active account context: %w`, + err, + ) } logger.Debug(fmt.Sprintf("account id %q resolved from user config", accountId)) return accountId, nil diff --git a/cli/internal/auth/tokens.go b/cli/internal/auth/tokens.go index 2938d6296d..1517bd56e4 100644 --- a/cli/internal/auth/tokens.go +++ b/cli/internal/auth/tokens.go @@ -26,7 +26,10 @@ func GetAuthEnabled( ctx context.Context, authclient mgmtv1alpha1connect.AuthServiceClient, ) (bool, error) { - isEnabledResp, err := authclient.GetAuthStatus(ctx, connect.NewRequest(&mgmtv1alpha1.GetAuthStatusRequest{})) + isEnabledResp, err := authclient.GetAuthStatus( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetAuthStatusRequest{}), + ) if err != nil { return false, err } @@ -67,7 +70,11 @@ func WithExtraHeaders(headers map[string]string) HttpOption { } // Returns an instance of *http.Client that includes the Neosync API Token if one was found in the environment -func GetNeosyncHttpClient(ctx context.Context, logger *slog.Logger, opts ...HttpOption) (*http.Client, error) { +func GetNeosyncHttpClient( + ctx context.Context, + logger *slog.Logger, + opts ...HttpOption, +) (*http.Client, error) { cfg := &httpClientConfig{} for _, opt := range opts { opt(cfg) @@ -97,7 +104,11 @@ func GetNeosyncHttpClient(ctx context.Context, logger *slog.Logger, opts ...Http // Method that handles retrieving the user's access token from the file system // This method automatically handles checking to see if the token is valid. // If it's invalid for any reason, will attempt to refresh and get + set a new access token -func getAccessToken(ctx context.Context, headers map[string]string, logger *slog.Logger) (string, error) { +func getAccessToken( + ctx context.Context, + headers map[string]string, + logger *slog.Logger, +) (string, error) { httpclient := http_client.NewWithHeaders(headers) neosyncurl := GetNeosyncUrl() authclient := mgmtv1alpha1connect.NewAuthServiceClient(httpclient, neosyncurl) @@ -107,7 +118,9 @@ func getAccessToken(ctx context.Context, headers map[string]string, logger *slog return "", err } authedAuthClient := mgmtv1alpha1connect.NewAuthServiceClient( - http_client.NewWithHeaders(http_client.MergeMaps(headers, http_client.GetBearerAuthHeaders(&accessToken))), + http_client.NewWithHeaders( + http_client.MergeMaps(headers, http_client.GetBearerAuthHeaders(&accessToken)), + ), neosyncurl, ) logger.Debug("found existing access token, checking if still valid") @@ -117,25 +130,39 @@ func getAccessToken(ctx context.Context, headers map[string]string, logger *slog if err := userconfig.RemoveAccessToken(); err != nil { return "", err } - logger.Debug(fmt.Errorf("access token is no longer valid. attempting to refresh...: %w", err).Error()) + logger.Debug( + fmt.Errorf("access token is no longer valid. attempting to refresh...: %w", err). + Error(), + ) refreshtoken, err := userconfig.GetRefreshToken() if err != nil { return "", fmt.Errorf("unable to find refresh token: %w", err) } - refreshResp, err := authclient.RefreshCli(ctx, connect.NewRequest(&mgmtv1alpha1.RefreshCliRequest{ - RefreshToken: refreshtoken, - })) + refreshResp, err := authclient.RefreshCli( + ctx, + connect.NewRequest(&mgmtv1alpha1.RefreshCliRequest{ + RefreshToken: refreshtoken, + }), + ) if err != nil { return "", fmt.Errorf("unable to refresh token, must login again: %w", err) } err = userconfig.SetAccessToken(refreshResp.Msg.AccessToken.AccessToken) if err != nil { - logger.Warn("unable to write refreshed access token back to user config", "error", err.Error()) + logger.Warn( + "unable to write refreshed access token back to user config", + "error", + err.Error(), + ) } if refreshResp.Msg.AccessToken.RefreshToken != nil { err = userconfig.SetRefreshToken(*refreshResp.Msg.AccessToken.RefreshToken) if err != nil { - logger.Warn("unable to write refreshed refresh token back to user config", "error", err.Error()) + logger.Warn( + "unable to write refreshed refresh token back to user config", + "error", + err.Error(), + ) } } return refreshResp.Msg.GetAccessToken().GetAccessToken(), nil diff --git a/cli/internal/cmds/neosync/accounts/list.go b/cli/internal/cmds/neosync/accounts/list.go index da67091331..d2c0158631 100644 --- a/cli/internal/cmds/neosync/accounts/list.go +++ b/cli/internal/cmds/neosync/accounts/list.go @@ -53,7 +53,9 @@ func listAccounts( accountsResp, err := userclient.GetUserAccounts( ctx, - connect.NewRequest[mgmtv1alpha1.GetUserAccountsRequest](&mgmtv1alpha1.GetUserAccountsRequest{}), + connect.NewRequest[mgmtv1alpha1.GetUserAccountsRequest]( + &mgmtv1alpha1.GetUserAccountsRequest{}, + ), ) if err != nil { return err diff --git a/cli/internal/cmds/neosync/accounts/switch.go b/cli/internal/cmds/neosync/accounts/switch.go index ce7be7f29c..90036bcd30 100644 --- a/cli/internal/cmds/neosync/accounts/switch.go +++ b/cli/internal/cmds/neosync/accounts/switch.go @@ -34,10 +34,13 @@ var ( header = lipgloss.NewStyle().Faint(true).PaddingLeft(4) bold = lipgloss.NewStyle().Bold(true) itemStyle = lipgloss.NewStyle().PaddingLeft(2).Height(1) - selectedItemStyle = lipgloss.NewStyle().PaddingLeft(2).Height(1).Foreground(lipgloss.Color("170")) - paginationStyle = list.DefaultStyles().PaginationStyle.PaddingLeft(4) - helpStyle = list.DefaultStyles().HelpStyle.PaddingLeft(4).PaddingBottom(1) - quitTextStyle = lipgloss.NewStyle().Margin(1, 0, 2, 4) + selectedItemStyle = lipgloss.NewStyle(). + PaddingLeft(2). + Height(1). + Foreground(lipgloss.Color("170")) + paginationStyle = list.DefaultStyles().PaginationStyle.PaddingLeft(4) + helpStyle = list.DefaultStyles().HelpStyle.PaddingLeft(4).PaddingBottom(1) + quitTextStyle = lipgloss.NewStyle().Margin(1, 0, 2, 4) ) func newSwitchCmd() *cobra.Command { @@ -160,7 +163,8 @@ func switchAccount( var account *mgmtv1alpha1.UserAccount for _, a := range accounts { - if strings.EqualFold(a.Name, *accountIdOrName) || strings.EqualFold(a.Id, *accountIdOrName) { + if strings.EqualFold(a.Name, *accountIdOrName) || + strings.EqualFold(a.Id, *accountIdOrName) { account = a } } @@ -174,7 +178,11 @@ func switchAccount( return fmt.Errorf("unable to set account context: %w", err) } - fmt.Println(itemStyle.Render(fmt.Sprintf("\n Switched account to %s (%s) \n", account.Name, account.Id))) //nolint:forbidigo + fmt.Println( + itemStyle.Render( + fmt.Sprintf("\n Switched account to %s (%s) \n", account.Name, account.Id), + ), + ) //nolint:forbidigo return nil } @@ -206,7 +214,13 @@ type itemDelegate struct{} func (d itemDelegate) Height() int { return 1 } func (d itemDelegate) Spacing() int { return 0 } func (d itemDelegate) Update(_ tea.Msg, _ *list.Model) tea.Cmd { return nil } -func (d itemDelegate) Render(w io.Writer, m list.Model, index int, listItem list.Item) { //nolint:gocritic + +func (d itemDelegate) Render( + w io.Writer, + m list.Model, + index int, + listItem list.Item, +) { //nolint:gocritic i, ok := listItem.(item) if !ok { return @@ -229,7 +243,11 @@ func (d itemDelegate) Render(w io.Writer, m list.Model, index int, listItem list } if index == m.Index() { fn = func(s ...string) string { - return fmt.Sprintf("%s%s", itemHeader, selectedItemStyle.Render("● "+strings.Join(s, " "))) + return fmt.Sprintf( + "%s%s", + itemHeader, + selectedItemStyle.Render("● "+strings.Join(s, " ")), + ) } } @@ -276,9 +294,13 @@ func (m *model) View() string { if m.choice.description != "" { err := userconfig.SetAccountId(m.choice.description) if err != nil { - return quitTextStyle.Render(fmt.Sprintf("Failed to switch accounts. Error %s", err.Error())) + return quitTextStyle.Render( + fmt.Sprintf("Failed to switch accounts. Error %s", err.Error()), + ) } - return quitTextStyle.Render(fmt.Sprintf("Switched account to %s (%s)", m.choice.title, m.choice.description)) + return quitTextStyle.Render( + fmt.Sprintf("Switched account to %s (%s)", m.choice.title, m.choice.description), + ) } if m.quitting || m.choice.title == "Cancel" { return quitTextStyle.Render("No changes made") diff --git a/cli/internal/cmds/neosync/connections/list.go b/cli/internal/cmds/neosync/connections/list.go index b044416dbe..091bef886d 100644 --- a/cli/internal/cmds/neosync/connections/list.go +++ b/cli/internal/cmds/neosync/connections/list.go @@ -39,7 +39,8 @@ func newListCmd() *cobra.Command { return listConnections(cmd.Context(), debugMode, &apiKey, &accountId) }, } - cmd.Flags().String("account-id", "", "Account to list connections for. Defaults to account id in cli context") + cmd.Flags(). + String("account-id", "", "Account to list connections for. Defaults to account id in cli context") return cmd } @@ -82,9 +83,12 @@ func getConnections( connectionclient mgmtv1alpha1connect.ConnectionServiceClient, accountId string, ) ([]*mgmtv1alpha1.Connection, error) { - res, err := connectionclient.GetConnections(ctx, connect.NewRequest[mgmtv1alpha1.GetConnectionsRequest](&mgmtv1alpha1.GetConnectionsRequest{ - AccountId: accountId, - })) + res, err := connectionclient.GetConnections( + ctx, + connect.NewRequest[mgmtv1alpha1.GetConnectionsRequest](&mgmtv1alpha1.GetConnectionsRequest{ + AccountId: accountId, + }), + ) if err != nil { return nil, err } diff --git a/cli/internal/cmds/neosync/jobs/list.go b/cli/internal/cmds/neosync/jobs/list.go index 696ec7a244..4576a2f58a 100644 --- a/cli/internal/cmds/neosync/jobs/list.go +++ b/cli/internal/cmds/neosync/jobs/list.go @@ -39,7 +39,8 @@ func newListCmd() *cobra.Command { return listJobs(cmd.Context(), debugMode, &apiKey, &accountId) }, } - cmd.Flags().String("account-id", "", "Account to list jobs for. Defaults to account id in cli context") + cmd.Flags(). + String("account-id", "", "Account to list jobs for. Defaults to account id in cli context") return cmd } @@ -68,9 +69,12 @@ func listJobs( httpclient, neosyncurl, ) - res, err := jobclient.GetJobs(ctx, connect.NewRequest[mgmtv1alpha1.GetJobsRequest](&mgmtv1alpha1.GetJobsRequest{ - AccountId: accountId, - })) + res, err := jobclient.GetJobs( + ctx, + connect.NewRequest[mgmtv1alpha1.GetJobsRequest](&mgmtv1alpha1.GetJobsRequest{ + AccountId: accountId, + }), + ) if err != nil { return err } @@ -80,9 +84,14 @@ func listJobs( for idx := range res.Msg.Jobs { idx := idx errgrp.Go(func() error { - jsres, err := jobclient.GetJobStatus(errctx, connect.NewRequest[mgmtv1alpha1.GetJobStatusRequest](&mgmtv1alpha1.GetJobStatusRequest{ - JobId: res.Msg.Jobs[idx].Id, - })) + jsres, err := jobclient.GetJobStatus( + errctx, + connect.NewRequest[mgmtv1alpha1.GetJobStatusRequest]( + &mgmtv1alpha1.GetJobStatusRequest{ + JobId: res.Msg.Jobs[idx].Id, + }, + ), + ) if err != nil { return err } diff --git a/cli/internal/cmds/neosync/jobs/trigger.go b/cli/internal/cmds/neosync/jobs/trigger.go index 277caf2383..37216b948a 100644 --- a/cli/internal/cmds/neosync/jobs/trigger.go +++ b/cli/internal/cmds/neosync/jobs/trigger.go @@ -49,7 +49,8 @@ func newTriggerCmd() *cobra.Command { return triggerJob(cmd.Context(), debugMode, jobUuid.String(), &apiKey, &accountId) }, } - cmd.Flags().String("account-id", "", "Account that job is in. Defaults to account id in cli context") + cmd.Flags(). + String("account-id", "", "Account that job is in. Defaults to account id in cli context") return cmd } @@ -79,18 +80,24 @@ func triggerJob( httpclient, neosyncurl, ) - job, err := jobclient.GetJob(ctx, connect.NewRequest[mgmtv1alpha1.GetJobRequest](&mgmtv1alpha1.GetJobRequest{ - Id: jobId, - })) + job, err := jobclient.GetJob( + ctx, + connect.NewRequest[mgmtv1alpha1.GetJobRequest](&mgmtv1alpha1.GetJobRequest{ + Id: jobId, + }), + ) if err != nil { return err } if job.Msg.GetJob().GetAccountId() != accountId { return fmt.Errorf("unable to trigger job run. job not found. accountId: %s", accountId) } - _, err = jobclient.CreateJobRun(ctx, connect.NewRequest[mgmtv1alpha1.CreateJobRunRequest](&mgmtv1alpha1.CreateJobRunRequest{ - JobId: jobId, - })) + _, err = jobclient.CreateJobRun( + ctx, + connect.NewRequest[mgmtv1alpha1.CreateJobRunRequest](&mgmtv1alpha1.CreateJobRunRequest{ + JobId: jobId, + }), + ) if err != nil { return err } diff --git a/cli/internal/cmds/neosync/login/login.go b/cli/internal/cmds/neosync/login/login.go index e4df6d689a..4df223270d 100644 --- a/cli/internal/cmds/neosync/login/login.go +++ b/cli/internal/cmds/neosync/login/login.go @@ -43,7 +43,9 @@ func NewCmd() *cobra.Command { logger := cli_logger.NewSLogger(cli_logger.GetCharmLevelOrDefault(debugMode)) if apiKey != "" { - logger.Info(`found api key, no need to log in. run "neosync whoami" to verify that the api key is valid`) + logger.Info( + `found api key, no need to log in. run "neosync whoami" to verify that the api key is valid`, + ) return nil } return login(cmd.Context(), logger) @@ -118,11 +120,14 @@ func oAuthLogin( ) error { state := uuid.NewString() - authorizeurlResp, err := authclient.GetAuthorizeUrl(ctx, connect.NewRequest(&mgmtv1alpha1.GetAuthorizeUrlRequest{ - State: state, - RedirectUri: redirectUri, - Scope: "openid profile offline_access", - })) + authorizeurlResp, err := authclient.GetAuthorizeUrl( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetAuthorizeUrlRequest{ + State: state, + RedirectUri: redirectUri, + Scope: "openid profile offline_access", + }), + ) if err != nil { return err } @@ -146,7 +151,10 @@ func oAuthLogin( }() if err := webbrowser.Open(authorizeurlResp.Msg.Url); err != nil { - fmt.Println("There was an issue opening the web browser, proceed to the following url to finish logging in to Neosync:\n", authorizeurlResp.Msg.Url) //nolint + fmt.Println( + "There was an issue opening the web browser, proceed to the following url to finish logging in to Neosync:\n", + authorizeurlResp.Msg.Url, + ) //nolint } select { @@ -164,7 +172,12 @@ func oAuthLogin( if result.State != state { return errors.New("state received from response was not what was sent") } - loginResp, err := authclient.LoginCli(ctx, connect.NewRequest(&mgmtv1alpha1.LoginCliRequest{Code: result.Code, RedirectUri: redirectUri})) + loginResp, err := authclient.LoginCli( + ctx, + connect.NewRequest( + &mgmtv1alpha1.LoginCliRequest{Code: result.Code, RedirectUri: redirectUri}, + ), + ) if err != nil { return err } diff --git a/cli/internal/cmds/neosync/neosync.go b/cli/internal/cmds/neosync/neosync.go index ef7b874f12..4ae394df21 100644 --- a/cli/internal/cmds/neosync/neosync.go +++ b/cli/internal/cmds/neosync/neosync.go @@ -64,7 +64,8 @@ func Execute() { rootCmd.PersistentFlags().StringVar( &cfgFilePath, "config", "", fmt.Sprintf("config file (default is $HOME/%s/%s.%s)", neosyncDirName, cliSettingsFileNameNoExt, cliSettingsFileExt), ) - rootCmd.PersistentFlags().String(apiKeyFlag, "", fmt.Sprintf("Neosync API Key. Takes precedence over $%s", apiKeyEnvVarName)) + rootCmd.PersistentFlags(). + String(apiKeyFlag, "", fmt.Sprintf("Neosync API Key. Takes precedence over $%s", apiKeyEnvVarName)) rootCmd.PersistentFlags().Bool("debug", false, "Run in debug mode") @@ -97,7 +98,14 @@ func migrateOldConfig(cfgFilePath string) { if err != nil { return } - err = os.Rename(oldPath, filepath.Join(home, neosyncDirName, fmt.Sprintf("%s.%s", cliSettingsFileNameNoExt, cliSettingsFileExt))) + err = os.Rename( + oldPath, + filepath.Join( + home, + neosyncDirName, + fmt.Sprintf("%s.%s", cliSettingsFileNameNoExt, cliSettingsFileExt), + ), + ) if err != nil { return } diff --git a/cli/internal/cmds/neosync/sync/config.go b/cli/internal/cmds/neosync/sync/config.go index b164256a96..f1df9f0660 100644 --- a/cli/internal/cmds/neosync/sync/config.go +++ b/cli/internal/cmds/neosync/sync/config.go @@ -167,7 +167,10 @@ func newCobraCmdConfig( return nil, err } if _, err := time.ParseDuration(openDuration); err != nil { - return nil, fmt.Errorf("unable to parse destination-open-duration as a valid duration string: %w", err) + return nil, fmt.Errorf( + "unable to parse destination-open-duration as a valid duration string: %w", + err, + ) } config.Destination.ConnectionOpts.OpenDuration = &openDuration } @@ -177,7 +180,10 @@ func newCobraCmdConfig( return nil, err } if _, err := time.ParseDuration(idleDuration); err != nil { - return nil, fmt.Errorf("unable to parse destination-idle-duration as valid duration string: %w", err) + return nil, fmt.Errorf( + "unable to parse destination-idle-duration as valid duration string: %w", + err, + ) } config.Destination.ConnectionOpts.IdleDuration = &idleDuration } @@ -207,34 +213,53 @@ func newCobraCmdConfig( config.Destination.Batch = &batchConfig{} } if _, err := time.ParseDuration(batchperiod); err != nil { - return nil, fmt.Errorf("unable to parse destination-batch-period as valid duration string: %w", err) + return nil, fmt.Errorf( + "unable to parse destination-batch-period as valid duration string: %w", + err, + ) } config.Destination.Batch.Period = &batchperiod } return config, nil } -func isConfigValid(cmd *cmdConfig, logger *slog.Logger, sourceConnection *mgmtv1alpha1.Connection, sourceConnectionType benthosbuilder_shared.ConnectionType) error { - if sourceConnectionType == benthosbuilder_shared.ConnectionTypeAwsS3 && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") { +func isConfigValid( + cmd *cmdConfig, + logger *slog.Logger, + sourceConnection *mgmtv1alpha1.Connection, + sourceConnectionType benthosbuilder_shared.ConnectionType, +) error { + if sourceConnectionType == benthosbuilder_shared.ConnectionTypeAwsS3 && + (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && + (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") { return errors.New("s3 source connection type requires job-id or job-run-id") } - if sourceConnectionType == benthosbuilder_shared.ConnectionTypeGCP && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") { + if sourceConnectionType == benthosbuilder_shared.ConnectionTypeGCP && + (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && + (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") { return errors.New("gcp cloud storage source connection type requires job-id or job-run-id") } - if (sourceConnectionType == benthosbuilder_shared.ConnectionTypeAwsS3 || sourceConnectionType == benthosbuilder_shared.ConnectionTypeGCP) && cmd.Destination.InitSchema { + if (sourceConnectionType == benthosbuilder_shared.ConnectionTypeAwsS3 || sourceConnectionType == benthosbuilder_shared.ConnectionTypeGCP) && + cmd.Destination.InitSchema { return errors.New("init schema is only supported when source is a SQL Database") } - if cmd.Destination != nil && cmd.Destination.TruncateCascade && cmd.Destination.Driver == mysqlDriver { + if cmd.Destination != nil && cmd.Destination.TruncateCascade && + cmd.Destination.Driver == mysqlDriver { return fmt.Errorf("truncate cascade is only supported in postgres") } - if cmd.Destination != nil && cmd.Destination.OnConflict.DoNothing && cmd.Destination.OnConflict.DoUpdate != nil && cmd.Destination.OnConflict.DoUpdate.Enabled { - return errors.New("on-conflict-do-nothing and on-conflict-do-update cannot be used together") + if cmd.Destination != nil && cmd.Destination.OnConflict.DoNothing && + cmd.Destination.OnConflict.DoUpdate != nil && + cmd.Destination.OnConflict.DoUpdate.Enabled { + return errors.New( + "on-conflict-do-nothing and on-conflict-do-update cannot be used together", + ) } - if sourceConnectionType == benthosbuilder_shared.ConnectionTypeMysql || sourceConnectionType == benthosbuilder_shared.ConnectionTypePostgres { + if sourceConnectionType == benthosbuilder_shared.ConnectionTypeMysql || + sourceConnectionType == benthosbuilder_shared.ConnectionTypePostgres { if cmd.Destination.Driver == "" { return fmt.Errorf("must provide destination-driver") } @@ -243,7 +268,9 @@ func isConfigValid(cmd *cmdConfig, logger *slog.Logger, sourceConnection *mgmtv1 } if cmd.Destination.Driver != mysqlDriver && cmd.Destination.Driver != postgresDriver { - return errors.New("unsupported destination driver. only pgx (postgres) and mysql are currently supported") + return errors.New( + "unsupported destination driver. only pgx (postgres) and mysql are currently supported", + ) } } diff --git a/cli/internal/cmds/neosync/sync/job.go b/cli/internal/cmds/neosync/sync/job.go index 4fb34e8f1b..6ee7c3c0ae 100644 --- a/cli/internal/cmds/neosync/sync/job.go +++ b/cli/internal/cmds/neosync/sync/job.go @@ -18,7 +18,8 @@ func toJob( return nil, err } jobId := uuid.NewString() - if cmd.Source.ConnectionOpts != nil && cmd.Source.ConnectionOpts.JobId != nil && *cmd.Source.ConnectionOpts.JobId != "" { + if cmd.Source.ConnectionOpts != nil && cmd.Source.ConnectionOpts.JobId != nil && + *cmd.Source.ConnectionOpts.JobId != "" { jobId = *cmd.Source.ConnectionOpts.JobId } tables := map[string]string{} @@ -32,12 +33,18 @@ func toJob( Source: &mgmtv1alpha1.JobSource{ Options: sourceConnOpts, }, - Destinations: []*mgmtv1alpha1.JobDestination{toJobDestination(cmd, destinationConnection, tables)}, - Mappings: toJobMappings(sourceSchema), + Destinations: []*mgmtv1alpha1.JobDestination{ + toJobDestination(cmd, destinationConnection, tables), + }, + Mappings: toJobMappings(sourceSchema), }, nil } -func toJobDestination(cmd *cmdConfig, destinationConnection *mgmtv1alpha1.Connection, tables map[string]string) *mgmtv1alpha1.JobDestination { +func toJobDestination( + cmd *cmdConfig, + destinationConnection *mgmtv1alpha1.Connection, + tables map[string]string, +) *mgmtv1alpha1.JobDestination { return &mgmtv1alpha1.JobDestination{ ConnectionId: destinationConnection.Id, Id: uuid.NewString(), @@ -45,7 +52,9 @@ func toJobDestination(cmd *cmdConfig, destinationConnection *mgmtv1alpha1.Connec } } -func toJobSourceOption(sourceConnection *mgmtv1alpha1.Connection) (*mgmtv1alpha1.JobSourceOptions, error) { +func toJobSourceOption( + sourceConnection *mgmtv1alpha1.Connection, +) (*mgmtv1alpha1.JobSourceOptions, error) { switch sourceConnection.ConnectionConfig.Config.(type) { case *mgmtv1alpha1.ConnectionConfig_PgConfig: return &mgmtv1alpha1.JobSourceOptions{ diff --git a/cli/internal/cmds/neosync/sync/sync.go b/cli/internal/cmds/neosync/sync/sync.go index 690a32aa0b..82fb1f8dc9 100644 --- a/cli/internal/cmds/neosync/sync/sync.go +++ b/cli/internal/cmds/neosync/sync/sync.go @@ -100,12 +100,12 @@ type sqlDestinationConfig struct { TruncateCascade bool `yaml:"truncate-cascade,omitempty"` OnConflict onConflictConfig `yaml:"on-conflict,omitempty"` ConnectionOpts sqlConnectionOptions `yaml:"connection-opts,omitempty"` - MaxInFlight *uint32 `yaml:"max-in-flight,omitempty" json:"max-in-flight,omitempty"` - Batch *batchConfig `yaml:"batch,omitempty" json:"batch,omitempty"` + MaxInFlight *uint32 `yaml:"max-in-flight,omitempty" json:"max-in-flight,omitempty"` + Batch *batchConfig `yaml:"batch,omitempty" json:"batch,omitempty"` } type batchConfig struct { - Count *uint32 `yaml:"count,omitempty" json:"count,omitempty"` + Count *uint32 `yaml:"count,omitempty" json:"count,omitempty"` Period *string `yaml:"period,omitempty" json:"period,omitempty"` } @@ -141,24 +141,34 @@ func NewCmd() *cobra.Command { } cmd.Flags().String("connection-id", "", "Connection id for sync source") - cmd.Flags().String("job-id", "", "Id of Job to sync data from. Only used with [AWS S3, GCP Cloud Storage] connections. Can use job-run-id instead.") - cmd.Flags().String("job-run-id", "", "Id of Job run to sync data from. Only used with [AWS S3, GCP Cloud Storage] connections. Can use job-id instead.") + cmd.Flags(). + String("job-id", "", "Id of Job to sync data from. Only used with [AWS S3, GCP Cloud Storage] connections. Can use job-run-id instead.") + cmd.Flags(). + String("job-run-id", "", "Id of Job run to sync data from. Only used with [AWS S3, GCP Cloud Storage] connections. Can use job-id instead.") cmd.Flags().String("destination-connection-url", "", "Connection url for sync output") cmd.Flags().String("destination-driver", "", "Connection driver for sync output") - cmd.Flags().String("account-id", "", "Account source connection is in. Defaults to account id in cli context") + cmd.Flags(). + String("account-id", "", "Account source connection is in. Defaults to account id in cli context") cmd.Flags().String("config", "", "Location of config file") cmd.Flags().Bool("init-schema", false, "Create table schema and its constraints") cmd.Flags().Bool("truncate-before-insert", false, "Truncate table before insert") - cmd.Flags().Bool("truncate-cascade", false, "Truncate cascade table before insert (postgres only)") - cmd.Flags().Bool("on-conflict-do-nothing", false, "If there is a conflict when inserting data do not insert") + cmd.Flags(). + Bool("truncate-cascade", false, "Truncate cascade table before insert (postgres only)") + cmd.Flags(). + Bool("on-conflict-do-nothing", false, "If there is a conflict when inserting data do not insert") cmd.Flags().Int32("destination-open-limit", 0, "Maximum number of open connections") cmd.Flags().Int32("destination-idle-limit", 0, "Maximum number of idle connections") - cmd.Flags().String("destination-idle-duration", "", "Maximum amount of time a connection may be idle (e.g. '5m')") - cmd.Flags().String("destination-open-duration", "", "Maximum amount of time a connection may be open (e.g. '30s')") - cmd.Flags().Uint32("destination-max-in-flight", 0, "Maximum allowed batched rows to sync. If not provided, uses server default of 64") - cmd.Flags().Uint32("destination-batch-count", 0, "Batch size of rows that will be sent to the destination. If not provided, uses server default of 100.") - cmd.Flags().String("destination-batch-period", "", "Duration of time that a batch of rows will be sent. If not provided, uses server default fo 5s. (e.g. 5s, 1m)") + cmd.Flags(). + String("destination-idle-duration", "", "Maximum amount of time a connection may be idle (e.g. '5m')") + cmd.Flags(). + String("destination-open-duration", "", "Maximum amount of time a connection may be open (e.g. '30s')") + cmd.Flags(). + Uint32("destination-max-in-flight", 0, "Maximum allowed batched rows to sync. If not provided, uses server default of 64") + cmd.Flags(). + Uint32("destination-batch-count", 0, "Batch size of rows that will be sent to the destination. If not provided, uses server default of 100.") + cmd.Flags(). + String("destination-batch-period", "", "Duration of time that a batch of rows will be sent. If not provided, uses server default fo 5s. (e.g. 5s, 1m)") // dynamo flags cmd.Flags().String("aws-access-key-id", "", "AWS Access Key ID for DynamoDB") @@ -217,10 +227,26 @@ func newCliSyncFromCmd( return nil, err } connectInterceptorOption := connect.WithInterceptors(connectInterceptors...) - connectionclient := mgmtv1alpha1connect.NewConnectionServiceClient(httpclient, neosyncurl, connectInterceptorOption) - connectiondataclient := mgmtv1alpha1connect.NewConnectionDataServiceClient(httpclient, neosyncurl, connectInterceptorOption) - transformerclient := mgmtv1alpha1connect.NewTransformersServiceClient(httpclient, neosyncurl, connectInterceptorOption) - userclient := mgmtv1alpha1connect.NewUserAccountServiceClient(httpclient, neosyncurl, connectInterceptorOption) + connectionclient := mgmtv1alpha1connect.NewConnectionServiceClient( + httpclient, + neosyncurl, + connectInterceptorOption, + ) + connectiondataclient := mgmtv1alpha1connect.NewConnectionDataServiceClient( + httpclient, + neosyncurl, + connectInterceptorOption, + ) + transformerclient := mgmtv1alpha1connect.NewTransformersServiceClient( + httpclient, + neosyncurl, + connectInterceptorOption, + ) + userclient := mgmtv1alpha1connect.NewUserAccountServiceClient( + httpclient, + neosyncurl, + connectInterceptorOption, + ) cmdCfg, err := newCobraCmdConfig( cmd, @@ -236,7 +262,9 @@ func newCliSyncFromCmd( logger.Info("Starting sync") - connmanager := connectionmanager.NewConnectionManager(sqlprovider.NewProvider(&sqlconnect.SqlOpenConnector{})) + connmanager := connectionmanager.NewConnectionManager( + sqlprovider.NewProvider(&sqlconnect.SqlOpenConnector{}), + ) sqlmanagerclient := sqlmanager.NewSqlManager(sqlmanager.WithConnectionManager(connmanager)) sync := &clisync{ @@ -256,9 +284,12 @@ func newCliSyncFromCmd( func (c *clisync) configureAndRunSync() error { c.logger.Debug("Retrieving neosync connection") - connResp, err := c.connectionclient.GetConnection(c.ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: c.cmd.Source.ConnectionId, - })) + connResp, err := c.connectionclient.GetConnection( + c.ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: c.cmd.Source.ConnectionId, + }), + ) if err != nil { return err } @@ -301,8 +332,13 @@ func (c *clisync) configureAndRunSync() error { benthosEnv, err := benthos_environment.NewEnvironment( c.logger, benthos_environment.WithSqlConfig(&benthos_environment.SqlConfig{ - Provider: pool_sql_provider.NewConnectionProvider(c.connmanager, getConnectionById, c.session, c.logger), - IsRetry: false, + Provider: pool_sql_provider.NewConnectionProvider( + c.connmanager, + getConnectionById, + c.session, + c.logger, + ), + IsRetry: false, }), benthos_environment.WithConnectionDataConfig(&benthos_environment.ConnectionDataConfig{ NeosyncConnectionDataApi: c.connectiondataclient, @@ -410,7 +446,9 @@ func (c *clisync) configureSync() ([][]*benthosbuilder.BenthosConfigResponse, er return groupedConfigs, nil } -func (c *clisync) getConnectionSchemaConfigByConnectionType(connection *mgmtv1alpha1.Connection) (*mgmtv1alpha1.ConnectionSchemaConfig, error) { +func (c *clisync) getConnectionSchemaConfigByConnectionType( + connection *mgmtv1alpha1.Connection, +) (*mgmtv1alpha1.ConnectionSchemaConfig, error) { switch conn := connection.GetConnectionConfig().GetConfig().(type) { case *mgmtv1alpha1.ConnectionConfig_PgConfig: return &mgmtv1alpha1.ConnectionSchemaConfig{ @@ -464,7 +502,13 @@ var ( streamBuilderMu syncmap.Mutex ) -func syncData(ctx context.Context, benv *service.Environment, cfg *benthosbuilder.BenthosConfigResponse, logger *slog.Logger, outputType output.OutputType) error { +func syncData( + ctx context.Context, + benv *service.Environment, + cfg *benthosbuilder.BenthosConfigResponse, + logger *slog.Logger, + outputType output.OutputType, +) error { configbits, err := yaml.Marshal(cfg.Config) if err != nil { return err @@ -502,7 +546,18 @@ func syncData(ctx context.Context, benv *service.Environment, cfg *benthosbuilde return fmt.Errorf("failed to create StreamBuilder") } if outputType == output.PlainOutput { - streambldr.SetLogger(logger.With("benthos", "true", "schema", cfg.TableSchema, "table", cfg.TableName, "runType", runType)) + streambldr.SetLogger( + logger.With( + "benthos", + "true", + "schema", + cfg.TableSchema, + "table", + cfg.TableName, + "runType", + runType, + ), + ) } if benv == nil { return fmt.Errorf("benthos env is nil") @@ -567,7 +622,9 @@ func cmdConfigToDestinationConnection(cmd *cmdConfig) *mgmtv1alpha1.Connection { ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{ Url: cmd.Destination.ConnectionUrl, }, - ConnectionOptions: toSqlConnectionOptions(cmd.Destination.ConnectionOpts), + ConnectionOptions: toSqlConnectionOptions( + cmd.Destination.ConnectionOpts, + ), }, }, }, @@ -582,7 +639,9 @@ func cmdConfigToDestinationConnection(cmd *cmdConfig) *mgmtv1alpha1.Connection { ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ Url: cmd.Destination.ConnectionUrl, }, - ConnectionOptions: toSqlConnectionOptions(cmd.Destination.ConnectionOpts), + ConnectionOptions: toSqlConnectionOptions( + cmd.Destination.ConnectionOpts, + ), }, }, }, @@ -597,7 +656,9 @@ func cmdConfigToDestinationConnection(cmd *cmdConfig) *mgmtv1alpha1.Connection { ConnectionConfig: &mgmtv1alpha1.MssqlConnectionConfig_Url{ Url: cmd.Destination.ConnectionUrl, }, - ConnectionOptions: toSqlConnectionOptions(cmd.Destination.ConnectionOpts), + ConnectionOptions: toSqlConnectionOptions( + cmd.Destination.ConnectionOpts, + ), }, }, }, @@ -630,12 +691,16 @@ func cmdConfigToDestinationConnection(cmd *cmdConfig) *mgmtv1alpha1.Connection { return &mgmtv1alpha1.Connection{} } -func cmdConfigToDestinationConnectionOptions(cmd *cmdConfig, tables map[string]string) *mgmtv1alpha1.JobDestinationOptions { +func cmdConfigToDestinationConnectionOptions( + cmd *cmdConfig, + tables map[string]string, +) *mgmtv1alpha1.JobDestinationOptions { if cmd.Destination != nil { switch cmd.Destination.Driver { case postgresDriver: conflictConfig := &mgmtv1alpha1.PostgresOnConflictConfig{} - if cmd.Destination.OnConflict.DoUpdate != nil && cmd.Destination.OnConflict.DoUpdate.Enabled { + if cmd.Destination.OnConflict.DoUpdate != nil && + cmd.Destination.OnConflict.DoUpdate.Enabled { conflictConfig.Strategy = &mgmtv1alpha1.PostgresOnConflictConfig_Update{ Update: &mgmtv1alpha1.PostgresOnConflictConfig_PostgresOnConflictUpdate{}, } @@ -660,7 +725,8 @@ func cmdConfigToDestinationConnectionOptions(cmd *cmdConfig, tables map[string]s } case mysqlDriver: conflictConfig := &mgmtv1alpha1.MysqlOnConflictConfig{} - if cmd.Destination.OnConflict.DoUpdate != nil && cmd.Destination.OnConflict.DoUpdate.Enabled { + if cmd.Destination.OnConflict.DoUpdate != nil && + cmd.Destination.OnConflict.DoUpdate.Enabled { conflictConfig.Strategy = &mgmtv1alpha1.MysqlOnConflictConfig_Update{ Update: &mgmtv1alpha1.MysqlOnConflictConfig_MysqlOnConflictUpdate{}, } @@ -731,11 +797,18 @@ func (c *clisync) runDestinationInitStatements( defer db.Db().Close() if c.cmd.Destination.InitSchema { for _, block := range schemaConfig.InitSchemaStatements { - c.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements))) + c.logger.Info( + fmt.Sprintf( + "[%s] found %d statements to execute during schema initialization", + block.Label, + len(block.Statements), + ), + ) if len(block.Statements) == 0 { continue } - err = db.Db().BatchExec(c.ctx, batchSize, block.Statements, &sqlmanager_shared.BatchExecOpts{}) + err = db.Db(). + BatchExec(c.ctx, batchSize, block.Statements, &sqlmanager_shared.BatchExecOpts{}) if err != nil { c.logger.Error(fmt.Sprintf("Error creating tables: %v", err)) return fmt.Errorf("unable to exec pg %s statements: %w", block.Label, err) @@ -752,7 +825,8 @@ func (c *clisync) runDestinationInitStatements( truncateCascadeStmts = append(truncateCascadeStmts, stmt) } } - err = db.Db().BatchExec(c.ctx, batchSize, truncateCascadeStmts, &sqlmanager_shared.BatchExecOpts{}) + err = db.Db(). + BatchExec(c.ctx, batchSize, truncateCascadeStmts, &sqlmanager_shared.BatchExecOpts{}) if err != nil { c.logger.Error(fmt.Sprintf("Error truncate cascade tables: %v", err)) return err @@ -779,10 +853,14 @@ func (c *clisync) runDestinationInitStatements( } orderedTableTruncateStatements := []string{} for _, t := range orderedTablesResp.OrderedTables { - orderedTableTruncateStatements = append(orderedTableTruncateStatements, schemaConfig.TruncateTableStatementsMap[t.String()]) + orderedTableTruncateStatements = append( + orderedTableTruncateStatements, + schemaConfig.TruncateTableStatementsMap[t.String()], + ) } disableFkChecks := sqlmanager_shared.DisableForeignKeyChecks - err = db.Db().BatchExec(c.ctx, batchSize, orderedTableTruncateStatements, &sqlmanager_shared.BatchExecOpts{Prefix: &disableFkChecks}) + err = db.Db(). + BatchExec(c.ctx, batchSize, orderedTableTruncateStatements, &sqlmanager_shared.BatchExecOpts{Prefix: &disableFkChecks}) if err != nil { c.logger.Error(fmt.Sprintf("Error truncating tables: %v", err)) return err @@ -819,7 +897,14 @@ func buildSyncConfigs( } } - runConfigs, err := runconfigs.BuildRunConfigs(schemaConfig.TableConstraints, map[string]string{}, primaryKeysMap, tableColMap, uniqueIndexesMap, uniqueConstraintsMap) + runConfigs, err := runconfigs.BuildRunConfigs( + schemaConfig.TableConstraints, + map[string]string{}, + primaryKeysMap, + tableColMap, + uniqueIndexesMap, + uniqueConstraintsMap, + ) if err != nil { logger.Error(err.Error()) return nil @@ -892,10 +977,13 @@ func (c *clisync) getSourceConnectionNonSqlSchemaConfig( connection *mgmtv1alpha1.Connection, sc *mgmtv1alpha1.ConnectionSchemaConfig, ) (*schemaConfig, error) { - schemaResp, err := c.connectiondataclient.GetConnectionSchema(c.ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ - ConnectionId: connection.Id, - SchemaConfig: sc, - })) + schemaResp, err := c.connectiondataclient.GetConnectionSchema( + c.ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ + ConnectionId: connection.Id, + SchemaConfig: sc, + }), + ) if err != nil { return nil, err } @@ -918,10 +1006,13 @@ func (c *clisync) getSourceConnectionSqlSchemaConfig( var initSchemaStatements []*mgmtv1alpha1.SchemaInitStatements errgrp, errctx := errgroup.WithContext(c.ctx) errgrp.Go(func() error { - schemaResp, err := c.connectiondataclient.GetConnectionSchema(errctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ - ConnectionId: connection.Id, - SchemaConfig: sc, - })) + schemaResp, err := c.connectiondataclient.GetConnectionSchema( + errctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ + ConnectionId: connection.Id, + SchemaConfig: sc, + }), + ) if err != nil { return err } @@ -930,7 +1021,14 @@ func (c *clisync) getSourceConnectionSqlSchemaConfig( }) errgrp.Go(func() error { - constraintConnectionResp, err := c.connectiondataclient.GetConnectionTableConstraints(errctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionTableConstraintsRequest{ConnectionId: c.cmd.Source.ConnectionId})) + constraintConnectionResp, err := c.connectiondataclient.GetConnectionTableConstraints( + errctx, + connect.NewRequest( + &mgmtv1alpha1.GetConnectionTableConstraintsRequest{ + ConnectionId: c.cmd.Source.ConnectionId, + }, + ), + ) if err != nil { return err } @@ -943,7 +1041,13 @@ func (c *clisync) getSourceConnectionSqlSchemaConfig( if c.cmd.Destination != nil { errgrp.Go(func() error { - initStatementsResp, err := getTableInitStatementMap(errctx, c.logger, c.connectiondataclient, c.cmd.Source.ConnectionId, c.cmd.Destination) + initStatementsResp, err := getTableInitStatementMap( + errctx, + c.logger, + c.connectiondataclient, + c.cmd.Source.ConnectionId, + c.cmd.Destination, + ) if err != nil { return err } @@ -990,10 +1094,13 @@ func (c *clisync) getDestinationSchemaConfig( sourceConnection *mgmtv1alpha1.Connection, sc *mgmtv1alpha1.ConnectionSchemaConfig, ) (*schemaConfig, error) { - schemaResp, err := c.connectiondataclient.GetConnectionSchema(c.ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ - ConnectionId: sourceConnection.Id, - SchemaConfig: sc, - })) + schemaResp, err := c.connectiondataclient.GetConnectionSchema( + c.ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ + ConnectionId: sourceConnection.Id, + SchemaConfig: sc, + }), + ) if err != nil { return nil, fmt.Errorf("unable to retrieve connection schema for connection: %w", err) } @@ -1001,7 +1108,10 @@ func (c *clisync) getDestinationSchemaConfig( destSchemas, err := c.getDestinationSchemas() if err != nil { - return nil, fmt.Errorf("unable to retrieve destination connection schema for connection: %w", err) + return nil, fmt.Errorf( + "unable to retrieve destination connection schema for connection: %w", + err, + ) } tableColMap := getTableColMap(sourceSchemas) @@ -1082,7 +1192,9 @@ func (c *clisync) getDestinationSchemaConfig( }, nil } -func (c *clisync) getDestinationTableConstraints(schemas []string) (*sqlmanager_shared.TableConstraints, error) { +func (c *clisync) getDestinationTableConstraints( + schemas []string, +) (*sqlmanager_shared.TableConstraints, error) { cctx, cancel := context.WithDeadline(c.ctx, time.Now().Add(5*time.Second)) defer cancel() destConnection := cmdConfigToDestinationConnection(c.cmd) diff --git a/cli/internal/cmds/neosync/sync/ui.go b/cli/internal/cmds/neosync/sync/ui.go index b9e3093d08..2f7ec9b996 100644 --- a/cli/internal/cmds/neosync/sync/ui.go +++ b/cli/internal/cmds/neosync/sync/ui.go @@ -40,13 +40,22 @@ var ( printlog = lipgloss.NewStyle().PaddingLeft(2) currentPkgNameStyle = lipgloss.NewStyle().PaddingLeft(2).Foreground(lipgloss.Color("211")) doneStyle = lipgloss.NewStyle().Margin(1, 2) - checkMark = lipgloss.NewStyle().PaddingLeft(2).Foreground(lipgloss.Color("42")).SetString("✓") - helpStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Margin(1, 0) - dotStyle = helpStyle.UnsetMargins() - durationStyle = dotStyle + checkMark = lipgloss.NewStyle(). + PaddingLeft(2). + Foreground(lipgloss.Color("42")). + SetString("✓") + helpStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("241")).Margin(1, 0) + dotStyle = helpStyle.UnsetMargins() + durationStyle = dotStyle ) -func newModel(ctx context.Context, benv *service.Environment, groupedConfigs [][]*benthosbuilder.BenthosConfigResponse, logger *slog.Logger, outputType output.OutputType) *model { +func newModel( + ctx context.Context, + benv *service.Environment, + groupedConfigs [][]*benthosbuilder.BenthosConfigResponse, + logger *slog.Logger, + outputType output.OutputType, +) *model { s := spinner.New() s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("63")) return &model{ @@ -122,7 +131,13 @@ func (m *model) View() string { var pkgName string if len(processingTables) > 5 { - pkgName = currentPkgNameStyle.Render(fmt.Sprintf("%s \n + %d others...", strings.Join(processingTables[:5], "\n"), len(processingTables))) + pkgName = currentPkgNameStyle.Render( + fmt.Sprintf( + "%s \n + %d others...", + strings.Join(processingTables[:5], "\n"), + len(processingTables), + ), + ) } else { pkgName = currentPkgNameStyle.Render(strings.Join(processingTables, "\n")) } @@ -132,7 +147,10 @@ func (m *model) View() string { type syncedDataMsg map[string]string -func (m *model) syncConfigs(ctx context.Context, configs []*benthosbuilder.BenthosConfigResponse) tea.Cmd { +func (m *model) syncConfigs( + ctx context.Context, + configs []*benthosbuilder.BenthosConfigResponse, +) tea.Cmd { return func() tea.Msg { messageMap := syncmap.Map{} errgrp, errctx := errgroup.WithContext(ctx) @@ -149,7 +167,9 @@ func (m *model) syncConfigs(ctx context.Context, configs []*benthosbuilder.Benth } duration := time.Since(start) messageMap.Store(cfg.Name, duration) - m.logger.Info(fmt.Sprintf("Finished syncing table %s %s", cfg.Name, duration.String())) + m.logger.Info( + fmt.Sprintf("Finished syncing table %s %s", cfg.Name, duration.String()), + ) return nil }) } @@ -161,7 +181,7 @@ func (m *model) syncConfigs(ctx context.Context, configs []*benthosbuilder.Benth results := map[string]string{} //nolint:gofmt - messageMap.Range(func(key, value interface{}) bool { + messageMap.Range(func(key, value any) bool { d := value.(time.Duration) results[key.(string)] = fmt.Sprintf("%s %s %s", checkMark, key, durationStyle.Render(d.String())) @@ -187,7 +207,13 @@ func getConfigCount(groupedConfigs [][]*benthosbuilder.BenthosConfigResponse) in return count } -func runSync(ctx context.Context, outputType output.OutputType, benv *service.Environment, groupedConfigs [][]*benthosbuilder.BenthosConfigResponse, logger *slog.Logger) error { +func runSync( + ctx context.Context, + outputType output.OutputType, + benv *service.Environment, + groupedConfigs [][]*benthosbuilder.BenthosConfigResponse, + logger *slog.Logger, +) error { var opts []tea.ProgramOption var synclogger = logger if outputType == output.PlainOutput { diff --git a/cli/internal/cmds/neosync/sync/util.go b/cli/internal/cmds/neosync/sync/util.go index bf4539ce73..f5b3083ed3 100644 --- a/cli/internal/cmds/neosync/sync/util.go +++ b/cli/internal/cmds/neosync/sync/util.go @@ -25,7 +25,10 @@ func parseDriverString(str string) (DriverType, bool) { return p, ok } -func isConfigReady(config *benthosbuilder.BenthosConfigResponse, queuedMap map[string][]string) bool { +func isConfigReady( + config *benthosbuilder.BenthosConfigResponse, + queuedMap map[string][]string, +) bool { for _, dep := range config.DependsOn { if cols, ok := queuedMap[dep.Table]; ok { for _, dc := range dep.Columns { @@ -40,7 +43,10 @@ func isConfigReady(config *benthosbuilder.BenthosConfigResponse, queuedMap map[s return true } -func groupConfigsByDependency(configs []*benthosbuilder.BenthosConfigResponse, logger *slog.Logger) [][]*benthosbuilder.BenthosConfigResponse { +func groupConfigsByDependency( + configs []*benthosbuilder.BenthosConfigResponse, + logger *slog.Logger, +) [][]*benthosbuilder.BenthosConfigResponse { groupedConfigs := [][]*benthosbuilder.BenthosConfigResponse{} configMap := map[string]*benthosbuilder.BenthosConfigResponse{} queuedMap := map[string][]string{} // map -> table to cols @@ -119,7 +125,10 @@ func buildDependencyMap(syncConfigs []*runconfigs.RunConfig) map[string][]string return dependencyMap } -func areSourceAndDestCompatible(connection *mgmtv1alpha1.Connection, destinationDriver *DriverType) error { +func areSourceAndDestCompatible( + connection *mgmtv1alpha1.Connection, + destinationDriver *DriverType, +) error { switch connection.ConnectionConfig.Config.(type) { case *mgmtv1alpha1.ConnectionConfig_PgConfig: if destinationDriver != nil && *destinationDriver != postgresDriver { diff --git a/cli/internal/output/output.go b/cli/internal/output/output.go index 17d70b77d2..5d1d9b08c1 100644 --- a/cli/internal/output/output.go +++ b/cli/internal/output/output.go @@ -31,7 +31,8 @@ func AttachOutputFlag(cmd *cobra.Command) { outputVals = append(outputVals, outputType) } - cmd.Flags().StringP("output", "o", string(autoOutput), fmt.Sprintf("Set type of output (%s).", strings.Join(outputVals, ", "))) + cmd.Flags(). + StringP("output", "o", string(autoOutput), fmt.Sprintf("Set type of output (%s).", strings.Join(outputVals, ", "))) } func ValidateAndRetrieveOutputFlag(cmd *cobra.Command) (OutputType, error) { diff --git a/cli/internal/userconfig/folder.go b/cli/internal/userconfig/folder.go index 6ccc785623..ae031d37b4 100644 --- a/cli/internal/userconfig/folder.go +++ b/cli/internal/userconfig/folder.go @@ -17,7 +17,9 @@ const ( // 2. Checks for existence of XDG_CONFIG_HOME and append "neosync" to it, if exists // 3. Use ~/.neosync func GetOrCreateNeosyncFolder() (string, error) { - configDir := os.Getenv("NEOSYNC_CONFIG_DIR") // helpful for tools such as direnv and people who want it somewhere interesting + configDir := os.Getenv( + "NEOSYNC_CONFIG_DIR", + ) // helpful for tools such as direnv and people who want it somewhere interesting xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") // linux users expect this to be respected var fullName string diff --git a/cli/internal/version/version.go b/cli/internal/version/version.go index 2996dbb00d..ff701971c7 100644 --- a/cli/internal/version/version.go +++ b/cli/internal/version/version.go @@ -10,11 +10,11 @@ import ( type VersionInfo struct { GitVersion string `json:"gitVersion" yaml:"gitVersion"` - GitCommit string `json:"gitCommit" yaml:"gitCommit"` - BuildDate string `json:"buildDate" yaml:"buildDate"` - GoVersion string `json:"goVersion" yaml:"goVersion"` - Compiler string `json:"compiler" yaml:"compiler"` - Platform string `json:"platform" yaml:"platform"` + GitCommit string `json:"gitCommit" yaml:"gitCommit"` + BuildDate string `json:"buildDate" yaml:"buildDate"` + GoVersion string `json:"goVersion" yaml:"goVersion"` + Compiler string `json:"compiler" yaml:"compiler"` + Platform string `json:"platform" yaml:"platform"` } func (info *VersionInfo) String() string { @@ -31,8 +31,15 @@ func (info *VersionInfo) Headers() map[string]string { } func constructUserAgent(info *VersionInfo) string { - return fmt.Sprintf("neosync/%s (commit: %s; build: %s; go: %s; compiler: %s; platform: %s)", - info.GitVersion, info.GitCommit, info.BuildDate, info.GoVersion, info.Compiler, info.Platform) + return fmt.Sprintf( + "neosync/%s (commit: %s; build: %s; go: %s; compiler: %s; platform: %s)", + info.GitVersion, + info.GitCommit, + info.BuildDate, + info.GoVersion, + info.Compiler, + info.Platform, + ) } func (info *VersionInfo) GrpcMetadata() metadata.MD { diff --git a/internal/authmgmt/auth0/admin-client.go b/internal/authmgmt/auth0/admin-client.go index 1b0c1e9858..edf6054988 100644 --- a/internal/authmgmt/auth0/admin-client.go +++ b/internal/authmgmt/auth0/admin-client.go @@ -14,7 +14,10 @@ type Auth0MgmtClient struct { } func New(domain, clientId, clientSecret string) (*Auth0MgmtClient, error) { - client, err := management.New(domain, management.WithClientCredentials(context.Background(), clientId, clientSecret)) + client, err := management.New( + domain, + management.WithClientCredentials(context.Background(), clientId, clientSecret), + ) if err != nil { return nil, err } diff --git a/internal/aws/aws-manager.go b/internal/aws/aws-manager.go index e0a54a3fbc..0813df27a0 100644 --- a/internal/aws/aws-manager.go +++ b/internal/aws/aws-manager.go @@ -38,7 +38,10 @@ type NeosyncAwsManagerClient interface { params *s3.GetObjectInput, ) (*s3.GetObjectOutput, error) - NewDynamoDbClient(ctx context.Context, connCfg *mgmtv1alpha1.DynamoDBConnectionConfig) (*DynamoDbClient, error) + NewDynamoDbClient( + ctx context.Context, + connCfg *mgmtv1alpha1.DynamoDBConnectionConfig, + ) (*DynamoDbClient, error) } func New() *NeosyncAwsManager { @@ -46,7 +49,10 @@ func New() *NeosyncAwsManager { } // Returns a wrapper dynamodb client -func (n *NeosyncAwsManager) NewDynamoDbClient(ctx context.Context, connCfg *mgmtv1alpha1.DynamoDBConnectionConfig) (*DynamoDbClient, error) { +func (n *NeosyncAwsManager) NewDynamoDbClient( + ctx context.Context, + connCfg *mgmtv1alpha1.DynamoDBConnectionConfig, +) (*DynamoDbClient, error) { client, err := n.newDynamoDbClient(ctx, connCfg) if err != nil { return nil, err @@ -55,7 +61,10 @@ func (n *NeosyncAwsManager) NewDynamoDbClient(ctx context.Context, connCfg *mgmt } // returns the raw, underlying aws client -func (n *NeosyncAwsManager) newDynamoDbClient(ctx context.Context, connCfg *mgmtv1alpha1.DynamoDBConnectionConfig) (*dynamodb.Client, error) { +func (n *NeosyncAwsManager) newDynamoDbClient( + ctx context.Context, + connCfg *mgmtv1alpha1.DynamoDBConnectionConfig, +) (*dynamodb.Client, error) { cfg, err := getDynamoAwsConfig(ctx, connCfg) if err != nil { return nil, err @@ -67,7 +76,10 @@ func (n *NeosyncAwsManager) newDynamoDbClient(ctx context.Context, connCfg *mgmt }), nil } -func (n *NeosyncAwsManager) NewS3Client(ctx context.Context, connCfg *mgmtv1alpha1.AwsS3ConnectionConfig) (*s3.Client, error) { +func (n *NeosyncAwsManager) NewS3Client( + ctx context.Context, + connCfg *mgmtv1alpha1.AwsS3ConnectionConfig, +) (*s3.Client, error) { cfg, err := getS3AwsConfig(ctx, connCfg) if err != nil { return nil, err @@ -116,7 +128,10 @@ func withS3Region(region *string) func(o *s3.Options) { } } -func getS3AwsConfig(ctx context.Context, s3ConnConfig *mgmtv1alpha1.AwsS3ConnectionConfig) (*aws.Config, error) { +func getS3AwsConfig( + ctx context.Context, + s3ConnConfig *mgmtv1alpha1.AwsS3ConnectionConfig, +) (*aws.Config, error) { return GetAwsConfig(ctx, &AwsCredentialsConfig{ Region: s3ConnConfig.GetRegion(), Endpoint: s3ConnConfig.GetEndpoint(), @@ -131,7 +146,10 @@ func getS3AwsConfig(ctx context.Context, s3ConnConfig *mgmtv1alpha1.AwsS3Connect }) } -func getDynamoAwsConfig(ctx context.Context, dynConnConfig *mgmtv1alpha1.DynamoDBConnectionConfig) (*aws.Config, error) { +func getDynamoAwsConfig( + ctx context.Context, + dynConnConfig *mgmtv1alpha1.DynamoDBConnectionConfig, +) (*aws.Config, error) { return GetAwsConfig(ctx, &AwsCredentialsConfig{ Region: dynConnConfig.GetRegion(), Endpoint: dynConnConfig.GetEndpoint(), @@ -187,7 +205,11 @@ type AwsCredentialsConfig struct { UseEc2 bool } -func GetAwsConfig(ctx context.Context, cfg *AwsCredentialsConfig, opts ...func(*config.LoadOptions) error) (*aws.Config, error) { +func GetAwsConfig( + ctx context.Context, + cfg *AwsCredentialsConfig, + opts ...func(*config.LoadOptions) error, +) (*aws.Config, error) { if cfg == nil { return nil, fmt.Errorf("cfg input was nil, expected *AwsCredentialsConfig") } diff --git a/internal/aws/dynamodb-client.go b/internal/aws/dynamodb-client.go index 915e261099..5c766c2b18 100644 --- a/internal/aws/dynamodb-client.go +++ b/internal/aws/dynamodb-client.go @@ -13,16 +13,32 @@ type DynamoDbClient struct { } type dynamoDBAPIV2 interface { - DescribeTable(ctx context.Context, params *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) - Scan(ctx context.Context, params *dynamodb.ScanInput, optFns ...func(*dynamodb.Options)) (*dynamodb.ScanOutput, error) - ListTables(ctx context.Context, params *dynamodb.ListTablesInput, optFns ...func(*dynamodb.Options)) (*dynamodb.ListTablesOutput, error) + DescribeTable( + ctx context.Context, + params *dynamodb.DescribeTableInput, + optFns ...func(*dynamodb.Options), + ) (*dynamodb.DescribeTableOutput, error) + Scan( + ctx context.Context, + params *dynamodb.ScanInput, + optFns ...func(*dynamodb.Options), + ) (*dynamodb.ScanOutput, error) + ListTables( + ctx context.Context, + params *dynamodb.ListTablesInput, + optFns ...func(*dynamodb.Options), + ) (*dynamodb.ListTablesOutput, error) } func NewDynamoDbClient(api dynamoDBAPIV2) *DynamoDbClient { return &DynamoDbClient{client: api} } -func (d *DynamoDbClient) ListAllTables(ctx context.Context, input *dynamodb.ListTablesInput, optFns ...func(*dynamodb.Options)) ([]string, error) { +func (d *DynamoDbClient) ListAllTables( + ctx context.Context, + input *dynamodb.ListTablesInput, + optFns ...func(*dynamodb.Options), +) ([]string, error) { tableNames := []string{} done := false for !done { @@ -42,7 +58,10 @@ type DynamoDbTableKey struct { RangeKey string } -func (d *DynamoDbClient) GetTableKey(ctx context.Context, tableName string) (*DynamoDbTableKey, error) { +func (d *DynamoDbClient) GetTableKey( + ctx context.Context, + tableName string, +) (*DynamoDbTableKey, error) { describeTableOutput, err := d.client.DescribeTable(ctx, &dynamodb.DescribeTableInput{ TableName: &tableName, }) diff --git a/internal/benthos-stream/client.go b/internal/benthos-stream/client.go index d02f981b64..ebd3804760 100644 --- a/internal/benthos-stream/client.go +++ b/internal/benthos-stream/client.go @@ -25,7 +25,9 @@ func NewBenthosStreamManager() *BenthosStreamManager { return &BenthosStreamManager{} } -func (b *BenthosStreamManager) NewBenthosStreamFromBuilder(streambldr *service.StreamBuilder) (BenthosStreamClient, error) { +func (b *BenthosStreamManager) NewBenthosStreamFromBuilder( + streambldr *service.StreamBuilder, +) (BenthosStreamClient, error) { stream, err := streambldr.Build() if err != nil { return nil, err diff --git a/internal/benthos/benthos-builder/benthos-builder.go b/internal/benthos/benthos-builder/benthos-builder.go index 0e5e481de8..876a2bc646 100644 --- a/internal/benthos/benthos-builder/benthos-builder.go +++ b/internal/benthos/benthos-builder/benthos-builder.go @@ -63,7 +63,11 @@ func NewBuilderProvider(logger *slog.Logger) *BuilderProvider { } // Handles registering new builders -func (r *BuilderProvider) Register(jobType bb_internal.JobType, connType bb_shared.ConnectionType, builder bb_internal.BenthosBuilder) { +func (r *BuilderProvider) Register( + jobType bb_internal.JobType, + connType bb_shared.ConnectionType, + builder bb_internal.BenthosBuilder, +) { key := BuilderKey{ConnType: connType, JobType: jobType} r.mu.Lock() @@ -71,7 +75,13 @@ func (r *BuilderProvider) Register(jobType bb_internal.JobType, connType bb_shar _, exists := r.builders[key.String()] if !exists { - r.logger.Debug(fmt.Sprintf("registering benthos builder for job type %s and connection type %s", jobType, connType)) + r.logger.Debug( + fmt.Sprintf( + "registering benthos builder for job type %s and connection type %s", + jobType, + connType, + ), + ) r.builders[key.String()] = builder } } @@ -92,7 +102,11 @@ func (r *BuilderProvider) GetBuilder( builder, exists := r.builders[key.String()] r.mu.RUnlock() if !exists { - return nil, fmt.Errorf("builder not registered for connection type (%s) and job type (%s)", connectionType, jobType) + return nil, fmt.Errorf( + "builder not registered for connection type (%s) and job type (%s)", + connectionType, + jobType, + ) } return builder, nil } @@ -131,22 +145,59 @@ func (b *BuilderProvider) registerStandardBuilders( for _, connectionType := range connectionTypes { switch connectionType { case bb_shared.ConnectionTypePostgres: - sqlbuilder := bb_conns.NewSqlSyncBuilder(transformerclient, sqlmanagerclient, redisConfig, sqlmanager_shared.PostgresDriver, selectQueryBuilder, defaultPageLimit) + sqlbuilder := bb_conns.NewSqlSyncBuilder( + transformerclient, + sqlmanagerclient, + redisConfig, + sqlmanager_shared.PostgresDriver, + selectQueryBuilder, + defaultPageLimit, + ) b.Register(bb_internal.JobTypeSync, connectionType, sqlbuilder) case bb_shared.ConnectionTypeMysql: - sqlbuilder := bb_conns.NewSqlSyncBuilder(transformerclient, sqlmanagerclient, redisConfig, sqlmanager_shared.MysqlDriver, selectQueryBuilder, defaultPageLimit) + sqlbuilder := bb_conns.NewSqlSyncBuilder( + transformerclient, + sqlmanagerclient, + redisConfig, + sqlmanager_shared.MysqlDriver, + selectQueryBuilder, + defaultPageLimit, + ) b.Register(bb_internal.JobTypeSync, connectionType, sqlbuilder) case bb_shared.ConnectionTypeMssql: - sqlbuilder := bb_conns.NewSqlSyncBuilder(transformerclient, sqlmanagerclient, redisConfig, sqlmanager_shared.MssqlDriver, selectQueryBuilder, defaultPageLimit) + sqlbuilder := bb_conns.NewSqlSyncBuilder( + transformerclient, + sqlmanagerclient, + redisConfig, + sqlmanager_shared.MssqlDriver, + selectQueryBuilder, + defaultPageLimit, + ) b.Register(bb_internal.JobTypeSync, connectionType, sqlbuilder) case bb_shared.ConnectionTypeAwsS3: - b.Register(bb_internal.JobTypeSync, bb_shared.ConnectionTypeAwsS3, bb_conns.NewAwsS3SyncBuilder()) + b.Register( + bb_internal.JobTypeSync, + bb_shared.ConnectionTypeAwsS3, + bb_conns.NewAwsS3SyncBuilder(), + ) case bb_shared.ConnectionTypeDynamodb: - b.Register(bb_internal.JobTypeSync, bb_shared.ConnectionTypeDynamodb, bb_conns.NewDynamoDbSyncBuilder(transformerclient)) + b.Register( + bb_internal.JobTypeSync, + bb_shared.ConnectionTypeDynamodb, + bb_conns.NewDynamoDbSyncBuilder(transformerclient), + ) case bb_shared.ConnectionTypeMongo: - b.Register(bb_internal.JobTypeSync, bb_shared.ConnectionTypeMongo, bb_conns.NewMongoDbSyncBuilder(transformerclient)) + b.Register( + bb_internal.JobTypeSync, + bb_shared.ConnectionTypeMongo, + bb_conns.NewMongoDbSyncBuilder(transformerclient), + ) case bb_shared.ConnectionTypeGCP: - b.Register(bb_internal.JobTypeSync, bb_shared.ConnectionTypeGCP, bb_conns.NewGcpCloudStorageSyncBuilder()) + b.Register( + bb_internal.JobTypeSync, + bb_shared.ConnectionTypeGCP, + bb_conns.NewGcpCloudStorageSyncBuilder(), + ) default: return fmt.Errorf("unsupport connection type for sync job: %s", connectionType) } @@ -155,7 +206,10 @@ func (b *BuilderProvider) registerStandardBuilders( if jobType == bb_internal.JobTypeAIGenerate { if len(destinationConnections) != 1 { - return fmt.Errorf("unsupported destination count for AI generate job: %d", len(destinationConnections)) + return fmt.Errorf( + "unsupported destination count for AI generate job: %d", + len(destinationConnections), + ) } destConnType, err := bb_shared.GetConnectionType(destinationConnections[0]) if err != nil { @@ -165,13 +219,22 @@ func (b *BuilderProvider) registerStandardBuilders( if err != nil { return err } - builder := bb_conns.NewGenerateAIBuilder(transformerclient, sqlmanagerclient, connectionclient, driver) + builder := bb_conns.NewGenerateAIBuilder( + transformerclient, + sqlmanagerclient, + connectionclient, + driver, + ) b.Register(bb_internal.JobTypeAIGenerate, bb_shared.ConnectionTypeOpenAI, builder) b.Register(bb_internal.JobTypeAIGenerate, destConnType, builder) } if jobType == bb_internal.JobTypeGenerate { for _, connectionType := range connectionTypes { - b.Register(bb_internal.JobTypeGenerate, connectionType, bb_conns.NewGenerateBuilder(transformerclient, sqlmanagerclient, connectionclient)) + b.Register( + bb_internal.JobTypeGenerate, + connectionType, + bb_conns.NewGenerateBuilder(transformerclient, sqlmanagerclient, connectionclient), + ) } } return nil @@ -258,7 +321,8 @@ func NewWorkerBenthosConfigManager( if err != nil { return nil, err } - logger := config.Logger.With(withBenthosConfigLoggerTags(config.Job, config.SourceConnection)...) + logger := config.Logger.With( + withBenthosConfigLoggerTags(config.Job, config.SourceConnection)...) return &BenthosConfigManager{ sourceProvider: provider, destinationProvider: provider, @@ -344,7 +408,8 @@ func NewCliBenthosConfigManager( return nil, err } - logger := config.Logger.With(withBenthosConfigLoggerTags(config.Job, config.SourceConnection)...) + logger := config.Logger.With( + withBenthosConfigLoggerTags(config.Job, config.SourceConnection)...) return &BenthosConfigManager{ sourceProvider: sourceProvider, destinationProvider: destinationProvider, diff --git a/internal/benthos/benthos-builder/builders/aws-s3.go b/internal/benthos/benthos-builder/builders/aws-s3.go index 18962ee8e2..2acd279afc 100644 --- a/internal/benthos/benthos-builder/builders/aws-s3.go +++ b/internal/benthos/benthos-builder/builders/aws-s3.go @@ -20,11 +20,17 @@ func NewAwsS3SyncBuilder() bb_internal.BenthosBuilder { return &awsS3SyncBuilder{} } -func (b *awsS3SyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_internal.SourceParams) ([]*bb_internal.BenthosSourceConfig, error) { +func (b *awsS3SyncBuilder) BuildSourceConfigs( + ctx context.Context, + params *bb_internal.SourceParams, +) ([]*bb_internal.BenthosSourceConfig, error) { return nil, errors.ErrUnsupported } -func (b *awsS3SyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_internal.DestinationParams) (*bb_internal.BenthosDestinationConfig, error) { +func (b *awsS3SyncBuilder) BuildDestinationConfig( + ctx context.Context, + params *bb_internal.DestinationParams, +) (*bb_internal.BenthosDestinationConfig, error) { config := &bb_internal.BenthosDestinationConfig{} benthosConfig := params.SourceConfig @@ -139,7 +145,9 @@ func (s S3StorageClass) String() string { }[s] } -func convertToS3StorageClass(protoStorageClass mgmtv1alpha1.AwsS3DestinationConnectionOptions_StorageClass) S3StorageClass { +func convertToS3StorageClass( + protoStorageClass mgmtv1alpha1.AwsS3DestinationConnectionOptions_StorageClass, +) S3StorageClass { switch protoStorageClass { case mgmtv1alpha1.AwsS3DestinationConnectionOptions_STORAGE_CLASS_STANDARD: return S3StorageClass_STANDARD diff --git a/internal/benthos/benthos-builder/builders/dynamodb.go b/internal/benthos/benthos-builder/builders/dynamodb.go index 4d21655a19..b226150bd6 100644 --- a/internal/benthos/benthos-builder/builders/dynamodb.go +++ b/internal/benthos/benthos-builder/builders/dynamodb.go @@ -29,13 +29,19 @@ func NewDynamoDbSyncBuilder( } } -func (b *dyanmodbSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_internal.SourceParams) ([]*bb_internal.BenthosSourceConfig, error) { +func (b *dyanmodbSyncBuilder) BuildSourceConfigs( + ctx context.Context, + params *bb_internal.SourceParams, +) ([]*bb_internal.BenthosSourceConfig, error) { sourceConnection := params.SourceConnection job := params.Job dynamoSourceConfig := sourceConnection.GetConnectionConfig().GetDynamodbConfig() if dynamoSourceConfig == nil { - return nil, fmt.Errorf("source connection was not dynamodb. Got %T", sourceConnection.GetConnectionConfig().Config) + return nil, fmt.Errorf( + "source connection was not dynamodb. Got %T", + sourceConnection.GetConnectionConfig().Config, + ) } awsManager := awsmanager.New() dynamoClient, err := awsManager.NewDynamoDbClient(ctx, dynamoSourceConfig) @@ -56,12 +62,16 @@ func (b *dyanmodbSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb Input: &neosync_benthos.InputConfig{ Inputs: neosync_benthos.Inputs{ AwsDynamoDB: &neosync_benthos.InputAwsDynamoDB{ - Table: tableMapping.Table, - Where: getWhereFromSourceTableOption(tableOptsMap[tableMapping.Table]), + Table: tableMapping.Table, + Where: getWhereFromSourceTableOption( + tableOptsMap[tableMapping.Table], + ), ConsistentRead: dynamoJobSourceOpts.GetEnableConsistentRead(), Region: dynamoSourceConfig.GetRegion(), Endpoint: dynamoSourceConfig.GetEndpoint(), - Credentials: buildBenthosS3Credentials(dynamoSourceConfig.GetCredentials()), + Credentials: buildBenthosS3Credentials( + dynamoSourceConfig.GetCredentials(), + ), }, }, }, @@ -106,7 +116,17 @@ func (b *dyanmodbSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb processorConfigs, err := buildProcessorConfigsByRunType( ctx, b.transformerclient, - runconfigs.NewRunConfig(runconfigId, schemaTable, runconfigType, []string{}, nil, columns, columns, nil, splitColumnPaths), + runconfigs.NewRunConfig( + runconfigId, + schemaTable, + runconfigType, + []string{}, + nil, + columns, + columns, + nil, + splitColumnPaths, + ), map[string][]*bb_internal.ReferenceKey{}, map[string][]*bb_internal.ReferenceKey{}, params.Job.Id, @@ -144,7 +164,10 @@ func (b *dyanmodbSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb return benthosConfigs, nil } -func (b *dyanmodbSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_internal.DestinationParams) (*bb_internal.BenthosDestinationConfig, error) { +func (b *dyanmodbSyncBuilder) BuildDestinationConfig( + ctx context.Context, + params *bb_internal.DestinationParams, +) (*bb_internal.BenthosDestinationConfig, error) { config := &bb_internal.BenthosDestinationConfig{} benthosConfig := params.SourceConfig @@ -164,7 +187,10 @@ func (b *dyanmodbSyncBuilder) BuildDestinationConfig(ctx context.Context, params } mappedTable, ok := tableMap[benthosConfig.TableName] if !ok { - return nil, fmt.Errorf("did not find table map for %q when building dynamodb destination config", benthosConfig.TableName) + return nil, fmt.Errorf( + "did not find table map for %q when building dynamodb destination config", + benthosConfig.TableName, + ) } config.Outputs = append(config.Outputs, neosync_benthos.Outputs{ AwsDynamoDB: &neosync_benthos.OutputAwsDynamoDB{ @@ -198,7 +224,9 @@ func getWhereFromSourceTableOption(opt *mgmtv1alpha1.DynamoDBSourceTableOption) return opt.WhereClause } -func toDynamoDbSourceTableOptionMap(tableOpts []*mgmtv1alpha1.DynamoDBSourceTableOption) map[string]*mgmtv1alpha1.DynamoDBSourceTableOption { +func toDynamoDbSourceTableOptionMap( + tableOpts []*mgmtv1alpha1.DynamoDBSourceTableOption, +) map[string]*mgmtv1alpha1.DynamoDBSourceTableOption { output := map[string]*mgmtv1alpha1.DynamoDBSourceTableOption{} for _, opt := range tableOpts { output[opt.Table] = opt @@ -206,7 +234,9 @@ func toDynamoDbSourceTableOptionMap(tableOpts []*mgmtv1alpha1.DynamoDBSourceTabl return output } -func buildBenthosS3Credentials(mgmtCreds *mgmtv1alpha1.AwsS3Credentials) *neosync_benthos.AwsCredentials { +func buildBenthosS3Credentials( + mgmtCreds *mgmtv1alpha1.AwsS3Credentials, +) *neosync_benthos.AwsCredentials { if mgmtCreds == nil { return nil } diff --git a/internal/benthos/benthos-builder/builders/gcp-cloud-storage.go b/internal/benthos/benthos-builder/builders/gcp-cloud-storage.go index bff5e0435f..655f4969bb 100644 --- a/internal/benthos/benthos-builder/builders/gcp-cloud-storage.go +++ b/internal/benthos/benthos-builder/builders/gcp-cloud-storage.go @@ -18,11 +18,17 @@ func NewGcpCloudStorageSyncBuilder() bb_internal.BenthosBuilder { return &gcpCloudStorageSyncBuilder{} } -func (b *gcpCloudStorageSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_internal.SourceParams) ([]*bb_internal.BenthosSourceConfig, error) { +func (b *gcpCloudStorageSyncBuilder) BuildSourceConfigs( + ctx context.Context, + params *bb_internal.SourceParams, +) ([]*bb_internal.BenthosSourceConfig, error) { return nil, errors.ErrUnsupported } -func (b *gcpCloudStorageSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_internal.DestinationParams) (*bb_internal.BenthosDestinationConfig, error) { +func (b *gcpCloudStorageSyncBuilder) BuildDestinationConfig( + ctx context.Context, + params *bb_internal.DestinationParams, +) (*bb_internal.BenthosDestinationConfig, error) { config := &bb_internal.BenthosDestinationConfig{} benthosConfig := params.SourceConfig diff --git a/internal/benthos/benthos-builder/builders/generate-ai.go b/internal/benthos/benthos-builder/builders/generate-ai.go index 1e34b0f625..fed2f69e01 100644 --- a/internal/benthos/benthos-builder/builders/generate-ai.go +++ b/internal/benthos/benthos-builder/builders/generate-ai.go @@ -48,11 +48,17 @@ type aiGenerateColumn struct { DataType string } -func (b *generateAIBuilder) BuildSourceConfigs(ctx context.Context, params *bb_internal.SourceParams) ([]*bb_internal.BenthosSourceConfig, error) { +func (b *generateAIBuilder) BuildSourceConfigs( + ctx context.Context, + params *bb_internal.SourceParams, +) ([]*bb_internal.BenthosSourceConfig, error) { jobSource := params.Job.GetSource() sourceOptions := jobSource.GetOptions().GetAiGenerate() if sourceOptions == nil { - return nil, fmt.Errorf("job does not have AiGenerate source options, has: %T", jobSource.GetOptions().Config) + return nil, fmt.Errorf( + "job does not have AiGenerate source options, has: %T", + jobSource.GetOptions().Config, + ) } sourceConnection := params.SourceConnection @@ -60,11 +66,21 @@ func (b *generateAIBuilder) BuildSourceConfigs(ctx context.Context, params *bb_i if openaiConfig == nil { return nil, errors.New("configured source connection is not an openai configuration") } - constraintConnection, err := getConstraintConnection(ctx, jobSource, b.connectionclient, shared.GetConnectionById) + constraintConnection, err := getConstraintConnection( + ctx, + jobSource, + b.connectionclient, + shared.GetConnectionById, + ) if err != nil { return nil, err } - db, err := b.sqlmanagerclient.NewSqlConnection(ctx, connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(params.JobRunId)), constraintConnection, params.Logger) + db, err := b.sqlmanagerclient.NewSqlConnection( + ctx, + connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(params.JobRunId)), + constraintConnection, + params.Logger, + ) if err != nil { return nil, fmt.Errorf("unable to create new sql db: %w", err) } @@ -82,7 +98,10 @@ func (b *generateAIBuilder) BuildSourceConfigs(ctx context.Context, params *bb_i tableColsMap, ok := groupedSchemas[sqlmanager_shared.BuildTable(schema.GetSchema(), table.GetTable())] if !ok { - return nil, fmt.Errorf("did not find schema data when building AI Generate config: %s", schema.GetSchema()) + return nil, fmt.Errorf( + "did not find schema data when building AI Generate config: %s", + schema.GetSchema(), + ) } for col, info := range tableColsMap { columns = append(columns, &aiGenerateColumn{ @@ -100,7 +119,10 @@ func (b *generateAIBuilder) BuildSourceConfigs(ctx context.Context, params *bb_i } } if len(mappings) == 0 { - return nil, fmt.Errorf("did not generate any mapping configs during AI Generate build for connection: %s", constraintConnection.GetId()) + return nil, fmt.Errorf( + "did not generate any mapping configs during AI Generate build for connection: %s", + constraintConnection.GetId(), + ) } var userPrompt *string @@ -179,7 +201,10 @@ func buildBenthosAiGenerateSourceConfigResponses( } responses = append(responses, &bb_internal.BenthosSourceConfig{ - Name: neosync_benthos.BuildBenthosTable(tableMapping.Schema, tableMapping.Table), // todo: may need to expand on this + Name: neosync_benthos.BuildBenthosTable( + tableMapping.Schema, + tableMapping.Table, + ), // todo: may need to expand on this Config: bc, DependsOn: []*runconfigs.DependsOn{}, @@ -197,7 +222,10 @@ func buildBenthosAiGenerateSourceConfigResponses( return responses } -func (b *generateAIBuilder) BuildDestinationConfig(ctx context.Context, params *bb_internal.DestinationParams) (*bb_internal.BenthosDestinationConfig, error) { +func (b *generateAIBuilder) BuildDestinationConfig( + ctx context.Context, + params *bb_internal.DestinationParams, +) (*bb_internal.BenthosDestinationConfig, error) { config := &bb_internal.BenthosDestinationConfig{} benthosConfig := params.SourceConfig @@ -211,7 +239,10 @@ func (b *generateAIBuilder) BuildDestinationConfig(ctx context.Context, params * processorConfigs = append(processorConfigs, *pc) } - config.BenthosDsns = append(config.BenthosDsns, &bb_shared.BenthosDsn{ConnectionId: params.DestConnection.Id}) + config.BenthosDsns = append( + config.BenthosDsns, + &bb_shared.BenthosDsn{ConnectionId: params.DestConnection.Id}, + ) config.Outputs = append(config.Outputs, neosync_benthos.Outputs{ // retry processor and output several times Retry: &neosync_benthos.RetryConfig{ @@ -273,7 +304,11 @@ func getConstraintConnection( } connection, err := getConnectionById(ctx, connclient, connectionId) if err != nil { - return nil, fmt.Errorf("unable to get constraint connection by id (%s): %w", connectionId, err) + return nil, fmt.Errorf( + "unable to get constraint connection by id (%s): %w", + connectionId, + err, + ) } return connection, nil } diff --git a/internal/benthos/benthos-builder/builders/generate.go b/internal/benthos/benthos-builder/builders/generate.go index 2deb48cfc5..0d97e04952 100644 --- a/internal/benthos/benthos-builder/builders/generate.go +++ b/internal/benthos/benthos-builder/builders/generate.go @@ -35,7 +35,10 @@ func NewGenerateBuilder( } } -func (b *generateBuilder) BuildSourceConfigs(ctx context.Context, params *bb_internal.SourceParams) ([]*bb_internal.BenthosSourceConfig, error) { +func (b *generateBuilder) BuildSourceConfigs( + ctx context.Context, + params *bb_internal.SourceParams, +) ([]*bb_internal.BenthosSourceConfig, error) { logger := params.Logger job := params.Job configs := []*bb_internal.BenthosSourceConfig{} @@ -43,14 +46,22 @@ func (b *generateBuilder) BuildSourceConfigs(ctx context.Context, params *bb_int jobSource := job.GetSource() sourceOptions := jobSource.GetOptions().GetGenerate() if sourceOptions == nil { - return nil, fmt.Errorf("job does not have Generate source options, has: %T", jobSource.GetOptions().Config) + return nil, fmt.Errorf( + "job does not have Generate source options, has: %T", + jobSource.GetOptions().Config, + ) } sourceConnection, err := shared.GetJobSourceConnection(ctx, jobSource, b.connectionclient) if err != nil { return nil, fmt.Errorf("unable to get connection by id: %w", err) } - db, err := b.sqlmanagerclient.NewSqlConnection(ctx, connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(params.JobRunId)), sourceConnection, logger) + db, err := b.sqlmanagerclient.NewSqlConnection( + ctx, + connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(params.JobRunId)), + sourceConnection, + logger, + ) if err != nil { return nil, fmt.Errorf("unable to create new sql db: %w", err) } @@ -88,7 +99,13 @@ func (b *generateBuilder) BuildSourceConfigs(ctx context.Context, params *bb_int return nil, err } - mutations, err := buildMutationConfigs(ctx, b.transformerclient, tableMapping.Mappings, tableColInfo, false) + mutations, err := buildMutationConfigs( + ctx, + b.transformerclient, + tableMapping.Mappings, + tableColInfo, + false, + ) if err != nil { return nil, err } @@ -101,15 +118,23 @@ func (b *generateBuilder) BuildSourceConfigs(ctx context.Context, params *bb_int processors = append(processors, &neosync_benthos.ProcessorConfig{Mutation: &mutations}) if jsCode != "" { - processors = append(processors, &neosync_benthos.ProcessorConfig{NeosyncJavascript: &neosync_benthos.NeosyncJavascriptConfig{Code: jsCode}}) + processors = append( + processors, + &neosync_benthos.ProcessorConfig{ + NeosyncJavascript: &neosync_benthos.NeosyncJavascriptConfig{Code: jsCode}, + }, + ) } if len(processors) > 0 { // add catch and error processor - processors = append(processors, &neosync_benthos.ProcessorConfig{Catch: []*neosync_benthos.ProcessorConfig{ - {Error: &neosync_benthos.ErrorProcessorConfig{ - ErrorMsg: `${! error()}`, + processors = append( + processors, + &neosync_benthos.ProcessorConfig{Catch: []*neosync_benthos.ProcessorConfig{ + {Error: &neosync_benthos.ErrorProcessorConfig{ + ErrorMsg: `${! error()}`, + }}, }}, - }}) + ) } bc := &neosync_benthos.BenthosConfig{ @@ -139,13 +164,22 @@ func (b *generateBuilder) BuildSourceConfigs(ctx context.Context, params *bb_int } columns := buildPlainColumns(tableMapping.Mappings) - columnDefaultProperties, err := getColumnDefaultProperties(logger, db.Driver(), columns, tableColInfo, tableColTransformers) + columnDefaultProperties, err := getColumnDefaultProperties( + logger, + db.Driver(), + columns, + tableColInfo, + tableColTransformers, + ) if err != nil { return nil, err } configs = append(configs, &bb_internal.BenthosSourceConfig{ - Name: neosync_benthos.BuildBenthosTable(tableMapping.Schema, tableMapping.Table), // todo: may need to expand on this + Name: neosync_benthos.BuildBenthosTable( + tableMapping.Schema, + tableMapping.Table, + ), // todo: may need to expand on this Config: bc, DependsOn: []*runconfigs.DependsOn{}, @@ -167,7 +201,10 @@ func (b *generateBuilder) BuildSourceConfigs(ctx context.Context, params *bb_int return configs, nil } -func (b *generateBuilder) BuildDestinationConfig(ctx context.Context, params *bb_internal.DestinationParams) (*bb_internal.BenthosDestinationConfig, error) { +func (b *generateBuilder) BuildDestinationConfig( + ctx context.Context, + params *bb_internal.DestinationParams, +) (*bb_internal.BenthosDestinationConfig, error) { config := &bb_internal.BenthosDestinationConfig{} benthosConfig := params.SourceConfig @@ -181,12 +218,20 @@ func (b *generateBuilder) BuildDestinationConfig(ctx context.Context, params *bb processorConfigs = append(processorConfigs, *pc) } - sqlProcessor, err := getSqlBatchProcessors(b.driver, benthosConfig.Columns, map[string]string{}, benthosConfig.ColumnDefaultProperties) + sqlProcessor, err := getSqlBatchProcessors( + b.driver, + benthosConfig.Columns, + map[string]string{}, + benthosConfig.ColumnDefaultProperties, + ) if err != nil { return nil, err } - config.BenthosDsns = append(config.BenthosDsns, &bb_shared.BenthosDsn{ConnectionId: params.DestConnection.Id}) + config.BenthosDsns = append( + config.BenthosDsns, + &bb_shared.BenthosDsn{ConnectionId: params.DestConnection.Id}, + ) config.Outputs = append(config.Outputs, neosync_benthos.Outputs{ // retry processor and output several times Retry: &neosync_benthos.RetryConfig{ @@ -248,7 +293,9 @@ func groupGenerateSourceOptionsByTable( tableOpt := schemaOpt.Tables[tidx] key := neosync_benthos.BuildBenthosTable(schemaOpt.Schema, tableOpt.Table) groupedMappings[key] = &generateSourceTableOptions{ - Count: int(tableOpt.RowCount), // todo: probably need to update rowcount int64 to int32 + Count: int( + tableOpt.RowCount, + ), // todo: probably need to update rowcount int64 to int32 } } } diff --git a/internal/benthos/benthos-builder/builders/mongodb.go b/internal/benthos/benthos-builder/builders/mongodb.go index d5648f46c8..14d84c2413 100644 --- a/internal/benthos/benthos-builder/builders/mongodb.go +++ b/internal/benthos/benthos-builder/builders/mongodb.go @@ -26,7 +26,10 @@ func NewMongoDbSyncBuilder( } } -func (b *mongodbSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_internal.SourceParams) ([]*bb_internal.BenthosSourceConfig, error) { +func (b *mongodbSyncBuilder) BuildSourceConfigs( + ctx context.Context, + params *bb_internal.SourceParams, +) ([]*bb_internal.BenthosSourceConfig, error) { sourceConnection := params.SourceConnection job := params.Job groupedMappings := groupMappingsByTable(job.GetMappings()) @@ -75,7 +78,17 @@ func (b *mongodbSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_ processorConfigs, err := buildProcessorConfigsByRunType( ctx, b.transformerclient, - runconfigs.NewRunConfig(runconfigId, schemaTable, runconfigType, []string{}, nil, columns, columns, nil, splitColumnPaths), + runconfigs.NewRunConfig( + runconfigId, + schemaTable, + runconfigType, + []string{}, + nil, + columns, + columns, + nil, + splitColumnPaths, + ), map[string][]*bb_internal.ReferenceKey{}, map[string][]*bb_internal.ReferenceKey{}, params.Job.Id, @@ -114,30 +127,39 @@ func (b *mongodbSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_ return benthosConfigs, nil } -func (b *mongodbSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_internal.DestinationParams) (*bb_internal.BenthosDestinationConfig, error) { +func (b *mongodbSyncBuilder) BuildDestinationConfig( + ctx context.Context, + params *bb_internal.DestinationParams, +) (*bb_internal.BenthosDestinationConfig, error) { config := &bb_internal.BenthosDestinationConfig{} benthosConfig := params.SourceConfig - config.BenthosDsns = append(config.BenthosDsns, &bb_shared.BenthosDsn{ConnectionId: params.DestConnection.GetId()}) - config.Outputs = append(config.Outputs, neosync_benthos.Outputs{PooledMongoDB: &neosync_benthos.OutputMongoDb{ - ConnectionId: params.DestConnection.GetId(), + config.BenthosDsns = append( + config.BenthosDsns, + &bb_shared.BenthosDsn{ConnectionId: params.DestConnection.GetId()}, + ) + config.Outputs = append( + config.Outputs, + neosync_benthos.Outputs{PooledMongoDB: &neosync_benthos.OutputMongoDb{ + ConnectionId: params.DestConnection.GetId(), - Database: benthosConfig.TableSchema, - Collection: benthosConfig.TableName, - Operation: "update-one", - Upsert: true, - DocumentMap: ` + Database: benthosConfig.TableSchema, + Collection: benthosConfig.TableName, + Operation: "update-one", + Upsert: true, + DocumentMap: ` root = { "$set": this } `, - FilterMap: ` + FilterMap: ` root._id = this._id `, - WriteConcern: &neosync_benthos.MongoWriteConcern{ - W: "1", + WriteConcern: &neosync_benthos.MongoWriteConcern{ + W: "1", + }, + }, }, - }, - }) + ) return config, nil } diff --git a/internal/benthos/benthos-builder/builders/neosync-connection-data.go b/internal/benthos/benthos-builder/builders/neosync-connection-data.go index 78e7ddcfa4..4841cc06af 100644 --- a/internal/benthos/benthos-builder/builders/neosync-connection-data.go +++ b/internal/benthos/benthos-builder/builders/neosync-connection-data.go @@ -41,7 +41,10 @@ func NewNeosyncConnectionDataSyncBuilder( } } -func (b *neosyncConnectionDataBuilder) BuildSourceConfigs(ctx context.Context, params *bb_internal.SourceParams) ([]*bb_internal.BenthosSourceConfig, error) { +func (b *neosyncConnectionDataBuilder) BuildSourceConfigs( + ctx context.Context, + params *bb_internal.SourceParams, +) ([]*bb_internal.BenthosSourceConfig, error) { sourceConnection := params.SourceConnection job := params.Job configs := []*bb_internal.BenthosSourceConfig{} @@ -96,6 +99,9 @@ func (b *neosyncConnectionDataBuilder) BuildSourceConfigs(ctx context.Context, p return configs, nil } -func (b *neosyncConnectionDataBuilder) BuildDestinationConfig(ctx context.Context, params *bb_internal.DestinationParams) (*bb_internal.BenthosDestinationConfig, error) { +func (b *neosyncConnectionDataBuilder) BuildDestinationConfig( + ctx context.Context, + params *bb_internal.DestinationParams, +) (*bb_internal.BenthosDestinationConfig, error) { return nil, errors.ErrUnsupported } diff --git a/internal/benthos/benthos-builder/builders/processors.go b/internal/benthos/benthos-builder/builders/processors.go index 820b86a087..6125d33c05 100644 --- a/internal/benthos/benthos-builder/builders/processors.go +++ b/internal/benthos/benthos-builder/builders/processors.go @@ -40,7 +40,13 @@ func buildProcessorConfigsByRunType( ) ([]*neosync_benthos.ProcessorConfig, error) { if config.RunType() == runconfigs.RunTypeUpdate { // sql update processor configs - processorConfigs, err := buildSqlUpdateProcessorConfigs(config, redisConfig, jobId, runId, transformedFktoPkMap) + processorConfigs, err := buildSqlUpdateProcessorConfigs( + config, + redisConfig, + jobId, + runId, + transformedFktoPkMap, + ) if err != nil { return nil, err } @@ -87,14 +93,25 @@ func buildSqlUpdateProcessorConfigs( // circular dependent foreign key hashedKey := neosync_benthos.HashBenthosCacheKey(jobId, runId, pk.Table, pk.Column) - requestMap := fmt.Sprintf(`root = if this.%q == null { deleted() } else { this }`, fkCol) + requestMap := fmt.Sprintf( + `root = if this.%q == null { deleted() } else { this }`, + fkCol, + ) argsMapping := fmt.Sprintf(`root = [%q, json(%q)]`, hashedKey, fkCol) resultMap := fmt.Sprintf("root.%q = this", fkCol) - fkBranch, err := buildRedisGetBranchConfig(resultMap, argsMapping, &requestMap, redisConfig) + fkBranch, err := buildRedisGetBranchConfig( + resultMap, + argsMapping, + &requestMap, + redisConfig, + ) if err != nil { return nil, err } - processorConfigs = append(processorConfigs, &neosync_benthos.ProcessorConfig{Branch: fkBranch}) + processorConfigs = append( + processorConfigs, + &neosync_benthos.ProcessorConfig{Branch: fkBranch}, + ) } } @@ -105,18 +122,29 @@ func buildSqlUpdateProcessorConfigs( pkRequestMap := fmt.Sprintf(`root = if this.%q == null { deleted() } else { this }`, pk) pkArgsMapping := fmt.Sprintf(`root = [%q, json(%q)]`, hashedKey, pk) pkResultMap := fmt.Sprintf("root.%q = this", pk) - pkBranch, err := buildRedisGetBranchConfig(pkResultMap, pkArgsMapping, &pkRequestMap, redisConfig) + pkBranch, err := buildRedisGetBranchConfig( + pkResultMap, + pkArgsMapping, + &pkRequestMap, + redisConfig, + ) if err != nil { return nil, err } - processorConfigs = append(processorConfigs, &neosync_benthos.ProcessorConfig{Branch: pkBranch}) + processorConfigs = append( + processorConfigs, + &neosync_benthos.ProcessorConfig{Branch: pkBranch}, + ) } // add catch and error processor - processorConfigs = append(processorConfigs, &neosync_benthos.ProcessorConfig{Catch: []*neosync_benthos.ProcessorConfig{ - {Error: &neosync_benthos.ErrorProcessorConfig{ - ErrorMsg: `${! error()}`, + processorConfigs = append( + processorConfigs, + &neosync_benthos.ProcessorConfig{Catch: []*neosync_benthos.ProcessorConfig{ + {Error: &neosync_benthos.ErrorProcessorConfig{ + ErrorMsg: `${! error()}`, + }}, }}, - }}) + ) } return processorConfigs, nil } @@ -146,12 +174,24 @@ func buildProcessorConfigs( return nil, err } - mutations, err := buildMutationConfigs(ctx, transformerclient, filteredCols, tableColumnInfo, runconfig.SplitColumnPaths()) + mutations, err := buildMutationConfigs( + ctx, + transformerclient, + filteredCols, + tableColumnInfo, + runconfig.SplitColumnPaths(), + ) if err != nil { return nil, err } - cacheBranches, err := buildBranchCacheConfigs(filteredCols, transformedFktoPkMap, jobId, runId, redisConfig) + cacheBranches, err := buildBranchCacheConfigs( + filteredCols, + transformedFktoPkMap, + jobId, + runId, + redisConfig, + ) if err != nil { return nil, err } @@ -164,30 +204,50 @@ func buildProcessorConfigs( } var processorConfigs []*neosync_benthos.ProcessorConfig if pkMapping != "" { - processorConfigs = append(processorConfigs, &neosync_benthos.ProcessorConfig{Mapping: &pkMapping}) + processorConfigs = append( + processorConfigs, + &neosync_benthos.ProcessorConfig{Mapping: &pkMapping}, + ) } if mutations != "" { - processorConfigs = append(processorConfigs, &neosync_benthos.ProcessorConfig{Mutation: &mutations}) + processorConfigs = append( + processorConfigs, + &neosync_benthos.ProcessorConfig{Mutation: &mutations}, + ) } if jsCode != "" { - processorConfigs = append(processorConfigs, &neosync_benthos.ProcessorConfig{NeosyncJavascript: &neosync_benthos.NeosyncJavascriptConfig{Code: jsCode}}) + processorConfigs = append( + processorConfigs, + &neosync_benthos.ProcessorConfig{ + NeosyncJavascript: &neosync_benthos.NeosyncJavascriptConfig{Code: jsCode}, + }, + ) } if len(cacheBranches) > 0 { for _, config := range cacheBranches { - processorConfigs = append(processorConfigs, &neosync_benthos.ProcessorConfig{Branch: config}) + processorConfigs = append( + processorConfigs, + &neosync_benthos.ProcessorConfig{Branch: config}, + ) } } if defaultTransformerConfig != nil { - processorConfigs = append(processorConfigs, &neosync_benthos.ProcessorConfig{NeosyncDefaultTransformer: defaultTransformerConfig}) + processorConfigs = append( + processorConfigs, + &neosync_benthos.ProcessorConfig{NeosyncDefaultTransformer: defaultTransformerConfig}, + ) } if len(processorConfigs) > 0 { // add catch and error processor - processorConfigs = append(processorConfigs, &neosync_benthos.ProcessorConfig{Catch: []*neosync_benthos.ProcessorConfig{ - {Error: &neosync_benthos.ErrorProcessorConfig{ - ErrorMsg: `${! error()}`, + processorConfigs = append( + processorConfigs, + &neosync_benthos.ProcessorConfig{Catch: []*neosync_benthos.ProcessorConfig{ + {Error: &neosync_benthos.ErrorProcessorConfig{ + ErrorMsg: `${! error()}`, + }}, }}, - }}) + ) } return processorConfigs, err @@ -212,7 +272,11 @@ func buildDefaultTransformerConfigs( }, nil } -func extractJsFunctionsAndOutputs(ctx context.Context, transformerclient mgmtv1alpha1connect.TransformersServiceClient, cols []*mgmtv1alpha1.JobMapping) (string, error) { +func extractJsFunctionsAndOutputs( + ctx context.Context, + transformerclient mgmtv1alpha1connect.TransformersServiceClient, + cols []*mgmtv1alpha1.JobMapping, +) (string, error) { var benthosOutputs []string var jsFunctions []string @@ -220,7 +284,11 @@ func extractJsFunctionsAndOutputs(ctx context.Context, transformerclient mgmtv1a jmTransformer := col.GetTransformer() if shouldProcessStrict(jmTransformer) { if jmTransformer.GetConfig().GetUserDefinedTransformerConfig() != nil { - val, err := convertUserDefinedFunctionConfig(ctx, transformerclient, col.GetTransformer()) + val, err := convertUserDefinedFunctionConfig( + ctx, + transformerclient, + col.GetTransformer(), + ) if err != nil { return "", errors.New("unable to look up user defined transformer config by id") } @@ -256,7 +324,8 @@ func isJavascriptTransformer(jmt *mgmtv1alpha1.JobMappingTransformer) bool { return false } - isConfig := jmt.GetConfig().GetTransformJavascriptConfig() != nil || jmt.GetConfig().GetGenerateJavascriptConfig() != nil + isConfig := jmt.GetConfig().GetTransformJavascriptConfig() != nil || + jmt.GetConfig().GetGenerateJavascriptConfig() != nil return isConfig } @@ -273,7 +342,10 @@ func buildIdentityCursors( if transformer.GetConfig().GetUserDefinedTransformerConfig() != nil { val, err := convertUserDefinedFunctionConfig(ctx, transformerclient, transformer) if err != nil { - return nil, fmt.Errorf("unable to look up user defined transformer config by id: %w", err) + return nil, fmt.Errorf( + "unable to look up user defined transformer config by id: %w", + err, + ) } transformer = val } @@ -301,7 +373,11 @@ func buildMutationConfigs( if shouldProcessColumn(col.GetTransformer()) { if col.GetTransformer().GetConfig().GetUserDefinedTransformerConfig() != nil { // handle user defined transformer -> get the user defined transformer configs using the id - val, err := convertUserDefinedFunctionConfig(ctx, transformerclient, col.GetTransformer()) + val, err := convertUserDefinedFunctionConfig( + ctx, + transformerclient, + col.GetTransformer(), + ) if err != nil { return "", errors.New("unable to look up user defined transformer config by id") } @@ -310,9 +386,20 @@ func buildMutationConfigs( if !isJavascriptTransformer(col.GetTransformer()) { mutation, err := computeMutationFunction(col, colInfo, splitColumnPaths) if err != nil { - return "", fmt.Errorf("%s is not a supported transformer: %w", col.GetTransformer(), err) + return "", fmt.Errorf( + "%s is not a supported transformer: %w", + col.GetTransformer(), + err, + ) } - mutations = append(mutations, fmt.Sprintf("root.%s = %s", getBenthosColumnKey(col.GetColumn(), splitColumnPaths), mutation)) + mutations = append( + mutations, + fmt.Sprintf( + "root.%s = %s", + getBenthosColumnKey(col.GetColumn(), splitColumnPaths), + mutation, + ), + ) } } } @@ -337,7 +424,14 @@ func buildPrimaryKeyMappingConfigs(cols []*mgmtv1alpha1.JobMapping, primaryKeys mappings := []string{} for _, col := range cols { if shouldProcessColumn(col.Transformer) && slices.Contains(primaryKeys, col.Column) { - mappings = append(mappings, fmt.Sprintf("meta %s = this.%q", hashPrimaryKeyMetaKey(col.Schema, col.Table, col.Column), col.Column)) + mappings = append( + mappings, + fmt.Sprintf( + "meta %s = this.%q", + hashPrimaryKeyMetaKey(col.Schema, col.Table, col.Column), + col.Column, + ), + ) } } return strings.Join(mappings, "\n") @@ -371,10 +465,18 @@ func buildBranchCacheConfigs( } hashedKey := neosync_benthos.HashBenthosCacheKey(jobId, runId, fk.Table, fk.Column) - requestMap := fmt.Sprintf(`root = if this.%q == null { deleted() } else { this }`, col.Column) + requestMap := fmt.Sprintf( + `root = if this.%q == null { deleted() } else { this }`, + col.Column, + ) argsMapping := fmt.Sprintf(`root = [%q, json(%q)]`, hashedKey, col.Column) resultMap := fmt.Sprintf("root.%q = this", col.Column) - br, err := buildRedisGetBranchConfig(resultMap, argsMapping, &requestMap, redisConfig) + br, err := buildRedisGetBranchConfig( + resultMap, + argsMapping, + &requestMap, + redisConfig, + ) if err != nil { return nil, err } @@ -439,7 +541,14 @@ func convertUserDefinedFunctionConfig( transformerclient mgmtv1alpha1connect.TransformersServiceClient, t *mgmtv1alpha1.JobMappingTransformer, ) (*mgmtv1alpha1.JobMappingTransformer, error) { - transformerResp, err := transformerclient.GetUserDefinedTransformerById(ctx, connect.NewRequest(&mgmtv1alpha1.GetUserDefinedTransformerByIdRequest{TransformerId: t.Config.GetUserDefinedTransformerConfig().Id})) + transformerResp, err := transformerclient.GetUserDefinedTransformerById( + ctx, + connect.NewRequest( + &mgmtv1alpha1.GetUserDefinedTransformerByIdRequest{ + TransformerId: t.Config.GetUserDefinedTransformerConfig().Id, + }, + ), + ) if err != nil { return nil, err } @@ -450,7 +559,11 @@ func convertUserDefinedFunctionConfig( }, nil } -func computeMutationFunction(col *mgmtv1alpha1.JobMapping, colInfo *sqlmanager_shared.DatabaseSchemaRow, splitColumnPath bool) (string, error) { +func computeMutationFunction( + col *mgmtv1alpha1.JobMapping, + colInfo *sqlmanager_shared.DatabaseSchemaRow, + splitColumnPath bool, +) (string, error) { var maxLen int64 = 10000 if colInfo != nil && colInfo.CharacterMaximumLength > 0 { maxLen = int64(colInfo.CharacterMaximumLength) @@ -749,7 +862,9 @@ func computeMutationFunction(col *mgmtv1alpha1.JobMapping, colInfo *sqlmanager_s } func buildScrambleIdentityToken(col *mgmtv1alpha1.JobMapping) string { - return neosync_benthos.ToSha256(fmt.Sprintf("%s.%s.%s", col.GetSchema(), col.GetTable(), col.GetColumn())) + return neosync_benthos.ToSha256( + fmt.Sprintf("%s.%s.%s", col.GetSchema(), col.GetTable(), col.GetColumn()), + ) } func shouldProcessColumn(t *mgmtv1alpha1.JobMappingTransformer) bool { diff --git a/internal/benthos/benthos-builder/builders/sql-util.go b/internal/benthos/benthos-builder/builders/sql-util.go index 604e1b7f3a..b708184196 100644 --- a/internal/benthos/benthos-builder/builders/sql-util.go +++ b/internal/benthos/benthos-builder/builders/sql-util.go @@ -126,7 +126,10 @@ func buildPlainColumns(mappings []*mgmtv1alpha1.JobMapping) []string { return columns } -func buildTableSubsetMap(tableOpts map[string]*sqlSourceTableOptions, tableMap map[string]*tableMapping) map[string]string { +func buildTableSubsetMap( + tableOpts map[string]*sqlSourceTableOptions, + tableMap map[string]*tableMapping, +) map[string]string { tableSubsetMap := map[string]string{} for table, opts := range tableOpts { if _, ok := tableMap[table]; !ok { @@ -224,7 +227,9 @@ func getTableMappingsMap(groupedMappings []*tableMapping) map[string]*tableMappi return groupedTableMapping } -func getColumnTransformerMap(tableMappingMap map[string]*tableMapping) map[string]map[string]*mgmtv1alpha1.JobMappingTransformer { +func getColumnTransformerMap( + tableMappingMap map[string]*tableMapping, +) map[string]map[string]*mgmtv1alpha1.JobMappingTransformer { colTransformerMap := map[string]map[string]*mgmtv1alpha1.JobMappingTransformer{} // schema.table -> column -> transformer for table, mapping := range tableMappingMap { colTransformerMap[table] = map[string]*mgmtv1alpha1.JobMappingTransformer{} @@ -273,7 +278,10 @@ func filterForeignKeysMap( newFk.Columns = append(newFk.Columns, c) newFk.NotNullable = append(newFk.NotNullable, fk.NotNullable[i]) - newFk.ForeignKey.Columns = append(newFk.ForeignKey.Columns, fk.ForeignKey.Columns[i]) + newFk.ForeignKey.Columns = append( + newFk.ForeignKey.Columns, + fk.ForeignKey.Columns[i], + ) } if len(newFk.Columns) > 0 { @@ -303,7 +311,9 @@ func isDefaultJobMappingTransformer(t *mgmtv1alpha1.JobMappingTransformer) bool } // map of table primary key cols to foreign key cols -func getPrimaryKeyDependencyMap(tableDependencies map[string][]*sqlmanager_shared.ForeignConstraint) map[string]map[string][]*bb_internal.ReferenceKey { +func getPrimaryKeyDependencyMap( + tableDependencies map[string][]*sqlmanager_shared.ForeignConstraint, +) map[string]map[string][]*bb_internal.ReferenceKey { tc := map[string]map[string][]*bb_internal.ReferenceKey{} // schema.table -> column -> ForeignKey for table, constraints := range tableDependencies { for _, c := range constraints { @@ -312,24 +322,34 @@ func getPrimaryKeyDependencyMap(tableDependencies map[string][]*sqlmanager_share tc[c.ForeignKey.Table] = map[string][]*bb_internal.ReferenceKey{} } for idx, col := range c.ForeignKey.Columns { - tc[c.ForeignKey.Table][col] = append(tc[c.ForeignKey.Table][col], &bb_internal.ReferenceKey{ - Table: table, - Column: c.Columns[idx], - }) + tc[c.ForeignKey.Table][col] = append( + tc[c.ForeignKey.Table][col], + &bb_internal.ReferenceKey{ + Table: table, + Column: c.Columns[idx], + }, + ) } } } return tc } -func findTopForeignKeySource(tableName, col string, tableDependencies map[string][]*sqlmanager_shared.ForeignConstraint) *bb_internal.ReferenceKey { +func findTopForeignKeySource( + tableName, col string, + tableDependencies map[string][]*sqlmanager_shared.ForeignConstraint, +) *bb_internal.ReferenceKey { // Add the foreign key dependencies of the current table if foreignKeys, ok := tableDependencies[tableName]; ok { for _, fk := range foreignKeys { for idx, c := range fk.Columns { if c == col { // Recursively add dependent tables and their foreign keys - return findTopForeignKeySource(fk.ForeignKey.Table, fk.ForeignKey.Columns[idx], tableDependencies) + return findTopForeignKeySource( + fk.ForeignKey.Table, + fk.ForeignKey.Columns[idx], + tableDependencies, + ) } } } @@ -342,7 +362,9 @@ func findTopForeignKeySource(tableName, col string, tableDependencies map[string // builds schema.table -> FK column -> PK schema table column // find top level primary key column if foreign keys are nested -func buildForeignKeySourceMap(tableDeps map[string][]*sqlmanager_shared.ForeignConstraint) map[string]map[string]*bb_internal.ReferenceKey { +func buildForeignKeySourceMap( + tableDeps map[string][]*sqlmanager_shared.ForeignConstraint, +) map[string]map[string]*bb_internal.ReferenceKey { outputMap := map[string]map[string]*bb_internal.ReferenceKey{} for tableName, constraints := range tableDeps { if _, ok := outputMap[tableName]; !ok { @@ -374,7 +396,10 @@ func getTransformedFksMap( // only add constraint if foreign key has transformer transformer, transformerOk := colTransformerMap[tc.Table][tc.Column] if transformerOk && shouldProcessStrict(transformer) { - transformedForeignKeyToSourceMap[table][col] = append(transformedForeignKeyToSourceMap[table][col], tc) + transformedForeignKeyToSourceMap[table][col] = append( + transformedForeignKeyToSourceMap[table][col], + tc, + ) } } } @@ -394,9 +419,18 @@ func getColumnDefaultProperties( if !ok { return nil, fmt.Errorf("column default type missing. column: %s", cName) } - needsOverride, needsReset, err := sqlmanager.GetColumnOverrideAndResetProperties(driver, info) + needsOverride, needsReset, err := sqlmanager.GetColumnOverrideAndResetProperties( + driver, + info, + ) if err != nil { - slogger.Error("unable to determine SQL column default flags", "error", err, "column", cName) + slogger.Error( + "unable to determine SQL column default flags", + "error", + err, + "column", + cName, + ) return nil, err } @@ -432,7 +466,9 @@ type destinationOptions struct { BatchPeriod string } -func getDestinationOptions(destOpts *mgmtv1alpha1.JobDestinationOptions) (*destinationOptions, error) { +func getDestinationOptions( + destOpts *mgmtv1alpha1.JobDestinationOptions, +) (*destinationOptions, error) { if destOpts.GetConfig() == nil { return &destinationOptions{}, nil } @@ -541,14 +577,19 @@ func getParsedBatchingConfig(destOpt batchDestinationOption) (batchingConfig, er if batchConfig.GetPeriod() != "" { _, err := time.ParseDuration(batchConfig.GetPeriod()) if err != nil { - return batchingConfig{}, fmt.Errorf("unable to parse batch period for s3 destination config: %w", err) + return batchingConfig{}, fmt.Errorf( + "unable to parse batch period for s3 destination config: %w", + err, + ) } } output.BatchPeriod = batchConfig.GetPeriod() } if output.BatchCount == 0 && output.BatchPeriod == "" { - return batchingConfig{}, fmt.Errorf("must have at least one batch policy configured. Cannot disable both period and count") + return batchingConfig{}, fmt.Errorf( + "must have at least one batch policy configured. Cannot disable both period and count", + ) } return output, nil } @@ -569,7 +610,11 @@ func getAdditionalJobMappings( mappedCols, ok := tableColMappings[schematable] if !ok { // todo: we may want to generate mappings for this entire table? However this may be dead code as we get the grouped schemas based on the mappings - logger.Warn("table found in schema data that is not present in job mappings", "table", schematable) + logger.Warn( + "table found in schema data that is not present in job mappings", + "table", + schematable, + ) continue } if len(cols) == len(mappedCols) { @@ -582,7 +627,8 @@ func getAdditionalJobMappings( return nil, err } // we found a column that is not present in the mappings, let's create a mapping for it - if info.ColumnDefault != "" || info.IdentityGeneration != nil || info.GeneratedType != nil { + if info.ColumnDefault != "" || info.IdentityGeneration != nil || + info.GeneratedType != nil { output = append(output, &mgmtv1alpha1.JobMapping{ Schema: schema, Table: table, @@ -645,7 +691,10 @@ func getAdditionalJobMappings( return output, nil } -func getJmTransformerByPostgresDataType(colInfo *sqlmanager_shared.DatabaseSchemaRow) (*mgmtv1alpha1.JobMappingTransformer, error) { + +func getJmTransformerByPostgresDataType( + colInfo *sqlmanager_shared.DatabaseSchemaRow, +) (*mgmtv1alpha1.JobMappingTransformer, error) { cleanedDataType := cleanPostgresType(colInfo.DataType) switch cleanedDataType { case "smallint": @@ -738,7 +787,10 @@ func getJmTransformerByPostgresDataType(colInfo *sqlmanager_shared.DatabaseSchem }, }, }, nil - case "text", "bpchar", "character", "character varying": // todo: test to see if this works when (n) has been specified + case "text", + "bpchar", + "character", + "character varying": // todo: test to see if this works when (n) has been specified return &mgmtv1alpha1.JobMappingTransformer{ Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateStringConfig{ @@ -777,18 +829,27 @@ func getJmTransformerByPostgresDataType(colInfo *sqlmanager_shared.DatabaseSchem return &mgmtv1alpha1.JobMappingTransformer{ Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateUuidConfig{ - GenerateUuidConfig: &mgmtv1alpha1.GenerateUuid{IncludeHyphens: shared.Ptr(true)}, + GenerateUuidConfig: &mgmtv1alpha1.GenerateUuid{ + IncludeHyphens: shared.Ptr(true), + }, }, }, }, nil default: - return nil, fmt.Errorf("uncountered unsupported data type %q for %q.%q.%q when attempting to generate an auto-mapper. To continue, provide a discrete job mapping for this column.: %w", - colInfo.DataType, colInfo.TableSchema, colInfo.TableName, colInfo.ColumnName, errors.ErrUnsupported, + return nil, fmt.Errorf( + "uncountered unsupported data type %q for %q.%q.%q when attempting to generate an auto-mapper. To continue, provide a discrete job mapping for this column.: %w", + colInfo.DataType, + colInfo.TableSchema, + colInfo.TableName, + colInfo.ColumnName, + errors.ErrUnsupported, ) } } -func getJmTransformerByMysqlDataType(colInfo *sqlmanager_shared.DatabaseSchemaRow) (*mgmtv1alpha1.JobMappingTransformer, error) { +func getJmTransformerByMysqlDataType( + colInfo *sqlmanager_shared.DatabaseSchemaRow, +) (*mgmtv1alpha1.JobMappingTransformer, error) { cleanedDataType := cleanMysqlType(colInfo.MysqlColumnType) switch cleanedDataType { case "char": @@ -798,7 +859,11 @@ func getJmTransformerByMysqlDataType(colInfo *sqlmanager_shared.DatabaseSchemaRo if len(params) > 0 { fixedLength, err := strconv.ParseInt(params[0], 10, 64) if err != nil { - return nil, fmt.Errorf("failed to parse length for type %q: %w", colInfo.MysqlColumnType, err) + return nil, fmt.Errorf( + "failed to parse length for type %q: %w", + colInfo.MysqlColumnType, + err, + ) } minLength = fixedLength maxLength = fixedLength @@ -808,7 +873,10 @@ func getJmTransformerByMysqlDataType(colInfo *sqlmanager_shared.DatabaseSchemaRo return &mgmtv1alpha1.JobMappingTransformer{ Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateStringConfig{ - GenerateStringConfig: &mgmtv1alpha1.GenerateString{Min: shared.Ptr(minLength), Max: shared.Ptr(maxLength)}, + GenerateStringConfig: &mgmtv1alpha1.GenerateString{ + Min: shared.Ptr(minLength), + Max: shared.Ptr(maxLength), + }, }, }, }, nil @@ -819,7 +887,11 @@ func getJmTransformerByMysqlDataType(colInfo *sqlmanager_shared.DatabaseSchemaRo if len(params) > 0 { fixedLength, err := strconv.ParseInt(params[0], 10, 64) if err != nil { - return nil, fmt.Errorf("failed to parse length for type %q: %w", colInfo.MysqlColumnType, err) + return nil, fmt.Errorf( + "failed to parse length for type %q: %w", + colInfo.MysqlColumnType, + err, + ) } maxLength = fixedLength } else if colInfo.CharacterMaximumLength > 0 { @@ -848,7 +920,11 @@ func getJmTransformerByMysqlDataType(colInfo *sqlmanager_shared.DatabaseSchemaRo if len(params) > 0 { length, err := strconv.ParseInt(params[0], 10, 64) if err != nil { - return nil, fmt.Errorf("failed to parse length for type %q: %w", colInfo.MysqlColumnType, err) + return nil, fmt.Errorf( + "failed to parse length for type %q: %w", + colInfo.MysqlColumnType, + err, + ) } maxLength = length } else if colInfo.CharacterMaximumLength > 0 { @@ -866,7 +942,9 @@ func getJmTransformerByMysqlDataType(colInfo *sqlmanager_shared.DatabaseSchemaRo return &mgmtv1alpha1.JobMappingTransformer{ Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateStringConfig{ - GenerateStringConfig: &mgmtv1alpha1.GenerateString{Max: shared.Ptr(int64(16_777_215))}, + GenerateStringConfig: &mgmtv1alpha1.GenerateString{ + Max: shared.Ptr(int64(16_777_215)), + }, }, }, }, nil @@ -874,7 +952,9 @@ func getJmTransformerByMysqlDataType(colInfo *sqlmanager_shared.DatabaseSchemaRo return &mgmtv1alpha1.JobMappingTransformer{ Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateStringConfig{ - GenerateStringConfig: &mgmtv1alpha1.GenerateString{Max: shared.Ptr(int64(4_294_967_295))}, + GenerateStringConfig: &mgmtv1alpha1.GenerateString{ + Max: shared.Ptr(int64(4_294_967_295)), + }, }, }, }, nil @@ -883,7 +963,9 @@ func getJmTransformerByMysqlDataType(colInfo *sqlmanager_shared.DatabaseSchemaRo return &mgmtv1alpha1.JobMappingTransformer{ Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_GenerateCategoricalConfig{ - GenerateCategoricalConfig: &mgmtv1alpha1.GenerateCategorical{Categories: shared.Ptr(strings.Join(params, ","))}, + GenerateCategoricalConfig: &mgmtv1alpha1.GenerateCategorical{ + Categories: shared.Ptr(strings.Join(params, ",")), + }, }, }, }, nil @@ -1213,8 +1295,13 @@ func getJmTransformerByMysqlDataType(colInfo *sqlmanager_shared.DatabaseSchemaRo }, }, nil default: - return nil, fmt.Errorf("uncountered unsupported data type %q for %q.%q.%q when attempting to generate an auto-mapper. To continue, provide a discrete job mapping for this column.: %w", - colInfo.DataType, colInfo.TableSchema, colInfo.TableName, colInfo.ColumnName, errors.ErrUnsupported, + return nil, fmt.Errorf( + "uncountered unsupported data type %q for %q.%q.%q when attempting to generate an auto-mapper. To continue, provide a discrete job mapping for this column.: %w", + colInfo.DataType, + colInfo.TableSchema, + colInfo.TableName, + colInfo.ColumnName, + errors.ErrUnsupported, ) } } @@ -1281,7 +1368,9 @@ func extractMysqlTypeParams(dataType string) []string { return result } -func shouldOverrideColumnDefault(columnDefaults map[string]*neosync_benthos.ColumnDefaultProperties) bool { +func shouldOverrideColumnDefault( + columnDefaults map[string]*neosync_benthos.ColumnDefaultProperties, +) bool { for _, cd := range columnDefaults { if cd != nil && !cd.HasDefaultTransformer && cd.NeedsOverride { return true @@ -1290,15 +1379,41 @@ func shouldOverrideColumnDefault(columnDefaults map[string]*neosync_benthos.Colu return false } -func getSqlBatchProcessors(driver string, columns []string, columnDataTypes map[string]string, columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties) (*neosync_benthos.BatchProcessor, error) { +func getSqlBatchProcessors( + driver string, + columns []string, + columnDataTypes map[string]string, + columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties, +) (*neosync_benthos.BatchProcessor, error) { switch driver { case sqlmanager_shared.PostgresDriver: - return &neosync_benthos.BatchProcessor{NeosyncToPgx: &neosync_benthos.NeosyncToPgxConfig{Columns: columns, ColumnDataTypes: columnDataTypes, ColumnDefaultProperties: columnDefaultProperties}}, nil + return &neosync_benthos.BatchProcessor{ + NeosyncToPgx: &neosync_benthos.NeosyncToPgxConfig{ + Columns: columns, + ColumnDataTypes: columnDataTypes, + ColumnDefaultProperties: columnDefaultProperties, + }, + }, nil case sqlmanager_shared.MysqlDriver: - return &neosync_benthos.BatchProcessor{NeosyncToMysql: &neosync_benthos.NeosyncToMysqlConfig{Columns: columns, ColumnDataTypes: columnDataTypes, ColumnDefaultProperties: columnDefaultProperties}}, nil + return &neosync_benthos.BatchProcessor{ + NeosyncToMysql: &neosync_benthos.NeosyncToMysqlConfig{ + Columns: columns, + ColumnDataTypes: columnDataTypes, + ColumnDefaultProperties: columnDefaultProperties, + }, + }, nil case sqlmanager_shared.MssqlDriver: - return &neosync_benthos.BatchProcessor{NeosyncToMssql: &neosync_benthos.NeosyncToMssqlConfig{Columns: columns, ColumnDataTypes: columnDataTypes, ColumnDefaultProperties: columnDefaultProperties}}, nil + return &neosync_benthos.BatchProcessor{ + NeosyncToMssql: &neosync_benthos.NeosyncToMssqlConfig{ + Columns: columns, + ColumnDataTypes: columnDataTypes, + ColumnDefaultProperties: columnDefaultProperties, + }, + }, nil default: - return nil, fmt.Errorf("unsupported driver %q when attempting to get sql batch processors", driver) + return nil, fmt.Errorf( + "unsupported driver %q when attempting to get sql batch processors", + driver, + ) } } diff --git a/internal/benthos/benthos-builder/builders/sql.go b/internal/benthos/benthos-builder/builders/sql.go index 567a44f6f1..8a27722fe3 100644 --- a/internal/benthos/benthos-builder/builders/sql.go +++ b/internal/benthos/benthos-builder/builders/sql.go @@ -59,7 +59,10 @@ func NewSqlSyncBuilder( } } -func (b *sqlSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_internal.SourceParams) ([]*bb_internal.BenthosSourceConfig, error) { +func (b *sqlSyncBuilder) BuildSourceConfigs( + ctx context.Context, + params *bb_internal.SourceParams, +) ([]*bb_internal.BenthosSourceConfig, error) { sourceConnection := params.SourceConnection job := params.Job logger := params.Logger @@ -73,7 +76,12 @@ func (b *sqlSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_inte sourceTableOpts = groupSqlJobSourceOptionsByTable(sqlSourceOpts) } - db, err := b.sqlmanagerclient.NewSqlConnection(ctx, connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(params.JobRunId)), sourceConnection, logger) + db, err := b.sqlmanagerclient.NewSqlConnection( + ctx, + connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(params.JobRunId)), + sourceConnection, + logger, + ) if err != nil { return nil, fmt.Errorf("unable to create new sql db: %w", err) } @@ -88,14 +96,25 @@ func (b *sqlSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_inte if sqlSourceOpts != nil && sqlSourceOpts.HaltOnNewColumnAddition { newColumns, shouldHalt := shouldHaltOnSchemaAddition(groupedColumnInfo, job.Mappings) if shouldHalt { - return nil, fmt.Errorf("%s: [%s]", haltOnSchemaAdditionErrMsg, strings.Join(newColumns, ", ")) + return nil, fmt.Errorf( + "%s: [%s]", + haltOnSchemaAdditionErrMsg, + strings.Join(newColumns, ", "), + ) } } if sqlSourceOpts != nil && sqlSourceOpts.HaltOnColumnRemoval { - missing, shouldHalt := isSourceMissingColumnsFoundInMappings(groupedColumnInfo, job.Mappings) + missing, shouldHalt := isSourceMissingColumnsFoundInMappings( + groupedColumnInfo, + job.Mappings, + ) if shouldHalt { - return nil, fmt.Errorf("%s: [%s]", haltOnSchemaAdditionErrMsg, strings.Join(missing, ", ")) + return nil, fmt.Errorf( + "%s: [%s]", + haltOnSchemaAdditionErrMsg, + strings.Join(missing, ", "), + ) } } @@ -103,11 +122,19 @@ func (b *sqlSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_inte existingSourceMappings := removeMappingsNotFoundInSource(job.Mappings, groupedColumnInfo) if sqlSourceOpts != nil && sqlSourceOpts.GenerateNewColumnTransformers { - extraMappings, err := getAdditionalJobMappings(b.driver, groupedColumnInfo, existingSourceMappings, splitKeyToTablePieces, logger) + extraMappings, err := getAdditionalJobMappings( + b.driver, + groupedColumnInfo, + existingSourceMappings, + splitKeyToTablePieces, + logger, + ) if err != nil { return nil, err } - logger.Debug(fmt.Sprintf("adding %d extra mappings due to unmapped columns", len(extraMappings))) + logger.Debug( + fmt.Sprintf("adding %d extra mappings due to unmapped columns", len(extraMappings)), + ) existingSourceMappings = append(existingSourceMappings, extraMappings...) } uniqueSchemas := shared.GetUniqueSchemasFromMappings(existingSourceMappings) @@ -117,24 +144,47 @@ func (b *sqlSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_inte return nil, fmt.Errorf("unable to retrieve database table constraints: %w", err) } - foreignKeysMap, err := mergeVirtualForeignKeys(tableConstraints.ForeignKeyConstraints, job.GetVirtualForeignKeys(), groupedColumnInfo) + foreignKeysMap, err := mergeVirtualForeignKeys( + tableConstraints.ForeignKeyConstraints, + job.GetVirtualForeignKeys(), + groupedColumnInfo, + ) if err != nil { return nil, err } - logger.Info(fmt.Sprintf("found %d foreign key constraints for database", getMapValuesCount(tableConstraints.ForeignKeyConstraints))) - logger.Info(fmt.Sprintf("found %d primary key constraints for database", getMapValuesCount(tableConstraints.PrimaryKeyConstraints))) + logger.Info( + fmt.Sprintf( + "found %d foreign key constraints for database", + getMapValuesCount(tableConstraints.ForeignKeyConstraints), + ), + ) + logger.Info( + fmt.Sprintf( + "found %d primary key constraints for database", + getMapValuesCount(tableConstraints.PrimaryKeyConstraints), + ), + ) groupedMappings := groupMappingsByTable(existingSourceMappings) groupedTableMapping := getTableMappingsMap(groupedMappings) - colTransformerMap := getColumnTransformerMap(groupedTableMapping) // schema.table -> column -> transformer + colTransformerMap := getColumnTransformerMap( + groupedTableMapping, + ) // schema.table -> column -> transformer b.colTransformerMap = colTransformerMap // include virtual foreign keys and removes fks that have null transformers filteredForeignKeysMap := filterForeignKeysMap(colTransformerMap, foreignKeysMap) tableSubsetMap := buildTableSubsetMap(sourceTableOpts, groupedTableMapping) tableColMap := getTableColMapFromMappings(groupedMappings) - runConfigs, err := rc.BuildRunConfigs(filteredForeignKeysMap, tableSubsetMap, tableConstraints.PrimaryKeyConstraints, tableColMap, tableConstraints.UniqueIndexes, tableConstraints.UniqueConstraints) + runConfigs, err := rc.BuildRunConfigs( + filteredForeignKeysMap, + tableSubsetMap, + tableConstraints.PrimaryKeyConstraints, + tableColMap, + tableConstraints.UniqueIndexes, + tableConstraints.UniqueConstraints, + ) if err != nil { return nil, err } @@ -142,13 +192,33 @@ func (b *sqlSyncBuilder) BuildSourceConfigs(ctx context.Context, params *bb_inte primaryKeyToForeignKeysMap := getPrimaryKeyDependencyMap(filteredForeignKeysMap) b.primaryKeyToForeignKeysMap = primaryKeyToForeignKeysMap - configQueryMap, err := b.selectQueryBuilder.BuildSelectQueryMap(db.Driver(), runConfigs, sqlSourceOpts.SubsetByForeignKeyConstraints, b.pageLimit) + configQueryMap, err := b.selectQueryBuilder.BuildSelectQueryMap( + db.Driver(), + runConfigs, + sqlSourceOpts.SubsetByForeignKeyConstraints, + b.pageLimit, + ) if err != nil { return nil, fmt.Errorf("unable to build select queries: %w", err) } b.configQueryMap = configQueryMap - configs, err := buildBenthosSqlSourceConfigResponses(logger, ctx, b.transformerclient, groupedTableMapping, runConfigs, sourceConnection.Id, configQueryMap, groupedColumnInfo, filteredForeignKeysMap, colTransformerMap, job.Id, params.JobRunId, b.redisConfig, primaryKeyToForeignKeysMap) + configs, err := buildBenthosSqlSourceConfigResponses( + logger, + ctx, + b.transformerclient, + groupedTableMapping, + runConfigs, + sourceConnection.Id, + configQueryMap, + groupedColumnInfo, + filteredForeignKeysMap, + colTransformerMap, + job.Id, + params.JobRunId, + b.redisConfig, + primaryKeyToForeignKeysMap, + ) if err != nil { return nil, fmt.Errorf("unable to build benthos sql source config responses: %w", err) } @@ -280,20 +350,36 @@ func buildBenthosSqlSourceConfigResponses( return configs, nil } -func (b *sqlSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_internal.DestinationParams) (*bb_internal.BenthosDestinationConfig, error) { +func (b *sqlSyncBuilder) BuildDestinationConfig( + ctx context.Context, + params *bb_internal.DestinationParams, +) (*bb_internal.BenthosDestinationConfig, error) { logger := params.Logger benthosConfig := params.SourceConfig - tableKey := neosync_benthos.BuildBenthosTable(benthosConfig.TableSchema, benthosConfig.TableName) + tableKey := neosync_benthos.BuildBenthosTable( + benthosConfig.TableSchema, + benthosConfig.TableName, + ) config := &bb_internal.BenthosDestinationConfig{} // lazy load if len(b.mergedSchemaColumnMap) == 0 { - sqlSchemaColMap := getSqlSchemaColumnMap(ctx, connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(params.JobRunId)), params.DestConnection, b.sqlSourceSchemaColumnInfoMap, b.sqlmanagerclient, params.Logger) + sqlSchemaColMap := getSqlSchemaColumnMap( + ctx, + connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(params.JobRunId)), + params.DestConnection, + b.sqlSourceSchemaColumnInfoMap, + b.sqlmanagerclient, + params.Logger, + ) b.mergedSchemaColumnMap = sqlSchemaColMap } if len(b.mergedSchemaColumnMap) == 0 { - return nil, fmt.Errorf("unable to retrieve schema columns for either source or destination: %s", params.DestConnection.Name) + return nil, fmt.Errorf( + "unable to retrieve schema columns for either source or destination: %s", + params.DestConnection.Name, + ) } var colInfoMap map[string]*sqlmanager_shared.DatabaseSchemaRow @@ -303,7 +389,11 @@ func (b *sqlSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_ } if len(colInfoMap) == 0 { - return nil, fmt.Errorf("unable to retrieve schema columns for destination: %s table: %s", params.DestConnection.Name, tableKey) + return nil, fmt.Errorf( + "unable to retrieve schema columns for destination: %s table: %s", + params.DestConnection.Name, + tableKey, + ) } colTransformerMap := b.colTransformerMap @@ -311,7 +401,9 @@ func (b *sqlSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_ if len(colTransformerMap) == 0 { groupedMappings := groupMappingsByTable(params.Job.Mappings) groupedTableMapping := getTableMappingsMap(groupedMappings) - colTMap := getColumnTransformerMap(groupedTableMapping) // schema.table -> column -> transformer + colTMap := getColumnTransformerMap( + groupedTableMapping, + ) // schema.table -> column -> transformer b.colTransformerMap = colTMap colTransformerMap = colTMap } @@ -321,7 +413,13 @@ func (b *sqlSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_ return nil, fmt.Errorf("column transformer mappings not found for table: %s", tableKey) } - columnDefaultProperties, err := getColumnDefaultProperties(logger, b.driver, benthosConfig.Columns, colInfoMap, tableColTransformers) + columnDefaultProperties, err := getColumnDefaultProperties( + logger, + b.driver, + benthosConfig.Columns, + colInfoMap, + tableColTransformers, + ) if err != nil { return nil, err } @@ -343,13 +441,22 @@ func (b *sqlSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_ query := b.configQueryMap[benthosConfig.Name] // skip foreign key violations if the query could return rows that violate foreign key constraints - skipForeignKeyViolations := destOpts.SkipForeignKeyViolations || (query != nil && query.IsNotForeignKeySafeSubset) + skipForeignKeyViolations := destOpts.SkipForeignKeyViolations || + (query != nil && query.IsNotForeignKeySafeSubset) - config.BenthosDsns = append(config.BenthosDsns, &bb_shared.BenthosDsn{ConnectionId: params.DestConnection.Id}) + config.BenthosDsns = append( + config.BenthosDsns, + &bb_shared.BenthosDsn{ConnectionId: params.DestConnection.Id}, + ) if benthosConfig.RunType == rc.RunTypeUpdate { processorColumns := benthosConfig.Columns processorColumns = append(processorColumns, benthosConfig.PrimaryKeys...) - sqlProcessor, err := getProcessors(b.driver, processorColumns, colInfoMap, columnDefaultProperties) + sqlProcessor, err := getProcessors( + b.driver, + processorColumns, + colInfoMap, + columnDefaultProperties, + ) if err != nil { return nil, err } @@ -459,7 +566,12 @@ func (b *sqlSyncBuilder) BuildDestinationConfig(ctx context.Context, params *bb_ return config, nil } -func getProcessors(driver string, columns []string, colInfoMap map[string]*sqlmanager_shared.DatabaseSchemaRow, columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties) (*neosync_benthos.BatchProcessor, error) { +func getProcessors( + driver string, + columns []string, + colInfoMap map[string]*sqlmanager_shared.DatabaseSchemaRow, + columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties, +) (*neosync_benthos.BatchProcessor, error) { columnDataTypes := map[string]string{} for _, c := range columns { colType, ok := colInfoMap[c] @@ -483,9 +595,17 @@ func getInsertPrefixAndSuffix( case sqlmanager_shared.MssqlDriver: if hasPassthroughIdentityColumn(columnDefaultProperties) { enableIdentityInsert := true - p := sqlmanager_mssql.BuildMssqlSetIdentityInsertStatement(schema, table, enableIdentityInsert) + p := sqlmanager_mssql.BuildMssqlSetIdentityInsertStatement( + schema, + table, + enableIdentityInsert, + ) pre = &p - s := sqlmanager_mssql.BuildMssqlSetIdentityInsertStatement(schema, table, !enableIdentityInsert) + s := sqlmanager_mssql.BuildMssqlSetIdentityInsertStatement( + schema, + table, + !enableIdentityInsert, + ) suff = &s } return pre, suff @@ -494,7 +614,9 @@ func getInsertPrefixAndSuffix( } } -func hasPassthroughIdentityColumn(columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties) bool { +func hasPassthroughIdentityColumn( + columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties, +) bool { for _, d := range columnDefaultProperties { if d.NeedsOverride && d.NeedsReset && !d.HasDefaultTransformer { return true diff --git a/internal/benthos/benthos-builder/generate-benthos.go b/internal/benthos/benthos-builder/generate-benthos.go index 13c0b63b29..556ebe3b39 100644 --- a/internal/benthos/benthos-builder/generate-benthos.go +++ b/internal/benthos/benthos-builder/generate-benthos.go @@ -48,7 +48,10 @@ func (b *BenthosConfigManager) GenerateBenthosConfigs( destOpts, ok := destinationOpts[destConnection.GetId()] if !ok { - return nil, fmt.Errorf("unable to find destination options for connection: %s", destConnection.GetId()) + return nil, fmt.Errorf( + "unable to find destination options for connection: %s", + destConnection.GetId(), + ) } for _, sourceConfig := range sourceConfigs { @@ -65,7 +68,9 @@ func (b *BenthosConfigManager) GenerateBenthosConfigs( if err != nil { return nil, err } - sourceConfig.Config.Output.Broker.Outputs = append(sourceConfig.Config.Output.Broker.Outputs, destConfig.Outputs...) + sourceConfig.Config.Output.Broker.Outputs = append( + sourceConfig.Config.Output.Broker.Outputs, + destConfig.Outputs...) sourceConfig.BenthosDsns = append(sourceConfig.BenthosDsns, destConfig.BenthosDsns...) } b.logger.Debug(fmt.Sprintf("applied destination to %d source configs", len(sourceConfigs))) @@ -76,7 +81,10 @@ func (b *BenthosConfigManager) GenerateBenthosConfigs( labels := metrics.MetricLabels{ metrics.NewEqLabel(metrics.AccountIdLabel, b.job.AccountId), metrics.NewEqLabel(metrics.JobIdLabel, b.job.Id), - metrics.NewEqLabel(metrics.NeosyncDateLabel, bb_shared.WithEnvInterpolation(metrics.NeosyncDateEnvKey)), + metrics.NewEqLabel( + metrics.NeosyncDateLabel, + bb_shared.WithEnvInterpolation(metrics.NeosyncDateEnvKey), + ), } for key, val := range b.metricLabelKeyVals { labels = append(labels, metrics.NewEqLabel(key, val)) @@ -112,7 +120,9 @@ func (b *BenthosConfigManager) GenerateBenthosConfigs( } // builds map of destination id -> destination options -func buildDestinationOptionsMap(jobDests []*mgmtv1alpha1.JobDestination) map[string]*mgmtv1alpha1.JobDestinationOptions { +func buildDestinationOptionsMap( + jobDests []*mgmtv1alpha1.JobDestination, +) map[string]*mgmtv1alpha1.JobDestinationOptions { destOpts := map[string]*mgmtv1alpha1.JobDestinationOptions{} for _, dest := range jobDests { destOpts[dest.GetConnectionId()] = dest.GetOptions() @@ -139,7 +149,8 @@ func convertToResponse(sourceConfig *bb_internal.BenthosSourceConfig) *BenthosCo func isOnlyBucketDestinations(destinations []*mgmtv1alpha1.JobDestination) bool { for _, dest := range destinations { - if dest.GetOptions().GetAwsS3Options() == nil && dest.GetOptions().GetGcpCloudstorageOptions() == nil { + if dest.GetOptions().GetAwsS3Options() == nil && + dest.GetOptions().GetGcpCloudstorageOptions() == nil { return false } } diff --git a/internal/benthos/benthos-builder/internal/types.go b/internal/benthos/benthos-builder/internal/types.go index c73d22f341..dc696ac559 100644 --- a/internal/benthos/benthos-builder/internal/types.go +++ b/internal/benthos/benthos-builder/internal/types.go @@ -64,7 +64,10 @@ type BenthosBuilder interface { BuildSourceConfigs(ctx context.Context, params *SourceParams) ([]*BenthosSourceConfig, error) // BuildDestinationConfig creates a Benthos destination configuration for writing processed data. // Returns single config for a schema.table configuration - BuildDestinationConfig(ctx context.Context, params *DestinationParams) (*BenthosDestinationConfig, error) + BuildDestinationConfig( + ctx context.Context, + params *DestinationParams, + ) (*BenthosDestinationConfig, error) } // SourceParams contains all parameters needed to build a source benthos configuration diff --git a/internal/benthos_slogger/logger.go b/internal/benthos_slogger/logger.go index c7bfcea55f..75bfaefd7a 100644 --- a/internal/benthos_slogger/logger.go +++ b/internal/benthos_slogger/logger.go @@ -20,7 +20,10 @@ func (h *benthosLogHandler) Enabled(ctx context.Context, level slog.Level) bool return true } -func (h *benthosLogHandler) Handle(ctx context.Context, r slog.Record) error { //nolint:gocritic // Needs to conform to the slog.Handler interface +func (h *benthosLogHandler) Handle( + ctx context.Context, + r slog.Record, +) error { //nolint:gocritic // Needs to conform to the slog.Handler interface // Combine pre-defined attrs with record attrs allAttrs := make([]slog.Attr, 0, len(h.attrs)+r.NumAttrs()) allAttrs = append(allAttrs, h.attrs...) diff --git a/internal/billing/stripe-billing.go b/internal/billing/stripe-billing.go index 86e9a70e95..878f06209d 100644 --- a/internal/billing/stripe-billing.go +++ b/internal/billing/stripe-billing.go @@ -18,7 +18,10 @@ type SubscriptionIter interface { type Interface interface { NewCustomer(req *CustomerRequest) (*stripe.Customer, error) NewBillingPortalSession(customerId, accountSlug string) (*stripe.BillingPortalSession, error) - NewCheckoutSession(customerId, accountSlug, userId string, logger *slog.Logger) (*stripe.CheckoutSession, error) + NewCheckoutSession( + customerId, accountSlug, userId string, + logger *slog.Logger, + ) (*stripe.CheckoutSession, error) GetSubscriptions(customerId string) SubscriptionIter NewMeterEvent(req *MeterEventRequest) (*stripe.BillingMeterEvent, error) } @@ -68,14 +71,21 @@ func (c *Client) NewCustomer(req *CustomerRequest) (*stripe.Customer, error) { }) } -func (c *Client) NewBillingPortalSession(customerId, accountSlug string) (*stripe.BillingPortalSession, error) { +func (c *Client) NewBillingPortalSession( + customerId, accountSlug string, +) (*stripe.BillingPortalSession, error) { return c.client.BillingPortalSessions.New(&stripe.BillingPortalSessionParams{ - Customer: stripe.String(customerId), - ReturnURL: stripe.String(fmt.Sprintf("%s/%s/settings/billing", c.cfg.AppBaseUrl, accountSlug)), + Customer: stripe.String(customerId), + ReturnURL: stripe.String( + fmt.Sprintf("%s/%s/settings/billing", c.cfg.AppBaseUrl, accountSlug), + ), }) } -func (c *Client) NewCheckoutSession(customerId, accountSlug, userId string, logger *slog.Logger) (*stripe.CheckoutSession, error) { +func (c *Client) NewCheckoutSession( + customerId, accountSlug, userId string, + logger *slog.Logger, +) (*stripe.CheckoutSession, error) { priceMap, err := c.getPricesFromLookupKeys() if err != nil { return nil, err @@ -97,12 +107,16 @@ func (c *Client) NewCheckoutSession(customerId, accountSlug, userId string, logg } logger.Debug("creating stripe checkout session", "numLineItems", len(lineitems)) return c.client.CheckoutSessions.New(&stripe.CheckoutSessionParams{ - Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), - LineItems: lineitems, - SuccessURL: stripe.String(fmt.Sprintf("%s/%s/settings/billing", c.cfg.AppBaseUrl, accountSlug)), - CancelURL: stripe.String(fmt.Sprintf("%s/%s/settings/billing", c.cfg.AppBaseUrl, accountSlug)), - Customer: stripe.String(customerId), - Metadata: map[string]string{"userId": userId}, + Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), + LineItems: lineitems, + SuccessURL: stripe.String( + fmt.Sprintf("%s/%s/settings/billing", c.cfg.AppBaseUrl, accountSlug), + ), + CancelURL: stripe.String( + fmt.Sprintf("%s/%s/settings/billing", c.cfg.AppBaseUrl, accountSlug), + ), + Customer: stripe.String(customerId), + Metadata: map[string]string{"userId": userId}, SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{ BillingCycleAnchor: stripe.Int64(getNextMonthBillingCycleAnchor(time.Now().UTC())), }, @@ -155,7 +169,11 @@ func (c *Client) getPricesFromLookupKeys() (map[string]*stripe.Price, error) { return nil, iter.Err() } if len(output) != len(c.cfg.PriceLookups) { - return nil, fmt.Errorf("unable to resolve all stripe price lookups to valid prices. need %d, found %d", len(c.cfg.PriceLookups), len(output)) + return nil, fmt.Errorf( + "unable to resolve all stripe price lookups to valid prices. need %d, found %d", + len(c.cfg.PriceLookups), + len(output), + ) } return output, nil } diff --git a/internal/connection-manager/manager.go b/internal/connection-manager/manager.go index 246aacc398..80fa8918a2 100644 --- a/internal/connection-manager/manager.go +++ b/internal/connection-manager/manager.go @@ -12,7 +12,10 @@ import ( ) type ConnectionProvider[T any] interface { - GetConnectionClient(connectionConfig *mgmtv1alpha1.ConnectionConfig, logger *slog.Logger) (T, error) + GetConnectionClient( + connectionConfig *mgmtv1alpha1.ConnectionConfig, + logger *slog.Logger, + ) (T, error) CloseClientConnection(client T) error } @@ -39,7 +42,11 @@ type ConnectionInput interface { } type Interface[T any] interface { - GetConnection(session SessionInterface, connection ConnectionInput, logger *slog.Logger) (T, error) + GetConnection( + session SessionInterface, + connection ConnectionInput, + logger *slog.Logger, + ) (T, error) ReleaseSession(session SessionInterface, logger *slog.Logger) bool Shutdown(logger *slog.Logger) Reaper(logger *slog.Logger) @@ -114,7 +121,10 @@ func (c *ConnectionManager[T]) GetConnection( logger.Debug("no cached connection found, creating new connection client") // Create new connection - connectionClient, err := c.connectionProvider.GetConnectionClient(connection.GetConnectionConfig(), logger) + connectionClient, err := c.connectionProvider.GetConnectionClient( + connection.GetConnectionConfig(), + logger, + ) if err != nil { var result T return result, err @@ -170,7 +180,9 @@ func (c *ConnectionManager[T]) ReleaseSession(session SessionInterface, logger * } if c.config.closeOnRelease { - logger.Debug("close on release is enabled, pruning connections that are not bound to any sessions in the group") + logger.Debug( + "close on release is enabled, pruning connections that are not bound to any sessions in the group", + ) remainingConns := getUniqueConnectionIdsFromGroupSessions(groupSessions) c.closeSpecificGroupConnections(groupId, sessionConnIds, remainingConns, logger) } @@ -178,7 +190,12 @@ func (c *ConnectionManager[T]) ReleaseSession(session SessionInterface, logger * } // does not handle locks as it assumes the parent caller holds the lock -func (c *ConnectionManager[T]) closeSpecificGroupConnections(groupId string, candidateConnIds []string, remainingConns map[string]struct{}, logger *slog.Logger) { +func (c *ConnectionManager[T]) closeSpecificGroupConnections( + groupId string, + candidateConnIds []string, + remainingConns map[string]struct{}, + logger *slog.Logger, +) { groupConns, exists := c.groupConnMap[groupId] if !exists { return @@ -189,7 +206,12 @@ func (c *ConnectionManager[T]) closeSpecificGroupConnections(groupId string, can if dbConn, exists := groupConns[connId]; exists { logger.Debug(fmt.Sprintf("closing connection %q", connId)) if err := c.connectionProvider.CloseClientConnection(dbConn); err != nil { - logger.Error(fmt.Sprintf("unable to close client connection during release: %s", err.Error())) + logger.Error( + fmt.Sprintf( + "unable to close client connection during release: %s", + err.Error(), + ), + ) } delete(groupConns, connId) } @@ -245,18 +267,49 @@ func (c *ConnectionManager[T]) cleanUnusedConnections(logger *slog.Logger) { for session := range sessions { groupSessions = append(groupSessions, session) } - logger.Debug(fmt.Sprintf("[ConnectionManager][Reaper] group %q with sessions %s", groupId, strings.Join(groupSessions, ","))) + logger.Debug( + fmt.Sprintf( + "[ConnectionManager][Reaper] group %q with sessions %s", + groupId, + strings.Join(groupSessions, ","), + ), + ) } for groupId, groupConns := range c.groupConnMap { - logger.Debug(fmt.Sprintf("[ConnectionManager][Reaper] checking group %q with %d connection(s)", groupId, len(groupConns))) + logger.Debug( + fmt.Sprintf( + "[ConnectionManager][Reaper] checking group %q with %d connection(s)", + groupId, + len(groupConns), + ), + ) sessionConns := groupSessionConnections[groupId] for connId, dbConn := range groupConns { - logger.Debug(fmt.Sprintf("[ConnectionManager][Reaper] checking group %q for connection %q", groupId, connId)) + logger.Debug( + fmt.Sprintf( + "[ConnectionManager][Reaper] checking group %q for connection %q", + groupId, + connId, + ), + ) if _, ok := sessionConns[connId]; !ok { - logger.Debug(fmt.Sprintf("[ConnectionManager][Reaper] closing client connection: %q in group %q", connId, groupId)) + logger.Debug( + fmt.Sprintf( + "[ConnectionManager][Reaper] closing client connection: %q in group %q", + connId, + groupId, + ), + ) if err := c.connectionProvider.CloseClientConnection(dbConn); err != nil { - logger.Warn(fmt.Sprintf("[ConnectionManager][Reaper] unable to fully close client connection %q in group %q during cleanup: %s", connId, groupId, err.Error())) + logger.Warn( + fmt.Sprintf( + "[ConnectionManager][Reaper] unable to fully close client connection %q in group %q during cleanup: %s", + connId, + groupId, + err.Error(), + ), + ) } delete(groupConns, connId) } @@ -275,7 +328,12 @@ func (c *ConnectionManager[T]) hardClose(logger *slog.Logger) { for groupId, groupConns := range c.groupConnMap { for connId, dbConn := range groupConns { if err := c.connectionProvider.CloseClientConnection(dbConn); err != nil { - logger.Error(fmt.Sprintf("unable to fully close client connection during hard close: %s", err.Error())) + logger.Error( + fmt.Sprintf( + "unable to fully close client connection during hard close: %s", + err.Error(), + ), + ) } delete(groupConns, connId) } @@ -288,7 +346,9 @@ func (c *ConnectionManager[T]) hardClose(logger *slog.Logger) { } } -func getUniqueConnectionIdsFromGroupSessions(sessions map[string]map[string]struct{}) map[string]struct{} { +func getUniqueConnectionIdsFromGroupSessions( + sessions map[string]map[string]struct{}, +) map[string]struct{} { connSet := map[string]struct{}{} for _, sessionConns := range sessions { for connId := range sessionConns { diff --git a/internal/connection-manager/pool/providers/mongo/mongo-pool-provider.go b/internal/connection-manager/pool/providers/mongo/mongo-pool-provider.go index fac351e432..5fc380a82c 100644 --- a/internal/connection-manager/pool/providers/mongo/mongo-pool-provider.go +++ b/internal/connection-manager/pool/providers/mongo/mongo-pool-provider.go @@ -24,10 +24,18 @@ func NewProvider( session connectionmanager.SessionInterface, logger *slog.Logger, ) *Provider { - return &Provider{connmanager: connmanager, getConnection: getConnection, session: session, logger: logger} + return &Provider{ + connmanager: connmanager, + getConnection: getConnection, + session: session, + logger: logger, + } } -func (p *Provider) GetClient(ctx context.Context, connectionId string) (neosync_benthos_mongodb.MongoClient, error) { +func (p *Provider) GetClient( + ctx context.Context, + connectionId string, +) (neosync_benthos_mongodb.MongoClient, error) { conn, err := p.getConnection(connectionId) if err != nil { return nil, err diff --git a/internal/connection-manager/pool/providers/sql/sql-pool-provider.go b/internal/connection-manager/pool/providers/sql/sql-pool-provider.go index f705734648..2fc6ac4fe6 100644 --- a/internal/connection-manager/pool/providers/sql/sql-pool-provider.go +++ b/internal/connection-manager/pool/providers/sql/sql-pool-provider.go @@ -27,10 +27,18 @@ func NewConnectionProvider( session connectionmanager.SessionInterface, logger *slog.Logger, ) *Provider { - return &Provider{connmanager: connmanager, getConnection: getConnection, session: session, logger: logger} + return &Provider{ + connmanager: connmanager, + getConnection: getConnection, + session: session, + logger: logger, + } } -func (p *Provider) GetDb(ctx context.Context, connectionId string) (neosync_benthos_sql.SqlDbtx, error) { +func (p *Provider) GetDb( + ctx context.Context, + connectionId string, +) (neosync_benthos_sql.SqlDbtx, error) { conn, err := p.getConnection(connectionId) if err != nil { return nil, err diff --git a/internal/connection-manager/providers/mongoprovider/provider.go b/internal/connection-manager/providers/mongoprovider/provider.go index 1a7cc2afe5..6b31f6aaa5 100644 --- a/internal/connection-manager/providers/mongoprovider/provider.go +++ b/internal/connection-manager/providers/mongoprovider/provider.go @@ -21,7 +21,10 @@ func NewProvider() *Provider { var _ connectionmanager.ConnectionProvider[neosync_benthos_mongodb.MongoClient] = &Provider{} // this is currently untested as it isn't really used anywhere -func (p *Provider) GetConnectionClient(cc *mgmtv1alpha1.ConnectionConfig, logger *slog.Logger) (neosync_benthos_mongodb.MongoClient, error) { +func (p *Provider) GetConnectionClient( + cc *mgmtv1alpha1.ConnectionConfig, + logger *slog.Logger, +) (neosync_benthos_mongodb.MongoClient, error) { connStr := cc.GetMongoConfig().GetUrl() if connStr == "" { return nil, errors.New("unable to find mongodb url on connection config") diff --git a/internal/connection-manager/providers/sqlprovider/provider.go b/internal/connection-manager/providers/sqlprovider/provider.go index 2b168ad052..1e674a0bd3 100644 --- a/internal/connection-manager/providers/sqlprovider/provider.go +++ b/internal/connection-manager/providers/sqlprovider/provider.go @@ -33,8 +33,15 @@ func (s *sqlDbtxWrapper) Close() error { const defaultConnectionTimeoutSeconds = uint32(10) -func (p *Provider) GetConnectionClient(cc *mgmtv1alpha1.ConnectionConfig, logger *slog.Logger) (neosync_benthos_sql.SqlDbtx, error) { - container, err := p.connector.NewDbFromConnectionConfig(cc, logger, sqlconnect.WithConnectionTimeout(defaultConnectionTimeoutSeconds)) +func (p *Provider) GetConnectionClient( + cc *mgmtv1alpha1.ConnectionConfig, + logger *slog.Logger, +) (neosync_benthos_sql.SqlDbtx, error) { + container, err := p.connector.NewDbFromConnectionConfig( + cc, + logger, + sqlconnect.WithConnectionTimeout(defaultConnectionTimeoutSeconds), + ) if err != nil { return nil, err } diff --git a/internal/connectiondata/aws-s3.go b/internal/connectiondata/aws-s3.go index 882b27ac87..fe571861b9 100644 --- a/internal/connectiondata/aws-s3.go +++ b/internal/connectiondata/aws-s3.go @@ -115,11 +115,16 @@ func (s *AwsS3ConnectionDataService) StreamData( path := strings.Join(s3pathpieces, "/") var pageToken *string for { - output, err := s.awsmanager.ListObjectsV2(ctx, s3Client, s.connconfig.Region, &s3.ListObjectsV2Input{ - Bucket: aws.String(s.connconfig.Bucket), - Prefix: aws.String(path), - ContinuationToken: pageToken, - }) + output, err := s.awsmanager.ListObjectsV2( + ctx, + s3Client, + s.connconfig.Region, + &s3.ListObjectsV2Input{ + Bucket: aws.String(s.connconfig.Bucket), + Prefix: aws.String(path), + ContinuationToken: pageToken, + }, + ) if err != nil { return err } @@ -128,10 +133,15 @@ func (s *AwsS3ConnectionDataService) StreamData( break } for _, item := range output.Contents { - result, err := s.awsmanager.GetObject(ctx, s3Client, s.connconfig.Region, &s3.GetObjectInput{ - Bucket: aws.String(s.connconfig.Bucket), - Key: aws.String(*item.Key), - }) + result, err := s.awsmanager.GetObject( + ctx, + s3Client, + s.connconfig.Region, + &s3.GetObjectInput{ + Bucket: aws.String(s.connconfig.Bucket), + Key: aws.String(*item.Key), + }, + ) if err != nil { return err } @@ -158,7 +168,10 @@ func (s *AwsS3ConnectionDataService) StreamData( for k, v := range rowData { newVal, err := s.neosynctyperegistry.Unmarshal(v) if err != nil { - return fmt.Errorf("unable to unmarshal row value using neosync type registry: %w", err) + return fmt.Errorf( + "unable to unmarshal row value using neosync type registry: %w", + err, + ) } rowData[k] = newVal } @@ -235,12 +248,17 @@ func (s *AwsS3ConnectionDataService) GetSchema( schemas := []*mgmtv1alpha1.DatabaseColumn{} var pageToken *string for { - output, err := s.awsmanager.ListObjectsV2(ctx, s3Client, s.connconfig.Region, &s3.ListObjectsV2Input{ - Bucket: aws.String(s.connconfig.Bucket), - Prefix: aws.String(path), - Delimiter: aws.String("/"), - ContinuationToken: pageToken, - }) + output, err := s.awsmanager.ListObjectsV2( + ctx, + s3Client, + s.connconfig.Region, + &s3.ListObjectsV2Input{ + Bucket: aws.String(s.connconfig.Bucket), + Prefix: aws.String(path), + Delimiter: aws.String("/"), + ContinuationToken: pageToken, + }, + ) if err != nil { return nil, err } @@ -252,29 +270,53 @@ func (s *AwsS3ConnectionDataService) GetSchema( tableFolder := strings.ReplaceAll(folders[len(folders)-1], "/", "") schemaTableList := strings.Split(tableFolder, ".") - filePath := fmt.Sprintf("%s%s/data", path, sqlmanager_shared.BuildTable(schemaTableList[0], schemaTableList[1])) - out, err := s.awsmanager.ListObjectsV2(ctx, s3Client, s.connconfig.Region, &s3.ListObjectsV2Input{ - Bucket: aws.String(s.connconfig.Bucket), - Prefix: aws.String(filePath), - MaxKeys: aws.Int32(1), - }) + filePath := fmt.Sprintf( + "%s%s/data", + path, + sqlmanager_shared.BuildTable(schemaTableList[0], schemaTableList[1]), + ) + out, err := s.awsmanager.ListObjectsV2( + ctx, + s3Client, + s.connconfig.Region, + &s3.ListObjectsV2Input{ + Bucket: aws.String(s.connconfig.Bucket), + Prefix: aws.String(filePath), + MaxKeys: aws.Int32(1), + }, + ) if err != nil { return nil, err } if out == nil { - s.logger.Warn(fmt.Sprintf("AWS S3 table folder missing data folder: %s, continuing..", tableFolder)) + s.logger.Warn( + fmt.Sprintf( + "AWS S3 table folder missing data folder: %s, continuing..", + tableFolder, + ), + ) continue } item := out.Contents[0] - result, err := s.awsmanager.GetObject(ctx, s3Client, s.connconfig.Region, &s3.GetObjectInput{ - Bucket: aws.String(s.connconfig.Bucket), - Key: aws.String(*item.Key), - }) + result, err := s.awsmanager.GetObject( + ctx, + s3Client, + s.connconfig.Region, + &s3.GetObjectInput{ + Bucket: aws.String(s.connconfig.Bucket), + Key: aws.String(*item.Key), + }, + ) if err != nil { return nil, err } if result.ContentLength == nil || *result.ContentLength == 0 { - s.logger.Warn(fmt.Sprintf("empty AWS S3 data folder for table: %s, continuing...", tableFolder)) + s.logger.Warn( + fmt.Sprintf( + "empty AWS S3 data folder for table: %s, continuing...", + tableFolder, + ), + ) continue } @@ -368,7 +410,12 @@ func (s *AwsS3ConnectionDataService) getLastestJobRunFromAwsS3( sort.Sort(sort.Reverse(sort.StringSlice(runIDs))) if len(runIDs) == 0 { - return "", nucleuserrors.NewNotFound(fmt.Sprintf("unable to find latest job run for job in s3 after processing common prefixes: %s", jobId)) + return "", nucleuserrors.NewNotFound( + fmt.Sprintf( + "unable to find latest job run for job in s3 after processing common prefixes: %s", + jobId, + ), + ) } s.logger.Debug(fmt.Sprintf("found %d run ids for job in s3", len(runIDs))) return runIDs[0], nil @@ -387,10 +434,17 @@ func (s *AwsS3ConnectionDataService) GetTableConstraints( return nil, errors.ErrUnsupported } -func (s *AwsS3ConnectionDataService) GetTableSchema(ctx context.Context, schema, table string) ([]*mgmtv1alpha1.DatabaseColumn, error) { +func (s *AwsS3ConnectionDataService) GetTableSchema( + ctx context.Context, + schema, table string, +) ([]*mgmtv1alpha1.DatabaseColumn, error) { return nil, errors.ErrUnsupported } -func (s *AwsS3ConnectionDataService) GetTableRowCount(ctx context.Context, schema, table string, whereClause *string) (int64, error) { +func (s *AwsS3ConnectionDataService) GetTableRowCount( + ctx context.Context, + schema, table string, + whereClause *string, +) (int64, error) { return 0, errors.ErrUnsupported } diff --git a/internal/connectiondata/connectiondata.go b/internal/connectiondata/connectiondata.go index fab0358814..fd7d16de3d 100644 --- a/internal/connectiondata/connectiondata.go +++ b/internal/connectiondata/connectiondata.go @@ -28,10 +28,21 @@ type ConnectionDataService interface { StreamConfig *mgmtv1alpha1.ConnectionStreamConfig, schema, table string, ) error - GetSchema(ctx context.Context, config *mgmtv1alpha1.ConnectionSchemaConfig) ([]*mgmtv1alpha1.DatabaseColumn, error) - GetInitStatements(ctx context.Context, options *mgmtv1alpha1.InitStatementOptions) (*mgmtv1alpha1.GetConnectionInitStatementsResponse, error) - GetTableConstraints(ctx context.Context) (*mgmtv1alpha1.GetConnectionTableConstraintsResponse, error) - GetTableSchema(ctx context.Context, schema, table string) ([]*mgmtv1alpha1.DatabaseColumn, error) + GetSchema( + ctx context.Context, + config *mgmtv1alpha1.ConnectionSchemaConfig, + ) ([]*mgmtv1alpha1.DatabaseColumn, error) + GetInitStatements( + ctx context.Context, + options *mgmtv1alpha1.InitStatementOptions, + ) (*mgmtv1alpha1.GetConnectionInitStatementsResponse, error) + GetTableConstraints( + ctx context.Context, + ) (*mgmtv1alpha1.GetConnectionTableConstraintsResponse, error) + GetTableSchema( + ctx context.Context, + schema, table string, + ) ([]*mgmtv1alpha1.DatabaseColumn, error) GetTableRowCount(ctx context.Context, schema, table string, whereClause *string) (int64, error) SampleData( ctx context.Context, @@ -49,7 +60,10 @@ type TableIdentifier struct { } type ConnectionDataBuilder interface { - NewDataConnection(logger *slog.Logger, connection *mgmtv1alpha1.Connection) (ConnectionDataService, error) + NewDataConnection( + logger *slog.Logger, + connection *mgmtv1alpha1.Connection, + ) (ConnectionDataService, error) } type DefaultConnectionDataBuilder struct { diff --git a/internal/connectiondata/dynamodb.go b/internal/connectiondata/dynamodb.go index 3966fafba5..281df252c5 100644 --- a/internal/connectiondata/dynamodb.go +++ b/internal/connectiondata/dynamodb.go @@ -35,7 +35,9 @@ func NewAwsDynamodbConnectionDataService( } } -func (s *AwsDynamodbConnectionDataService) GetAllTables(ctx context.Context) ([]TableIdentifier, error) { +func (s *AwsDynamodbConnectionDataService) GetAllTables( + ctx context.Context, +) ([]TableIdentifier, error) { return nil, errors.ErrUnsupported } @@ -129,11 +131,18 @@ func (s *AwsDynamodbConnectionDataService) GetTableConstraints( return nil, errors.ErrUnsupported } -func (s *AwsDynamodbConnectionDataService) GetTableSchema(ctx context.Context, schema, table string) ([]*mgmtv1alpha1.DatabaseColumn, error) { +func (s *AwsDynamodbConnectionDataService) GetTableSchema( + ctx context.Context, + schema, table string, +) ([]*mgmtv1alpha1.DatabaseColumn, error) { return nil, errors.ErrUnsupported } -func (s *AwsDynamodbConnectionDataService) GetTableRowCount(ctx context.Context, schema, table string, whereClause *string) (int64, error) { +func (s *AwsDynamodbConnectionDataService) GetTableRowCount( + ctx context.Context, + schema, table string, + whereClause *string, +) (int64, error) { return 0, errors.ErrUnsupported } diff --git a/internal/connectiondata/gcp.go b/internal/connectiondata/gcp.go index 2a6dd1edb5..bdabc05762 100644 --- a/internal/connectiondata/gcp.go +++ b/internal/connectiondata/gcp.go @@ -61,7 +61,9 @@ func (s *GcpConnectionDataService) StreamData( ) error { gcpStreamCfg := config.GetGcpCloudstorageConfig() if gcpStreamCfg == nil { - return nucleuserrors.NewBadRequest("must provide non-nil gcp cloud storage config in request") + return nucleuserrors.NewBadRequest( + "must provide non-nil gcp cloud storage config in request", + ) } gcpclient, err := s.gcpmanager.GetClient(ctx, s.logger) if err != nil { @@ -88,9 +90,15 @@ func (s *GcpConnectionDataService) StreamData( if err := enc.Encode(record); err != nil { return fmt.Errorf("unable to encode gcp record using gob: %w", err) } - return stream.Send(&mgmtv1alpha1.GetConnectionDataStreamResponse{RowBytes: rowbytes.Bytes()}) + return stream.Send( + &mgmtv1alpha1.GetConnectionDataStreamResponse{RowBytes: rowbytes.Bytes()}, + ) } - tablePath := neosync_gcp.GetWorkflowActivityDataPrefix(jobRunId, sqlmanager_shared.BuildTable(schema, table), s.connconfig.PathPrefix) + tablePath := neosync_gcp.GetWorkflowActivityDataPrefix( + jobRunId, + sqlmanager_shared.BuildTable(schema, table), + s.connconfig.PathPrefix, + ) err = gcpclient.GetRecordStreamFromPrefix(ctx, s.connconfig.GetBucket(), tablePath, onRecord) if err != nil { return fmt.Errorf("unable to finish sending record stream: %w", err) @@ -128,7 +136,8 @@ func (s *GcpConnectionDataService) GetSchema( schemas, err := gcpclient.GetDbSchemaFromPrefix( ctx, - s.connconfig.GetBucket(), neosync_gcp.GetWorkflowActivityPrefix(jobRunId, s.connconfig.PathPrefix), + s.connconfig.GetBucket(), + neosync_gcp.GetWorkflowActivityPrefix(jobRunId, s.connconfig.PathPrefix), ) if err != nil { return nil, fmt.Errorf("uanble to retrieve db schema from gcs: %w", err) @@ -199,10 +208,17 @@ func (s *GcpConnectionDataService) GetTableConstraints( return nil, errors.ErrUnsupported } -func (s *GcpConnectionDataService) GetTableSchema(ctx context.Context, schema, table string) ([]*mgmtv1alpha1.DatabaseColumn, error) { +func (s *GcpConnectionDataService) GetTableSchema( + ctx context.Context, + schema, table string, +) ([]*mgmtv1alpha1.DatabaseColumn, error) { return nil, errors.ErrUnsupported } -func (s *GcpConnectionDataService) GetTableRowCount(ctx context.Context, schema, table string, whereClause *string) (int64, error) { +func (s *GcpConnectionDataService) GetTableRowCount( + ctx context.Context, + schema, table string, + whereClause *string, +) (int64, error) { return 0, errors.ErrUnsupported } diff --git a/internal/connectiondata/mongodb.go b/internal/connectiondata/mongodb.go index 976d1d0d45..3c747db99d 100644 --- a/internal/connectiondata/mongodb.go +++ b/internal/connectiondata/mongodb.go @@ -31,7 +31,9 @@ func NewMongoDbConnectionDataService( } } -func (s *MongoDbConnectionDataService) GetAllTables(ctx context.Context) ([]TableIdentifier, error) { +func (s *MongoDbConnectionDataService) GetAllTables( + ctx context.Context, +) ([]TableIdentifier, error) { return nil, errors.ErrUnsupported } @@ -103,10 +105,17 @@ func (s *MongoDbConnectionDataService) GetTableConstraints( return nil, errors.ErrUnsupported } -func (s *MongoDbConnectionDataService) GetTableSchema(ctx context.Context, schema, table string) ([]*mgmtv1alpha1.DatabaseColumn, error) { +func (s *MongoDbConnectionDataService) GetTableSchema( + ctx context.Context, + schema, table string, +) ([]*mgmtv1alpha1.DatabaseColumn, error) { return nil, errors.ErrUnsupported } -func (s *MongoDbConnectionDataService) GetTableRowCount(ctx context.Context, schema, table string, whereClause *string) (int64, error) { +func (s *MongoDbConnectionDataService) GetTableRowCount( + ctx context.Context, + schema, table string, + whereClause *string, +) (int64, error) { return 0, errors.ErrUnsupported } diff --git a/internal/connectiondata/sql.go b/internal/connectiondata/sql.go index c559818918..3658e2acfe 100644 --- a/internal/connectiondata/sql.go +++ b/internal/connectiondata/sql.go @@ -46,7 +46,12 @@ func NewSQLConnectionDataService( } func (s *SQLConnectionDataService) GetAllSchemas(ctx context.Context) ([]string, error) { - db, err := s.sqlmanager.NewSqlConnection(ctx, connectionmanager.NewUniqueSession(), s.connection, s.logger) + db, err := s.sqlmanager.NewSqlConnection( + ctx, + connectionmanager.NewUniqueSession(), + s.connection, + s.logger, + ) if err != nil { return nil, err } @@ -64,7 +69,12 @@ func (s *SQLConnectionDataService) GetAllSchemas(ctx context.Context) ([]string, } func (s *SQLConnectionDataService) GetAllTables(ctx context.Context) ([]TableIdentifier, error) { - db, err := s.sqlmanager.NewSqlConnection(ctx, connectionmanager.NewUniqueSession(), s.connection, s.logger) + db, err := s.sqlmanager.NewSqlConnection( + ctx, + connectionmanager.NewUniqueSession(), + s.connection, + s.logger, + ) if err != nil { return nil, err } @@ -101,7 +111,11 @@ func (s *SQLConnectionDataService) SampleData( return fmt.Errorf("invalid schema or table: %w", err) } - conn, err := s.sqlconnector.NewDbFromConnectionConfig(s.connconfig, s.logger, sqlconnect.WithConnectionTimeout(uint32(5))) + conn, err := s.sqlconnector.NewDbFromConnectionConfig( + s.connconfig, + s.logger, + sqlconnect.WithConnectionTimeout(uint32(5)), + ) if err != nil { return fmt.Errorf("error creating connection: %w", err) } @@ -128,19 +142,34 @@ func (s *SQLConnectionDataService) SampleData( } rows, err := db.QueryContext(ctx, query) if err != nil && !neosyncdb.IsNoRows(err) { - return fmt.Errorf("error querying table %s with database type %s: %w", schemaTable, goquDriver, err) + return fmt.Errorf( + "error querying table %s with database type %s: %w", + schemaTable, + goquDriver, + err, + ) } defer rows.Close() for rows.Next() { r, err := mapper.MapRecord(rows) if err != nil { - return fmt.Errorf("unable to convert row to map for table %s with database type %s: %w", schemaTable, goquDriver, err) + return fmt.Errorf( + "unable to convert row to map for table %s with database type %s: %w", + schemaTable, + goquDriver, + err, + ) } var rowbytes bytes.Buffer enc := gob.NewEncoder(&rowbytes) if err := enc.Encode(r); err != nil { - return fmt.Errorf("unable to encode row for table %s with database type %s: %w", schemaTable, goquDriver, err) + return fmt.Errorf( + "unable to encode row for table %s with database type %s: %w", + schemaTable, + goquDriver, + err, + ) } if err := stream.Send(&mgmtv1alpha1.GetConnectionDataStreamResponse{RowBytes: rowbytes.Bytes()}); err != nil { return err @@ -160,7 +189,11 @@ func (s *SQLConnectionDataService) StreamData( return err } - conn, err := s.sqlconnector.NewDbFromConnectionConfig(s.connconfig, s.logger, sqlconnect.WithConnectionTimeout(uint32(5))) + conn, err := s.sqlconnector.NewDbFromConnectionConfig( + s.connconfig, + s.logger, + sqlconnect.WithConnectionTimeout(uint32(5)), + ) if err != nil { return err } @@ -188,12 +221,22 @@ func (s *SQLConnectionDataService) StreamData( } r, err := db.QueryContext(ctx, query) if err != nil && !neosyncdb.IsNoRows(err) { - return fmt.Errorf("error querying table %s with database type %s: %w", schemaTable, goquDriver, err) + return fmt.Errorf( + "error querying table %s with database type %s: %w", + schemaTable, + goquDriver, + err, + ) } columnNames, err := r.Columns() if err != nil { - return fmt.Errorf("unable to get column names from table %s with database type %s: %w", schemaTable, goquDriver, err) + return fmt.Errorf( + "unable to get column names from table %s with database type %s: %w", + schemaTable, + goquDriver, + err, + ) } selectQuery, err := querybuilder.BuildSelectQuery(goquDriver, schemaTable, columnNames, nil) @@ -202,19 +245,34 @@ func (s *SQLConnectionDataService) StreamData( } rows, err := db.QueryContext(ctx, selectQuery) if err != nil && !neosyncdb.IsNoRows(err) { - return fmt.Errorf("error querying table %s with goqu driver %s: %w", schemaTable, goquDriver, err) + return fmt.Errorf( + "error querying table %s with goqu driver %s: %w", + schemaTable, + goquDriver, + err, + ) } // todo: rows.Close needs to be called here? for rows.Next() { r, err := mapper.MapRecord(rows) if err != nil { - return fmt.Errorf("unable to convert row to map for table %s with database type %s: %w", schemaTable, goquDriver, err) + return fmt.Errorf( + "unable to convert row to map for table %s with database type %s: %w", + schemaTable, + goquDriver, + err, + ) } var rowbytes bytes.Buffer enc := gob.NewEncoder(&rowbytes) if err := enc.Encode(r); err != nil { - return fmt.Errorf("unable to encode row for table %s with database type %s: %w", schemaTable, goquDriver, err) + return fmt.Errorf( + "unable to encode row for table %s with database type %s: %w", + schemaTable, + goquDriver, + err, + ) } if err := stream.Send(&mgmtv1alpha1.GetConnectionDataStreamResponse{RowBytes: rowbytes.Bytes()}); err != nil { return err @@ -223,8 +281,16 @@ func (s *SQLConnectionDataService) StreamData( return nil } -func (s *SQLConnectionDataService) GetSchema(ctx context.Context, config *mgmtv1alpha1.ConnectionSchemaConfig) ([]*mgmtv1alpha1.DatabaseColumn, error) { - db, err := s.sqlmanager.NewSqlConnection(ctx, connectionmanager.NewUniqueSession(), s.connection, s.logger) +func (s *SQLConnectionDataService) GetSchema( + ctx context.Context, + config *mgmtv1alpha1.ConnectionSchemaConfig, +) ([]*mgmtv1alpha1.DatabaseColumn, error) { + db, err := s.sqlmanager.NewSqlConnection( + ctx, + connectionmanager.NewUniqueSession(), + s.connection, + s.logger, + ) if err != nil { return nil, err } @@ -270,7 +336,12 @@ func (s *SQLConnectionDataService) GetInitStatements( schemaTableMap[sqlmanager_shared.BuildTable(s.Schema, s.Table)] = s } - db, err := s.sqlmanager.NewSqlConnection(ctx, connectionmanager.NewUniqueSession(), s.connection, s.logger) + db, err := s.sqlmanager.NewSqlConnection( + ctx, + connectionmanager.NewUniqueSession(), + s.connection, + s.logger, + ) if err != nil { return nil, err } @@ -281,7 +352,10 @@ func (s *SQLConnectionDataService) GetInitStatements( if options.GetInitSchema() { tables := []*sqlmanager_shared.SchemaTable{} for _, v := range schemaTableMap { - tables = append(tables, &sqlmanager_shared.SchemaTable{Schema: v.Schema, Table: v.Table}) + tables = append( + tables, + &sqlmanager_shared.SchemaTable{Schema: v.Schema, Table: v.Table}, + ) } initBlocks, err := db.Db().GetSchemaInitStatements(ctx, tables) if err != nil { @@ -348,7 +422,12 @@ func (s *SQLConnectionDataService) GetTableConstraints( schemas = append(schemas, s) } - db, err := s.sqlmanager.NewSqlConnection(ctx, connectionmanager.NewUniqueSession(), s.connection, s.logger) + db, err := s.sqlmanager.NewSqlConnection( + ctx, + connectionmanager.NewUniqueSession(), + s.connection, + s.logger, + ) if err != nil { return nil, err } @@ -364,12 +443,15 @@ func (s *SQLConnectionDataService) GetTableConstraints( Constraints: []*mgmtv1alpha1.ForeignConstraint{}, } for _, constraint := range d { - fkConstraintsMap[tableName].Constraints = append(fkConstraintsMap[tableName].Constraints, &mgmtv1alpha1.ForeignConstraint{ - Columns: constraint.Columns, NotNullable: constraint.NotNullable, ForeignKey: &mgmtv1alpha1.ForeignKey{ - Table: constraint.ForeignKey.Table, - Columns: constraint.ForeignKey.Columns, + fkConstraintsMap[tableName].Constraints = append( + fkConstraintsMap[tableName].Constraints, + &mgmtv1alpha1.ForeignConstraint{ + Columns: constraint.Columns, NotNullable: constraint.NotNullable, ForeignKey: &mgmtv1alpha1.ForeignKey{ + Table: constraint.ForeignKey.Table, + Columns: constraint.ForeignKey.Columns, + }, }, - }) + ) } } @@ -386,9 +468,12 @@ func (s *SQLConnectionDataService) GetTableConstraints( Constraints: []*mgmtv1alpha1.UniqueConstraint{}, } for _, uc := range uniqueConstraints { - uniqueConstraintsMap[table].Constraints = append(uniqueConstraintsMap[table].Constraints, &mgmtv1alpha1.UniqueConstraint{ - Columns: uc, - }) + uniqueConstraintsMap[table].Constraints = append( + uniqueConstraintsMap[table].Constraints, + &mgmtv1alpha1.UniqueConstraint{ + Columns: uc, + }, + ) } } @@ -398,9 +483,12 @@ func (s *SQLConnectionDataService) GetTableConstraints( Indexes: []*mgmtv1alpha1.UniqueIndex{}, } for _, ui := range uniqueIndexes { - uniqueIndexesMap[table].Indexes = append(uniqueIndexesMap[table].Indexes, &mgmtv1alpha1.UniqueIndex{ - Columns: ui, - }) + uniqueIndexesMap[table].Indexes = append( + uniqueIndexesMap[table].Indexes, + &mgmtv1alpha1.UniqueIndex{ + Columns: ui, + }, + ) } } @@ -412,14 +500,23 @@ func (s *SQLConnectionDataService) GetTableConstraints( }, nil } -func (s *SQLConnectionDataService) GetTableSchema(ctx context.Context, schema, table string) ([]*mgmtv1alpha1.DatabaseColumn, error) { - db, err := s.sqlmanager.NewSqlConnection(ctx, connectionmanager.NewUniqueSession(), s.connection, s.logger) +func (s *SQLConnectionDataService) GetTableSchema( + ctx context.Context, + schema, table string, +) ([]*mgmtv1alpha1.DatabaseColumn, error) { + db, err := s.sqlmanager.NewSqlConnection( + ctx, + connectionmanager.NewUniqueSession(), + s.connection, + s.logger, + ) if err != nil { return nil, err } defer db.Db().Close() schematable := &sqlmanager_shared.SchemaTable{Schema: schema, Table: table} - dbschema, err := db.Db().GetDatabaseTableSchemasBySchemasAndTables(ctx, []*sqlmanager_shared.SchemaTable{schematable}) + dbschema, err := db.Db(). + GetDatabaseTableSchemasBySchemasAndTables(ctx, []*sqlmanager_shared.SchemaTable{schematable}) if err != nil { return nil, err } @@ -440,8 +537,17 @@ func (s *SQLConnectionDataService) GetTableSchema(ctx context.Context, schema, t return schemas, nil } -func (s *SQLConnectionDataService) GetTableRowCount(ctx context.Context, schema, table string, whereClause *string) (int64, error) { - db, err := s.sqlmanager.NewSqlConnection(ctx, connectionmanager.NewUniqueSession(), s.connection, s.logger) +func (s *SQLConnectionDataService) GetTableRowCount( + ctx context.Context, + schema, table string, + whereClause *string, +) (int64, error) { + db, err := s.sqlmanager.NewSqlConnection( + ctx, + connectionmanager.NewUniqueSession(), + s.connection, + s.logger, + ) if err != nil { return 0, err } @@ -449,7 +555,10 @@ func (s *SQLConnectionDataService) GetTableRowCount(ctx context.Context, schema, return db.Db().GetTableRowCount(ctx, schema, table, whereClause) } -func (s *SQLConnectionDataService) areSchemaAndTableValid(ctx context.Context, schema, table string) error { +func (s *SQLConnectionDataService) areSchemaAndTableValid( + ctx context.Context, + schema, table string, +) error { schemas, err := s.GetTableSchema(ctx, schema, table) if err != nil { return err diff --git a/internal/connectrpc/interceptors/retry/interceptor.go b/internal/connectrpc/interceptors/retry/interceptor.go index a527957c4f..ca70850c2e 100644 --- a/internal/connectrpc/interceptors/retry/interceptor.go +++ b/internal/connectrpc/interceptors/retry/interceptor.go @@ -31,7 +31,13 @@ func DefaultRetryInterceptor(logger *slog.Logger) *Interceptor { backoff.WithMaxTries(10), backoff.WithMaxElapsedTime(1 * time.Minute), backoff.WithNotify(func(err error, d time.Duration) { - logger.Warn(fmt.Sprintf("error with retry: %s, retrying in %s", err.Error(), d.String())) + logger.Warn( + fmt.Sprintf( + "error with retry: %s, retrying in %s", + err.Error(), + d.String(), + ), + ) }), } }), @@ -75,7 +81,9 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { } } -func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { +func (i *Interceptor) WrapStreamingClient( + next connect.StreamingClientFunc, +) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { conn := next(ctx, spec) @@ -138,7 +146,9 @@ func (r *retryStreamingClientConn) Receive(msg any) error { return unwrapPermanentError(err) } -func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (i *Interceptor) WrapStreamingHandler( + next connect.StreamingHandlerFunc, +) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { operation := func() (any, error) { err := next(ctx, conn) diff --git a/internal/connectrpc/validate/validate.go b/internal/connectrpc/validate/validate.go index a47ab1f213..74ff4c3cec 100644 --- a/internal/connectrpc/validate/validate.go +++ b/internal/connectrpc/validate/validate.go @@ -96,7 +96,9 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { } // WrapStreamingClient implements connect.Interceptor. -func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { +func (i *Interceptor) WrapStreamingClient( + next connect.StreamingClientFunc, +) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { return &streamingClientInterceptor{ StreamingClientConn: next(ctx, spec), @@ -106,7 +108,9 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn } // WrapStreamingHandler implements connect.Interceptor. -func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (i *Interceptor) WrapStreamingHandler( + next connect.StreamingHandlerFunc, +) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { return next(ctx, &streamingHandlerInterceptor{ StreamingHandlerConn: conn, diff --git a/internal/database-record-mapper/builder/builder.go b/internal/database-record-mapper/builder/builder.go index bb79de9c7f..3885f2bb57 100644 --- a/internal/database-record-mapper/builder/builder.go +++ b/internal/database-record-mapper/builder/builder.go @@ -13,7 +13,9 @@ type DatabaseRecordMapper[T any] interface { MapRecord(record T) (map[string]any, error) // deprecated - use MapRecord instead with neosync types - MapRecordWithKeyType(record T) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) + MapRecordWithKeyType( + record T, + ) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) } type Builder[T any] struct { @@ -29,7 +31,9 @@ func (b *Builder[T]) MapRecord(record any) (map[string]any, error) { } // deprecated - use MapRecord instead with neosync types -func (b *Builder[T]) MapRecordWithKeyType(record any) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { +func (b *Builder[T]) MapRecordWithKeyType( + record any, +) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { typedRecord, ok := record.(T) if !ok { return nil, nil, fmt.Errorf("invalid record type: expected %T, got %T", *new(T), record) diff --git a/internal/database-record-mapper/database-record-mapper.go b/internal/database-record-mapper/database-record-mapper.go index 582f6ae997..f91da4dd01 100644 --- a/internal/database-record-mapper/database-record-mapper.go +++ b/internal/database-record-mapper/database-record-mapper.go @@ -30,7 +30,9 @@ func NewDatabaseRecordMapper(dbType string) (builder.DatabaseRecordMapper[any], } } -func NewDatabaseRecordMapperFromConnection(connection *mgmtv1alpha1.Connection) (builder.DatabaseRecordMapper[any], error) { +func NewDatabaseRecordMapperFromConnection( + connection *mgmtv1alpha1.Connection, +) (builder.DatabaseRecordMapper[any], error) { switch connection.GetConnectionConfig().GetConfig().(type) { case *mgmtv1alpha1.ConnectionConfig_PgConfig: return NewDatabaseRecordMapper(sqlmanager_shared.PostgresDriver) diff --git a/internal/database-record-mapper/dynamodb/mapper.go b/internal/database-record-mapper/dynamodb/mapper.go index 89644d4389..d5992df85e 100644 --- a/internal/database-record-mapper/dynamodb/mapper.go +++ b/internal/database-record-mapper/dynamodb/mapper.go @@ -22,7 +22,9 @@ func (m *DynamoDBMapper) MapRecord(item map[string]types.AttributeValue) (map[st return nil, errors.ErrUnsupported } -func (m *DynamoDBMapper) MapRecordWithKeyType(item map[string]types.AttributeValue) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { +func (m *DynamoDBMapper) MapRecordWithKeyType( + item map[string]types.AttributeValue, +) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { standardJSON := make(map[string]any) ktm := make(map[string]neosync_types.KeyType) for k, v := range item { @@ -36,7 +38,11 @@ func (m *DynamoDBMapper) MapRecordWithKeyType(item map[string]types.AttributeVal } // ParseAttributeValue converts a DynamoDB AttributeValue to a standard value -func parseAttributeValue(key string, v types.AttributeValue, keyTypeMap map[string]neosync_types.KeyType) (any, error) { +func parseAttributeValue( + key string, + v types.AttributeValue, + keyTypeMap map[string]neosync_types.KeyType, +) (any, error) { switch t := v.(type) { case *types.AttributeValueMemberB: return t.Value, nil diff --git a/internal/database-record-mapper/mongodb/mapper.go b/internal/database-record-mapper/mongodb/mapper.go index 6388babd53..f1ed42bc8c 100644 --- a/internal/database-record-mapper/mongodb/mapper.go +++ b/internal/database-record-mapper/mongodb/mapper.go @@ -23,7 +23,9 @@ func (m *MongoDBMapper) MapRecord(item map[string]any) (map[string]any, error) { return nil, errors.ErrUnsupported } -func (m *MongoDBMapper) MapRecordWithKeyType(item map[string]any) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { +func (m *MongoDBMapper) MapRecordWithKeyType( + item map[string]any, +) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { result := make(map[string]any) ktm := make(map[string]neosync_types.KeyType) for k, v := range item { @@ -36,7 +38,11 @@ func (m *MongoDBMapper) MapRecordWithKeyType(item map[string]any) (valuemap map[ return result, ktm, nil } -func parsePrimitives(key string, value any, keyTypeMap map[string]neosync_types.KeyType) (any, error) { +func parsePrimitives( + key string, + value any, + keyTypeMap map[string]neosync_types.KeyType, +) (any, error) { switch v := value.(type) { case primitive.Decimal128: keyTypeMap[key] = neosync_types.Decimal128 diff --git a/internal/database-record-mapper/mssql/mapper.go b/internal/database-record-mapper/mssql/mapper.go index 5d0175d8fc..2a7c73e85d 100644 --- a/internal/database-record-mapper/mssql/mapper.go +++ b/internal/database-record-mapper/mssql/mapper.go @@ -21,7 +21,9 @@ func NewMSSQLBuilder() *builder.Builder[*sql.Rows] { } } -func (m *MSSQLMapper) MapRecordWithKeyType(rows *sql.Rows) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { +func (m *MSSQLMapper) MapRecordWithKeyType( + rows *sql.Rows, +) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { return nil, nil, errors.ErrUnsupported } diff --git a/internal/database-record-mapper/mysql/mapper.go b/internal/database-record-mapper/mysql/mapper.go index fa08fd9519..979281efb9 100644 --- a/internal/database-record-mapper/mysql/mapper.go +++ b/internal/database-record-mapper/mysql/mapper.go @@ -21,7 +21,9 @@ func NewMySQLBuilder() *builder.Builder[*sql.Rows] { } } -func (m *MySQLMapper) MapRecordWithKeyType(rows *sql.Rows) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { +func (m *MySQLMapper) MapRecordWithKeyType( + rows *sql.Rows, +) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { return nil, nil, errors.ErrUnsupported } @@ -55,7 +57,10 @@ func (m *MySQLMapper) MapRecord(rows *sql.Rows) (map[string]any, error) { return jObj, nil } -func parseMysqlRowValues(values []any, columnNames, columnDbTypes []string) (map[string]any, error) { +func parseMysqlRowValues( + values []any, + columnNames, columnDbTypes []string, +) (map[string]any, error) { jObj := map[string]any{} for i, v := range values { col := columnNames[i] diff --git a/internal/database-record-mapper/postgres/mapper.go b/internal/database-record-mapper/postgres/mapper.go index e62cd9043c..1616a8c511 100644 --- a/internal/database-record-mapper/postgres/mapper.go +++ b/internal/database-record-mapper/postgres/mapper.go @@ -24,7 +24,9 @@ func NewPostgresBuilder() *builder.Builder[*sql.Rows] { } } -func (m *PostgresMapper) MapRecordWithKeyType(rows *sql.Rows) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { +func (m *PostgresMapper) MapRecordWithKeyType( + rows *sql.Rows, +) (valuemap map[string]any, typemap map[string]neosync_types.KeyType, err error) { return nil, nil, errors.ErrUnsupported } @@ -272,7 +274,10 @@ func (a *PgxArray[T]) Scan(src any) error { pgt, ok = m.TypeForName(strings.ToLower(a.colDataType)) } if !ok { - return fmt.Errorf("cannot convert to sql.Scanner: cannot find registered type for %s", a.colDataType) + return fmt.Errorf( + "cannot convert to sql.Scanner: cannot find registered type for %s", + a.colDataType, + ) } v := &a.Array @@ -337,7 +342,10 @@ func toBinaryArray(array *PgxArray[[]byte]) (*neosynctypes.NeosyncArray, error) return nil, errors.ErrUnsupported } - binaryArray, err := neosynctypes.NewBinaryArrayFromPgx(array.Elements, []neosynctypes.NeosyncTypeOption{}) + binaryArray, err := neosynctypes.NewBinaryArrayFromPgx( + array.Elements, + []neosynctypes.NeosyncTypeOption{}, + ) if err != nil { return nil, err } @@ -354,7 +362,10 @@ func toBitsArray(array *PgxArray[*pgtype.Bits]) (*neosynctypes.NeosyncArray, err return nil, errors.ErrUnsupported } - bitsArray, err := neosynctypes.NewBitsArrayFromPgx(array.Elements, []neosynctypes.NeosyncTypeOption{}) + bitsArray, err := neosynctypes.NewBitsArrayFromPgx( + array.Elements, + []neosynctypes.NeosyncTypeOption{}, + ) if err != nil { return nil, err } @@ -371,7 +382,10 @@ func toIntervalArray(array *PgxArray[*pgtype.Interval]) (*neosynctypes.NeosyncAr return nil, errors.ErrUnsupported } - neoIntervalArray, err := neosynctypes.NewIntervalArrayFromPgx(array.Elements, []neosynctypes.NeosyncTypeOption{}) + neoIntervalArray, err := neosynctypes.NewIntervalArrayFromPgx( + array.Elements, + []neosynctypes.NeosyncTypeOption{}, + ) if err != nil { return nil, err } @@ -392,7 +406,9 @@ func pgArrayToGoSlice(array *PgxArray[any]) (any, error) { return createMultiDimSlice(dims, array.Elements), nil } - if strings.EqualFold(array.colDataType, "timestamp") || strings.EqualFold(array.colDataType, "date") || strings.EqualFold(array.colDataType, "timestampz") { + if strings.EqualFold(array.colDataType, "timestamp") || + strings.EqualFold(array.colDataType, "date") || + strings.EqualFold(array.colDataType, "timestampz") { timeArray := make([]time.Time, len(array.Elements)) for i, elem := range array.Elements { if t, ok := elem.(time.Time); ok { @@ -401,7 +417,10 @@ func pgArrayToGoSlice(array *PgxArray[any]) (any, error) { return nil, fmt.Errorf("expected time.Time, got %T", elem) } } - dtArray, err := neosynctypes.NewDateTimeArrayFromPgx(timeArray, []neosynctypes.NeosyncTypeOption{}) + dtArray, err := neosynctypes.NewDateTimeArrayFromPgx( + timeArray, + []neosynctypes.NeosyncTypeOption{}, + ) if err != nil { return nil, err } diff --git a/internal/ee/cloud-license/license.go b/internal/ee/cloud-license/license.go index 7044724f30..b2ed7da49c 100644 --- a/internal/ee/cloud-license/license.go +++ b/internal/ee/cloud-license/license.go @@ -84,7 +84,11 @@ func getFromEnv() (*licenseContents, bool, error) { input := viper.GetString(cloudLicenseEvKey) if input == "" { - return nil, false, fmt.Errorf("%s was true but no license was found at %s", cloudEnabledEvKey, cloudLicenseEvKey) + return nil, false, fmt.Errorf( + "%s was true but no license was found at %s", + cloudEnabledEvKey, + cloudLicenseEvKey, + ) } pk, err := parsePublicKey(publicKeyPEM) if err != nil { @@ -126,7 +130,10 @@ func getLicense(licenseData string, publicKey ed25519.PublicKey) (*licenseConten var lc licenseContents err = json.Unmarshal(contents, &lc) if err != nil { - return nil, fmt.Errorf("contents verified, but unable to unmarshal license contents from input: %w", err) + return nil, fmt.Errorf( + "contents verified, but unable to unmarshal license contents from input: %w", + err, + ) } return &lc, nil diff --git a/internal/ee/license/license.go b/internal/ee/license/license.go index 1726340871..7214c2f27d 100644 --- a/internal/ee/license/license.go +++ b/internal/ee/license/license.go @@ -124,7 +124,10 @@ func getLicense(licenseData string, publicKey ed25519.PublicKey) (*licenseConten var lc licenseContents err = json.Unmarshal(contents, &lc) if err != nil { - return nil, fmt.Errorf("contents verified, but unable to unmarshal license contents from input: %w", err) + return nil, fmt.Errorf( + "contents verified, but unable to unmarshal license contents from input: %w", + err, + ) } return &lc, nil diff --git a/internal/ee/mssql-manager/ee-mssql-manager.go b/internal/ee/mssql-manager/ee-mssql-manager.go index 2f2c61b469..d7b00b3bce 100644 --- a/internal/ee/mssql-manager/ee-mssql-manager.go +++ b/internal/ee/mssql-manager/ee-mssql-manager.go @@ -27,11 +27,19 @@ type Manager struct { logger *slog.Logger } -func NewManager(querier mssql_queries.Querier, db mysql_queries.DBTX, closer func(), logger *slog.Logger) *Manager { +func NewManager( + querier mssql_queries.Querier, + db mysql_queries.DBTX, + closer func(), + logger *slog.Logger, +) *Manager { return &Manager{querier: querier, db: db, close: closer, logger: logger} } -func (m *Manager) GetTableInitStatements(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableInitStatement, error) { +func (m *Manager) GetTableInitStatements( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.TableInitStatement, error) { if len(tables) == 0 { return []*sqlmanager_shared.TableInitStatement{}, nil } @@ -51,12 +59,19 @@ func (m *Manager) GetTableInitStatements(ctx context.Context, tables []*sqlmanag colDefMap := map[string][]*mssql_queries.GetDatabaseTableSchemasBySchemasAndTablesRow{} errgrp.Go(func() error { - columnDefs, err := m.querier.GetDatabaseTableSchemasBySchemasAndTables(errctx, m.db, combined) + columnDefs, err := m.querier.GetDatabaseTableSchemasBySchemasAndTables( + errctx, + m.db, + combined, + ) if err != nil { return err } for _, columnDefinition := range columnDefs { - key := sqlmanager_shared.SchemaTable{Schema: columnDefinition.TableSchema, Table: columnDefinition.TableName} + key := sqlmanager_shared.SchemaTable{ + Schema: columnDefinition.TableSchema, + Table: columnDefinition.TableName, + } colDefMap[key.String()] = append(colDefMap[key.String()], columnDefinition) } return nil @@ -64,12 +79,19 @@ func (m *Manager) GetTableInitStatements(ctx context.Context, tables []*sqlmanag constraintmap := map[string][]*mssql_queries.GetTableConstraintsBySchemasRow{} errgrp.Go(func() error { - constraints, err := m.querier.GetTableConstraintsBySchemas(errctx, m.db, schemas) // todo: update this to only grab what is necessary instead of entire schema + constraints, err := m.querier.GetTableConstraintsBySchemas( + errctx, + m.db, + schemas, + ) // todo: update this to only grab what is necessary instead of entire schema if err != nil { return err } for _, constraint := range constraints { - key := sqlmanager_shared.SchemaTable{Schema: constraint.SchemaName, Table: constraint.TableName} + key := sqlmanager_shared.SchemaTable{ + Schema: constraint.SchemaName, + Table: constraint.TableName, + } constraintmap[key.String()] = append(constraintmap[key.String()], constraint) } return nil @@ -83,7 +105,10 @@ func (m *Manager) GetTableInitStatements(ctx context.Context, tables []*sqlmanag } for _, record := range idxrecords { key := sqlmanager_shared.SchemaTable{Schema: record.SchemaName, Table: record.TableName} - indexmap[key.String()] = append(indexmap[key.String()], generateCreateIndexStatement(record)) + indexmap[key.String()] = append( + indexmap[key.String()], + generateCreateIndexStatement(record), + ) } return nil }) @@ -112,14 +137,19 @@ func (m *Manager) GetTableInitStatements(ctx context.Context, tables []*sqlmanag continue } stmt := generateAddConstraintStatement(constraint) - constraintType, err := sqlmanager_shared.ToConstraintType(toStandardConstraintType(constraint.ConstraintType)) + constraintType, err := sqlmanager_shared.ToConstraintType( + toStandardConstraintType(constraint.ConstraintType), + ) if err != nil { return nil, err } - info.AlterTableStatements = append(info.AlterTableStatements, &sqlmanager_shared.AlterTableStatement{ - Statement: stmt, - ConstraintType: constraintType, - }) + info.AlterTableStatements = append( + info.AlterTableStatements, + &sqlmanager_shared.AlterTableStatement{ + Statement: stmt, + ConstraintType: constraintType, + }, + ) } output = append(output, info) } @@ -141,7 +171,10 @@ func toStandardConstraintType(constraintType string) string { } } -func (m *Manager) GetSchemaInitStatements(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.InitSchemaStatements, error) { +func (m *Manager) GetSchemaInitStatements( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.InitSchemaStatements, error) { schemasMap := map[string]struct{}{} for _, t := range tables { schemasMap[t.Schema] = struct{}{} @@ -156,7 +189,14 @@ func (m *Manager) GetSchemaInitStatements(ctx context.Context, tables []*sqlmana schemaStmts := []string{} errgrp.Go(func() error { for schema := range schemasMap { - schemaStmts = append(schemaStmts, fmt.Sprintf("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '%s')\nBEGIN\n EXEC('CREATE SCHEMA [%s]')\nEND;", schema, schema)) + schemaStmts = append( + schemaStmts, + fmt.Sprintf( + "IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '%s')\nBEGIN\n EXEC('CREATE SCHEMA [%s]')\nEND;", + schema, + schema, + ), + ) } return nil }) @@ -235,7 +275,10 @@ func (m *Manager) GetSchemaInitStatements(ctx context.Context, tables []*sqlmana }, nil } -func (m *Manager) GetSchemaTableDataTypes(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) (*sqlmanager_shared.SchemaTableDataTypeResponse, error) { +func (m *Manager) GetSchemaTableDataTypes( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) (*sqlmanager_shared.SchemaTableDataTypeResponse, error) { if len(tables) == 0 { return &sqlmanager_shared.SchemaTableDataTypeResponse{}, nil } @@ -278,7 +321,10 @@ func (m *Manager) GetSchemaTableDataTypes(ctx context.Context, tables []*sqlmana return output, nil } -func (m *Manager) GetSchemaTableTriggers(ctx context.Context, tables []*sqlmanager_shared.SchemaTable) ([]*sqlmanager_shared.TableTrigger, error) { +func (m *Manager) GetSchemaTableTriggers( + ctx context.Context, + tables []*sqlmanager_shared.SchemaTable, +) ([]*sqlmanager_shared.TableTrigger, error) { if len(tables) == 0 { return []*sqlmanager_shared.TableTrigger{}, nil } @@ -299,20 +345,35 @@ func (m *Manager) GetSchemaTableTriggers(ctx context.Context, tables []*sqlmanag for _, row := range rows { if !row.Definition.Valid { // This may occur if the trigger is encrypted or implemented as a CLR trigger (i.e., not written in T-SQL). - m.logger.Warn("mssql trigger definition is missing", "schema", row.SchemaName, "table", row.TableName, "trigger", row.TriggerName) + m.logger.Warn( + "mssql trigger definition is missing", + "schema", + row.SchemaName, + "table", + row.TableName, + "trigger", + row.TriggerName, + ) continue } output = append(output, &sqlmanager_shared.TableTrigger{ Schema: row.SchemaName, Table: row.TableName, TriggerName: row.TriggerName, - Definition: generateCreateTriggerStatement(row.TriggerName, row.SchemaName, row.Definition.String), + Definition: generateCreateTriggerStatement( + row.TriggerName, + row.SchemaName, + row.Definition.String, + ), }) } return output, nil } -func (m *Manager) getSequencesBySchemas(ctx context.Context, schemas []string) ([]*sqlmanager_shared.DataType, error) { +func (m *Manager) getSequencesBySchemas( + ctx context.Context, + schemas []string, +) ([]*sqlmanager_shared.DataType, error) { rows, err := m.querier.GetCustomSequencesBySchemas(ctx, m.db, schemas) if err != nil && !isNoRows(err) { return nil, err @@ -331,7 +392,11 @@ func (m *Manager) getSequencesBySchemas(ctx context.Context, schemas []string) ( return output, nil } -func (m *Manager) GetSequencesByTables(ctx context.Context, schema string, tables []string) ([]*sqlmanager_shared.DataType, error) { +func (m *Manager) GetSequencesByTables( + ctx context.Context, + schema string, + tables []string, +) ([]*sqlmanager_shared.DataType, error) { rows, err := m.querier.GetCustomSequencesBySchemas(ctx, m.db, []string{schema}) if err != nil && !isNoRows(err) { return nil, err @@ -350,7 +415,10 @@ func (m *Manager) GetSequencesByTables(ctx context.Context, schema string, table return output, nil } -func (m *Manager) getViewsAndFunctionsBySchemas(ctx context.Context, schemas []string) ([]*sqlmanager_shared.DataType, error) { +func (m *Manager) getViewsAndFunctionsBySchemas( + ctx context.Context, + schemas []string, +) ([]*sqlmanager_shared.DataType, error) { rows, err := m.querier.GetViewsAndFunctionsBySchemas(ctx, m.db, schemas) if err != nil && !isNoRows(err) { return nil, err @@ -363,15 +431,21 @@ func (m *Manager) getViewsAndFunctionsBySchemas(ctx context.Context, schemas []s output := make([]*sqlmanager_shared.DataType, 0, len(orderedObjects)) for _, row := range orderedObjects { output = append(output, &sqlmanager_shared.DataType{ - Schema: row.SchemaName, - Name: row.ObjectName, - Definition: generateCreateDatabaseObjectStatement(row.ObjectName, row.SchemaName, row.Definition), + Schema: row.SchemaName, + Name: row.ObjectName, + Definition: generateCreateDatabaseObjectStatement( + row.ObjectName, + row.SchemaName, + row.Definition, + ), }) } return output, nil } -func orderObjectsByDependency(objects []*mssql_queries.GetViewsAndFunctionsBySchemasRow) []*mssql_queries.GetViewsAndFunctionsBySchemasRow { +func orderObjectsByDependency( + objects []*mssql_queries.GetViewsAndFunctionsBySchemasRow, +) []*mssql_queries.GetViewsAndFunctionsBySchemasRow { objectMap := make(map[string]*mssql_queries.GetViewsAndFunctionsBySchemasRow) processedObjects := make(map[string]bool) diff --git a/internal/ee/mssql-manager/generate-sql.go b/internal/ee/mssql-manager/generate-sql.go index 096d98a5ab..b08478fd2a 100644 --- a/internal/ee/mssql-manager/generate-sql.go +++ b/internal/ee/mssql-manager/generate-sql.go @@ -8,7 +8,9 @@ import ( ) // Creates idempotent create table statement -func generateCreateTableStatement(rows []*mssql_queries.GetDatabaseTableSchemasBySchemasAndTablesRow) string { +func generateCreateTableStatement( + rows []*mssql_queries.GetDatabaseTableSchemasBySchemasAndTablesRow, +) string { if len(rows) == 0 { return "" } @@ -19,8 +21,13 @@ func generateCreateTableStatement(rows []*mssql_queries.GetDatabaseTableSchemasB var sb strings.Builder // Create table if not exists - sb.WriteString(fmt.Sprintf("IF NOT EXISTS (SELECT * FROM sys.objects WHERE object_id = OBJECT_ID(N'[%s].[%s]') AND type in (N'U'))\nBEGIN\n", - tableSchema, tableName)) + sb.WriteString( + fmt.Sprintf( + "IF NOT EXISTS (SELECT * FROM sys.objects WHERE object_id = OBJECT_ID(N'[%s].[%s]') AND type in (N'U'))\nBEGIN\n", + tableSchema, + tableName, + ), + ) sb.WriteString(fmt.Sprintf("CREATE TABLE [%s].[%s] (\n", tableSchema, tableName)) primaryKeys := []string{} @@ -118,7 +125,13 @@ func generateCreateTableStatement(rows []*mssql_queries.GetDatabaseTableSchemasB } if len(primaryKeys) > 0 { - sb.WriteString(fmt.Sprintf("CONSTRAINT pk_%s PRIMARY KEY (%s)", tableName, strings.Join(primaryKeys, ","))) + sb.WriteString( + fmt.Sprintf( + "CONSTRAINT pk_%s PRIMARY KEY (%s)", + tableName, + strings.Join(primaryKeys, ","), + ), + ) } if periodDefinition != nil && *periodDefinition != "" { @@ -229,7 +242,9 @@ END`, record.TypeName, record.SchemaName, record.Definition)) } // Creates idempotent alter table add constraint statement -func generateAddConstraintStatement(constraint *mssql_queries.GetTableConstraintsBySchemasRow) string { +func generateAddConstraintStatement( + constraint *mssql_queries.GetTableConstraintsBySchemasRow, +) string { var sb strings.Builder // Start IF NOT EXISTS check @@ -260,8 +275,11 @@ BEGIN sb.WriteString(fmt.Sprintf("(%s)", escapeColumnList(constraint.ConstraintColumns))) case "FOREIGN KEY": - sb.WriteString(fmt.Sprintf("FOREIGN KEY (%s) ", escapeColumnList(constraint.ConstraintColumns))) - if constraint.ReferencedSchema.Valid && constraint.ReferencedTable.Valid && constraint.ReferencedColumns.Valid { + sb.WriteString( + fmt.Sprintf("FOREIGN KEY (%s) ", escapeColumnList(constraint.ConstraintColumns)), + ) + if constraint.ReferencedSchema.Valid && constraint.ReferencedTable.Valid && + constraint.ReferencedColumns.Valid { sb.WriteString(fmt.Sprintf("REFERENCES [%s].[%s] (%s)", constraint.ReferencedSchema.String, constraint.ReferencedTable.String, diff --git a/internal/ee/presidio/interface.go b/internal/ee/presidio/interface.go index 455dd3f884..fae2553d2e 100644 --- a/internal/ee/presidio/interface.go +++ b/internal/ee/presidio/interface.go @@ -6,15 +6,27 @@ import ( // Slimmed down Presidio Analyze Interface for use in Neosync systems type AnalyzeInterface interface { - PostAnalyzeWithResponse(ctx context.Context, body PostAnalyzeJSONRequestBody, reqEditors ...RequestEditorFn) (*PostAnalyzeResponse, error) + PostAnalyzeWithResponse( + ctx context.Context, + body PostAnalyzeJSONRequestBody, + reqEditors ...RequestEditorFn, + ) (*PostAnalyzeResponse, error) } // Slimmed down Presidio Anonymize Interface for use in Neosync systems type AnonymizeInterface interface { - PostAnonymizeWithResponse(ctx context.Context, body PostAnonymizeJSONRequestBody, reqEditors ...RequestEditorFn) (*PostAnonymizeResponse, error) + PostAnonymizeWithResponse( + ctx context.Context, + body PostAnonymizeJSONRequestBody, + reqEditors ...RequestEditorFn, + ) (*PostAnonymizeResponse, error) } // Slimmed down Presidio Entity Interface for use in Neosync systems type EntityInterface interface { - GetSupportedentitiesWithResponse(ctx context.Context, params *GetSupportedentitiesParams, reqEditors ...RequestEditorFn) (*GetSupportedentitiesResponse, error) + GetSupportedentitiesWithResponse( + ctx context.Context, + params *GetSupportedentitiesParams, + reqEditors ...RequestEditorFn, + ) (*GetSupportedentitiesResponse, error) } diff --git a/internal/ee/presidio/util.go b/internal/ee/presidio/util.go index 29ed8335ba..8d6520460a 100644 --- a/internal/ee/presidio/util.go +++ b/internal/ee/presidio/util.go @@ -1,6 +1,8 @@ package presidioapi -func ToAnonymizeRecognizerResults(input []RecognizerResultWithAnaysisExplanation) []RecognizerResult { +func ToAnonymizeRecognizerResults( + input []RecognizerResultWithAnaysisExplanation, +) []RecognizerResult { output := make([]RecognizerResult, 0, len(input)) for _, rr := range input { output = append(output, ToAnonymizeRecognizerResult(rr)) diff --git a/internal/ee/rbac/allow_all_client.go b/internal/ee/rbac/allow_all_client.go index 98c6bbfa31..967789e61d 100644 --- a/internal/ee/rbac/allow_all_client.go +++ b/internal/ee/rbac/allow_all_client.go @@ -12,35 +12,67 @@ type AllowAllClient struct { var _ Interface = (*AllowAllClient)(nil) -func (a *AllowAllClient) Job(ctx context.Context, user, account, job EntityString, action JobAction) (bool, error) { +func (a *AllowAllClient) Job( + ctx context.Context, + user, account, job EntityString, + action JobAction, +) (bool, error) { return true, nil } -func (a *AllowAllClient) Connection(ctx context.Context, user, account, connection EntityString, action ConnectionAction) (bool, error) { +func (a *AllowAllClient) Connection( + ctx context.Context, + user, account, connection EntityString, + action ConnectionAction, +) (bool, error) { return true, nil } -func (a *AllowAllClient) Account(ctx context.Context, user, account EntityString, action AccountAction) (bool, error) { +func (a *AllowAllClient) Account( + ctx context.Context, + user, account EntityString, + action AccountAction, +) (bool, error) { return true, nil } -func (a *AllowAllClient) EnforceJob(ctx context.Context, user, account, job EntityString, action JobAction) error { +func (a *AllowAllClient) EnforceJob( + ctx context.Context, + user, account, job EntityString, + action JobAction, +) error { return nil } -func (a *AllowAllClient) EnforceConnection(ctx context.Context, user, account, connection EntityString, action ConnectionAction) error { +func (a *AllowAllClient) EnforceConnection( + ctx context.Context, + user, account, connection EntityString, + action ConnectionAction, +) error { return nil } -func (a *AllowAllClient) EnforceAccount(ctx context.Context, user, account EntityString, action AccountAction) error { +func (a *AllowAllClient) EnforceAccount( + ctx context.Context, + user, account EntityString, + action AccountAction, +) error { return nil } -func (a *AllowAllClient) SetAccountRole(ctx context.Context, user, account EntityString, role mgmtv1alpha1.AccountRole) error { +func (a *AllowAllClient) SetAccountRole( + ctx context.Context, + user, account EntityString, + role mgmtv1alpha1.AccountRole, +) error { return nil } -func (a *AllowAllClient) RemoveAccountRole(ctx context.Context, user, account EntityString, role mgmtv1alpha1.AccountRole) error { +func (a *AllowAllClient) RemoveAccountRole( + ctx context.Context, + user, account EntityString, + role mgmtv1alpha1.AccountRole, +) error { return nil } @@ -48,10 +80,20 @@ func (a *AllowAllClient) RemoveAccountUser(ctx context.Context, user, account En return nil } -func (a *AllowAllClient) SetupNewAccount(ctx context.Context, accountId string, logger *slog.Logger) error { +func (a *AllowAllClient) SetupNewAccount( + ctx context.Context, + accountId string, + logger *slog.Logger, +) error { return nil } -func (a *AllowAllClient) GetUserRoles(ctx context.Context, users []EntityString, account EntityString, logger *slog.Logger) map[string]Role { + +func (a *AllowAllClient) GetUserRoles( + ctx context.Context, + users []EntityString, + account EntityString, + logger *slog.Logger, +) map[string]Role { return map[string]Role{} } diff --git a/internal/ee/rbac/enforcer/enforcer.go b/internal/ee/rbac/enforcer/enforcer.go index c382fcce40..ba36cb8396 100644 --- a/internal/ee/rbac/enforcer/enforcer.go +++ b/internal/ee/rbac/enforcer/enforcer.go @@ -37,8 +37,12 @@ func newEnforcer( if err != nil { return nil, fmt.Errorf("unable to initialize casbin synced cached enforcer: %w", err) } - enforcer.EnableAutoSave(true) // seems to do this automatically but it doesn't hurt - enforcer.StartAutoLoadPolicy(time.Second * 10) // allows HA between neosync-api instances or backend changes to RBAC policies to be picked up. + enforcer.EnableAutoSave( + true, + ) // seems to do this automatically but it doesn't hurt + enforcer.StartAutoLoadPolicy( + time.Second * 10, + ) // allows HA between neosync-api instances or backend changes to RBAC policies to be picked up. return enforcer, nil } diff --git a/internal/ee/rbac/policy.go b/internal/ee/rbac/policy.go index 34d36f9e2a..e75b357210 100644 --- a/internal/ee/rbac/policy.go +++ b/internal/ee/rbac/policy.go @@ -19,21 +19,70 @@ type Db interface { // Interface that handles enforcing entity level policies type EntityEnforcer interface { - Job(ctx context.Context, user EntityString, account EntityString, job EntityString, action JobAction) (bool, error) - EnforceJob(ctx context.Context, user EntityString, account EntityString, job EntityString, action JobAction) error - Connection(ctx context.Context, user EntityString, account EntityString, connection EntityString, action ConnectionAction) (bool, error) - EnforceConnection(ctx context.Context, user EntityString, account EntityString, connection EntityString, action ConnectionAction) error - Account(ctx context.Context, user EntityString, account EntityString, action AccountAction) (bool, error) - EnforceAccount(ctx context.Context, user EntityString, account EntityString, action AccountAction) error + Job( + ctx context.Context, + user EntityString, + account EntityString, + job EntityString, + action JobAction, + ) (bool, error) + EnforceJob( + ctx context.Context, + user EntityString, + account EntityString, + job EntityString, + action JobAction, + ) error + Connection( + ctx context.Context, + user EntityString, + account EntityString, + connection EntityString, + action ConnectionAction, + ) (bool, error) + EnforceConnection( + ctx context.Context, + user EntityString, + account EntityString, + connection EntityString, + action ConnectionAction, + ) error + Account( + ctx context.Context, + user EntityString, + account EntityString, + action AccountAction, + ) (bool, error) + EnforceAccount( + ctx context.Context, + user EntityString, + account EntityString, + action AccountAction, + ) error } // Interface that handles setting and removing roles for users type RoleAdmin interface { - SetAccountRole(ctx context.Context, user EntityString, account EntityString, role mgmtv1alpha1.AccountRole) error - RemoveAccountRole(ctx context.Context, user EntityString, account EntityString, role mgmtv1alpha1.AccountRole) error + SetAccountRole( + ctx context.Context, + user EntityString, + account EntityString, + role mgmtv1alpha1.AccountRole, + ) error + RemoveAccountRole( + ctx context.Context, + user EntityString, + account EntityString, + role mgmtv1alpha1.AccountRole, + ) error RemoveAccountUser(ctx context.Context, user EntityString, account EntityString) error SetupNewAccount(ctx context.Context, accountId string, logger *slog.Logger) error - GetUserRoles(ctx context.Context, users []EntityString, account EntityString, logger *slog.Logger) map[string]Role + GetUserRoles( + ctx context.Context, + users []EntityString, + account EntityString, + logger *slog.Logger, + ) map[string]Role } // Initialize default policies for existing accounts at startup @@ -59,8 +108,14 @@ func (r *Rbac) InitPolicies( return nil } -func setupAccountPolicies(enforcer casbin.IEnforcer, accountIds []string, logger *slog.Logger) error { - logger.Debug(fmt.Sprintf("found %d account ids to associate with rbac policies", len(accountIds))) +func setupAccountPolicies( + enforcer casbin.IEnforcer, + accountIds []string, + logger *slog.Logger, +) error { + logger.Debug( + fmt.Sprintf("found %d account ids to associate with rbac policies", len(accountIds)), + ) policyRules := [][]string{} for _, accountId := range accountIds { @@ -94,7 +149,13 @@ func setupAccountPolicies(enforcer casbin.IEnforcer, accountIds []string, logger } // For the given accounts, assign users to the account admin role if the account does not currently have any role assignments -func setupUserAssignments(ctx context.Context, db Db, enforcer casbin.IEnforcer, accountIds []string, logger *slog.Logger) error { +func setupUserAssignments( + ctx context.Context, + db Db, + enforcer casbin.IEnforcer, + accountIds []string, + logger *slog.Logger, +) error { policiesByDomain, err := getGroupingPoliciesByDomain(enforcer) if err != nil { return err @@ -116,7 +177,13 @@ func setupUserAssignments(ctx context.Context, db Db, enforcer casbin.IEnforcer, logger.Debug(fmt.Sprintf("no users found for account %s, skipping", accountId)) continue } - logger.Debug(fmt.Sprintf("found %d users for account %s, assigning all account admin role", len(users), accountId)) + logger.Debug( + fmt.Sprintf( + "found %d users for account %s, assigning all account admin role", + len(users), + accountId, + ), + ) for _, user := range users { groupedRules = append(groupedRules, []string{ NewUserIdEntity(user).String(), @@ -161,7 +228,13 @@ func (r *Rbac) SetupNewAccount( ) error { accountRules := getAccountPolicyRules(accountId) if len(accountRules) > 0 { - logger.Debug(fmt.Sprintf("adding %d policy rules to rbac engine for account %s", len(accountRules), accountId)) + logger.Debug( + fmt.Sprintf( + "adding %d policy rules to rbac engine for account %s", + len(accountRules), + accountId, + ), + ) shouldReloadPolicy := false for _, policy := range accountRules { result, err := setPolicy(r.e, policy) @@ -351,7 +424,9 @@ func (r *Rbac) EnforceJob( return err } if !ok { - return nucleuserrors.NewUnauthorized(fmt.Sprintf("user does not have permission to %s job", action)) + return nucleuserrors.NewUnauthorized( + fmt.Sprintf("user does not have permission to %s job", action), + ) } return nil } @@ -378,7 +453,9 @@ func (r *Rbac) EnforceConnection( return err } if !ok { - return nucleuserrors.NewUnauthorized(fmt.Sprintf("user does not have permission to %s connection", action)) + return nucleuserrors.NewUnauthorized( + fmt.Sprintf("user does not have permission to %s connection", action), + ) } return nil } @@ -403,7 +480,9 @@ func (r *Rbac) EnforceAccount( return err } if !ok { - return nucleuserrors.NewUnauthorized(fmt.Sprintf("user does not have permission to %s account", action)) + return nucleuserrors.NewUnauthorized( + fmt.Sprintf("user does not have permission to %s account", action), + ) } return nil } @@ -424,7 +503,9 @@ func setPolicy(e casbin.IEnforcer, policy []string) (*setPolicyResult, error) { return nil, fmt.Errorf("unable to check if policy exists: %w", err) } if !ok { - _, err = e.AddPolicy(policy) // always resolves to true even if it was not added, may be adapter dependent + _, err = e.AddPolicy( + policy, + ) // always resolves to true even if it was not added, may be adapter dependent if err != nil && !neosyncdb.IsConflict(err) { return nil, fmt.Errorf("unable to add policy: %w", err) } else if err != nil && neosyncdb.IsConflict(err) { diff --git a/internal/ee/rbac/roles.go b/internal/ee/rbac/roles.go index 49e7a0f9d7..b74b095716 100644 --- a/internal/ee/rbac/roles.go +++ b/internal/ee/rbac/roles.go @@ -45,6 +45,9 @@ func fromRoleDto(role mgmtv1alpha1.AccountRole) (string, error) { case mgmtv1alpha1.AccountRole_ACCOUNT_ROLE_JOB_VIEWER: return Role_JobViewer.String(), nil default: - return "", fmt.Errorf("account role provided has not be mapped to a casbin role name: %d", role) + return "", fmt.Errorf( + "account role provided has not be mapped to a casbin role name: %d", + role, + ) } } diff --git a/internal/ee/slack/slack.go b/internal/ee/slack/slack.go index 82167993b6..5d8cf5b850 100644 --- a/internal/ee/slack/slack.go +++ b/internal/ee/slack/slack.go @@ -17,10 +17,18 @@ type IsUserInAccountFunc func(ctx context.Context, userId, accountId string) (bo type Interface interface { GetAuthorizeUrl(accountId, userId string) (string, error) - ValidateState(ctx context.Context, state, userId string, isUserInAccount IsUserInAccountFunc) (*OauthState, error) + ValidateState( + ctx context.Context, + state, userId string, + isUserInAccount IsUserInAccountFunc, + ) (*OauthState, error) ExchangeCodeForAccessToken(ctx context.Context, code string) (*slack.OAuthV2Response, error) Test(ctx context.Context, accessToken string) (*slack.AuthTestResponse, error) - SendMessage(ctx context.Context, accessToken, channelId string, options ...slack.MsgOption) error + SendMessage( + ctx context.Context, + accessToken, channelId string, + options ...slack.MsgOption, + ) error JoinChannel(ctx context.Context, accessToken, channelId string, logger *slog.Logger) error GetPublicChannels(ctx context.Context, accessToken string) ([]slack.Channel, error) } @@ -110,7 +118,11 @@ func (c *Client) GetAuthorizeUrl(accountId, userId string) (string, error) { return slackUrl.String(), nil } -func (c *Client) ValidateState(ctx context.Context, state, userId string, isUserInAccount IsUserInAccountFunc) (*OauthState, error) { +func (c *Client) ValidateState( + ctx context.Context, + state, userId string, + isUserInAccount IsUserInAccountFunc, +) (*OauthState, error) { stateDecrypted, err := c.encryptor.Decrypt(state) if err != nil { return nil, fmt.Errorf("unable to decrypt slack oauth state: %w", err) @@ -140,8 +152,18 @@ func (c *Client) ValidateState(ctx context.Context, state, userId string, isUser return &decodedState, nil } -func (c *Client) ExchangeCodeForAccessToken(ctx context.Context, code string) (*slack.OAuthV2Response, error) { - resp, err := slack.GetOAuthV2ResponseContext(ctx, c.cfg.httpClient, c.cfg.authClientId, c.cfg.authClientSecret, code, c.cfg.redirectUrl) +func (c *Client) ExchangeCodeForAccessToken( + ctx context.Context, + code string, +) (*slack.OAuthV2Response, error) { + resp, err := slack.GetOAuthV2ResponseContext( + ctx, + c.cfg.httpClient, + c.cfg.authClientId, + c.cfg.authClientSecret, + code, + c.cfg.redirectUrl, + ) if err != nil { return nil, fmt.Errorf("unable to exchange code for access token: %w", err) } @@ -161,7 +183,11 @@ func (c *Client) Test(ctx context.Context, accessToken string) (*slack.AuthTestR return resp, nil } -func (c *Client) SendMessage(ctx context.Context, accessToken, channelId string, options ...slack.MsgOption) error { +func (c *Client) SendMessage( + ctx context.Context, + accessToken, channelId string, + options ...slack.MsgOption, +) error { api := slack.New(accessToken) _, _, err := api.PostMessageContext(ctx, channelId, options...) if err != nil { @@ -170,7 +196,11 @@ func (c *Client) SendMessage(ctx context.Context, accessToken, channelId string, return nil } -func (c *Client) JoinChannel(ctx context.Context, accessToken, channelId string, logger *slog.Logger) error { +func (c *Client) JoinChannel( + ctx context.Context, + accessToken, channelId string, + logger *slog.Logger, +) error { api := slack.New(accessToken) _, _, warnings, err := api.JoinConversationContext(ctx, channelId) @@ -183,7 +213,10 @@ func (c *Client) JoinChannel(ctx context.Context, accessToken, channelId string, return nil } -func (c *Client) GetPublicChannels(ctx context.Context, accessToken string) ([]slack.Channel, error) { +func (c *Client) GetPublicChannels( + ctx context.Context, + accessToken string, +) ([]slack.Channel, error) { api := slack.New(accessToken) channels, _, err := api.GetConversationsContext(ctx, &slack.GetConversationsParameters{ diff --git a/internal/ee/transformers/functions/functions.go b/internal/ee/transformers/functions/functions.go index 2996fd40a2..3da26feed6 100644 --- a/internal/ee/transformers/functions/functions.go +++ b/internal/ee/transformers/functions/functions.go @@ -18,7 +18,11 @@ var ( // Used when using the PII Anonymizer with Neosync Transformers type NeosyncOperatorApi interface { - Transform(ctx context.Context, config *mgmtv1alpha1.TransformerConfig, value string) (string, error) + Transform( + ctx context.Context, + config *mgmtv1alpha1.TransformerConfig, + value string, + ) (string, error) } func TransformPiiText( @@ -48,12 +52,21 @@ func TransformPiiText( } if analyzeResp.JSON200 == nil { - return "", fmt.Errorf("received non-200 response from analyzer: %s %d %s", analyzeResp.Status(), analyzeResp.StatusCode(), string(analyzeResp.Body)) + return "", fmt.Errorf( + "received non-200 response from analyzer: %s %d %s", + analyzeResp.Status(), + analyzeResp.StatusCode(), + string(analyzeResp.Body), + ) } analysisResults := removeAllowedPhrases(*analyzeResp.JSON200, value, config.GetAllowedPhrases()) - analysisResults, neosyncEntityMap := processAnalysisResultsForNeosyncTransformers(analysisResults, getNeosyncConfiguredEntities(config), value) + analysisResults, neosyncEntityMap := processAnalysisResultsForNeosyncTransformers( + analysisResults, + getNeosyncConfiguredEntities(config), + value, + ) anonymizers, err := buildAnonymizers(config) if err != nil { return "", fmt.Errorf("unable to build anonymizers: %w", err) @@ -133,7 +146,11 @@ func handleNeosyncEntityAnonymization( transformerConfig = defaultTransformerConfig } if transformerConfig == nil { - logger.Warn("no transformer config found for entity (a default presidio profile may have been used)", "entity", presidioEntity) + logger.Warn( + "no transformer config found for entity (a default presidio profile may have been used)", + "entity", + presidioEntity, + ) continue } @@ -147,11 +164,17 @@ func handleNeosyncEntityAnonymization( logger.Warn("no original values found in queue for entity", "entity", item.EntityType) continue } - transformedSnippet, err := neosyncOperatorApi.Transform(ctx, transformerConfig, originalValue) + transformedSnippet, err := neosyncOperatorApi.Transform( + ctx, + transformerConfig, + originalValue, + ) if err != nil { return "", fmt.Errorf("unable to transform neosync entity %s: %w", presidioEntity, err) } - logger.Debug(fmt.Sprintf("transformed snippet %s replacing %s", transformedSnippet, *item.Text)) + logger.Debug( + fmt.Sprintf("transformed snippet %s replacing %s", transformedSnippet, *item.Text), + ) outputText = strings.Replace(outputText, *item.Text, transformedSnippet, 1) } return outputText, nil @@ -228,7 +251,10 @@ func getDefaultTransformerConfigByEntity(entity string) *mgmtv1alpha1.Transforme invalidEmailAction := mgmtv1alpha1.InvalidEmailAction_INVALID_EMAIL_ACTION_GENERATE return &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformEmailConfig{ - TransformEmailConfig: &mgmtv1alpha1.TransformEmail{EmailType: &emailType, InvalidEmailAction: &invalidEmailAction}, + TransformEmailConfig: &mgmtv1alpha1.TransformEmail{ + EmailType: &emailType, + InvalidEmailAction: &invalidEmailAction, + }, }, } default: @@ -247,7 +273,9 @@ func getNeosyncConfiguredEntities(config *mgmtv1alpha1.TransformPiiText) []strin return entities } -func buildAnonymizers(config *mgmtv1alpha1.TransformPiiText) (map[string]presidioapi.AnonymizeRequest_Anonymizers_AdditionalProperties, error) { +func buildAnonymizers( + config *mgmtv1alpha1.TransformPiiText, +) (map[string]presidioapi.AnonymizeRequest_Anonymizers_AdditionalProperties, error) { output := map[string]presidioapi.AnonymizeRequest_Anonymizers_AdditionalProperties{} defaultAnon, ok, err := toPresidioAnonymizerConfig("DEFAULT", config.GetDefaultAnonymizer()) if err != nil { @@ -340,7 +368,10 @@ func buildAdhocRecognizers(dtos []*mgmtv1alpha1.PiiDenyRecognizer) []presidioapi return output } -func toPresidioAnonymizerConfig(entity string, dto *mgmtv1alpha1.PiiAnonymizer) (*presidioapi.AnonymizeRequest_Anonymizers_AdditionalProperties, bool, error) { +func toPresidioAnonymizerConfig( + entity string, + dto *mgmtv1alpha1.PiiAnonymizer, +) (*presidioapi.AnonymizeRequest_Anonymizers_AdditionalProperties, bool, error) { switch cfg := dto.GetConfig().(type) { case *mgmtv1alpha1.PiiAnonymizer_Redact_: ap := &presidioapi.AnonymizeRequest_Anonymizers_AdditionalProperties{} @@ -416,7 +447,12 @@ func handleAnonRespErr(resp *presidioapi.PostAnonymizeResponse) error { return fmt.Errorf("%s", *resp.JSON422.Error) } if resp.JSON200 == nil { - return fmt.Errorf("received non-200 response from anonymizer: %s %d %s", resp.Status(), resp.StatusCode(), string(resp.Body)) + return fmt.Errorf( + "received non-200 response from anonymizer: %s %d %s", + resp.Status(), + resp.StatusCode(), + string(resp.Body), + ) } return nil } diff --git a/internal/ee/transformers/transformers.go b/internal/ee/transformers/transformers.go index 96ceac8483..0b3c2e02ea 100644 --- a/internal/ee/transformers/transformers.go +++ b/internal/ee/transformers/transformers.go @@ -10,8 +10,10 @@ var ( mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_STRING, mgmtv1alpha1.TransformerDataType_TRANSFORMER_DATA_TYPE_NULL, }, - SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC}, - Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_PII_TEXT, + SupportedJobTypes: []mgmtv1alpha1.SupportedJobType{ + mgmtv1alpha1.SupportedJobType_SUPPORTED_JOB_TYPE_SYNC, + }, + Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_PII_TEXT, Config: &mgmtv1alpha1.TransformerConfig{ Config: &mgmtv1alpha1.TransformerConfig_TransformPiiTextConfig{ TransformPiiTextConfig: &mgmtv1alpha1.TransformPiiText{ diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 7a4a173e9c..73444e61cc 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -65,7 +65,8 @@ func IsNotFound(err error) bool { if status.Code(err) == codes.NotFound { return true } - if connectErr := new(connect.Error); errors.As(err, &connectErr) && connectErr.Code() == connect.CodeNotFound { + if connectErr := new(connect.Error); errors.As(err, &connectErr) && + connectErr.Code() == connect.CodeNotFound { return true } } diff --git a/internal/gcp/client.go b/internal/gcp/client.go index 86f525e2ef..81e4dac161 100644 --- a/internal/gcp/client.go +++ b/internal/gcp/client.go @@ -19,7 +19,11 @@ import ( ) type ClientInterface interface { - GetDbSchemaFromPrefix(ctx context.Context, bucketName string, prefix string) ([]*mgmtv1alpha1.DatabaseColumn, error) + GetDbSchemaFromPrefix( + ctx context.Context, + bucketName string, + prefix string, + ) ([]*mgmtv1alpha1.DatabaseColumn, error) DoesPrefixContainTables(ctx context.Context, bucketName string, prefix string) (bool, error) GetRecordStreamFromPrefix( ctx context.Context, @@ -146,7 +150,12 @@ func (c *Client) GetRecordStreamFromPrefix( } err = streamRecordsFromReader(reader, onRecord) if closeErr := reader.Close(); closeErr != nil { - c.logger.Warn(fmt.Sprintf("failed to close reader while streaming records from prefix: %s", closeErr.Error())) + c.logger.Warn( + fmt.Sprintf( + "failed to close reader while streaming records from prefix: %s", + closeErr.Error(), + ), + ) } return err }) @@ -154,7 +163,10 @@ func (c *Client) GetRecordStreamFromPrefix( return errgrp.Wait() } -func (c *Client) ListObjectPrefixes(ctx context.Context, bucketName, prefix, delimiter string) ([]string, error) { +func (c *Client) ListObjectPrefixes( + ctx context.Context, + bucketName, prefix, delimiter string, +) ([]string, error) { prefixes := []string{} it := c.client.Bucket(bucketName).Objects(ctx, &storage.Query{ Prefix: prefix, @@ -178,7 +190,10 @@ func (c *Client) getTableColumnsFromFile( bucket *storage.BucketHandle, prefix string, ) ([]string, error) { - dataiterator := bucket.Objects(ctx, &storage.Query{Prefix: fmt.Sprintf("%s/data", strings.TrimSuffix(prefix, "/"))}) + dataiterator := bucket.Objects( + ctx, + &storage.Query{Prefix: fmt.Sprintf("%s/data", strings.TrimSuffix(prefix, "/"))}, + ) columns := []string{} var firstFile *storage.ObjectAttrs @@ -205,7 +220,9 @@ func (c *Client) getTableColumnsFromFile( } defer func() { if closeErr := reader.Close(); closeErr != nil { - c.logger.Warn(fmt.Sprintf("unable to successfully close gcs reader: %s", closeErr.Error())) + c.logger.Warn( + fmt.Sprintf("unable to successfully close gcs reader: %s", closeErr.Error()), + ) } }() @@ -229,7 +246,10 @@ func getSchemaTableFromPrefix(prefix string) (*sqlmanager_shared.SchemaTable, er if len(schemaTableList) == 1 { return &sqlmanager_shared.SchemaTable{Schema: "", Table: schemaTableList[0]}, nil } - return &sqlmanager_shared.SchemaTable{Schema: schemaTableList[0], Table: schemaTableList[1]}, nil + return &sqlmanager_shared.SchemaTable{ + Schema: schemaTableList[0], + Table: schemaTableList[1], + }, nil } // Returns the prefix that contains the table folders in GCS @@ -265,7 +285,10 @@ func getFirstRecordFromReader(reader io.Reader) (map[string]any, error) { return result, nil } -func streamRecordsFromReader(reader io.Reader, onRecord func(record map[string][]byte) error) error { +func streamRecordsFromReader( + reader io.Reader, + onRecord func(record map[string][]byte) error, +) error { gzipReader, err := gzip.NewReader(reader) if err != nil { return fmt.Errorf("failed to create gzip reader: %w", err) @@ -284,7 +307,10 @@ func streamRecordsFromReader(reader io.Reader, onRecord func(record map[string][ record, err := valToRecord(result) if err != nil { - return fmt.Errorf("unable to convert record from map[string]any to map[string][]byte: %w", err) + return fmt.Errorf( + "unable to convert record from map[string]any to map[string][]byte: %w", + err, + ) } err = onRecord(record) if err != nil { diff --git a/internal/integration-tests/worker/workflow/datasync-workflow.go b/internal/integration-tests/worker/workflow/datasync-workflow.go index 621bed92bd..6952758afc 100644 --- a/internal/integration-tests/worker/workflow/datasync-workflow.go +++ b/internal/integration-tests/worker/workflow/datasync-workflow.go @@ -172,28 +172,46 @@ func NewTestDataSyncWorkflowEnv( func (w *TestWorkflowEnv) ExecuteTestDataSyncWorkflow(jobId string) { w.TestEnv.SetStartWorkflowOptions(client.StartWorkflowOptions{ID: jobId}) datasyncWorkflow := datasync_workflow.New(w.fakeEELicense) - w.TestEnv.ExecuteWorkflow(datasyncWorkflow.Workflow, &datasync_workflow.WorkflowRequest{JobId: jobId}) + w.TestEnv.ExecuteWorkflow( + datasyncWorkflow.Workflow, + &datasync_workflow.WorkflowRequest{JobId: jobId}, + ) } // RequireActivitiesCompletedSuccessfully verifies all activities completed without errors // NOTE: this should be called before ExecuteTestDataSyncWorkflow func (w *TestWorkflowEnv) RequireActivitiesCompletedSuccessfully(t testing.TB) { - w.TestEnv.SetOnActivityCompletedListener(func(activityInfo *activity.Info, result converter.EncodedValue, err error) { - require.NoError(t, err, "Activity %s failed", activityInfo.ActivityType.Name) - if activityInfo.ActivityType.Name == "RunPostTableSync" && result.HasValue() { - var postTableSyncResp posttablesync_activity.RunPostTableSyncResponse - decodeErr := result.Get(&postTableSyncResp) - require.NoError(t, decodeErr, "Failed to decode result for activity %s", activityInfo.ActivityType.Name) - require.Emptyf(t, postTableSyncResp.Errors, "Post table sync activity returned errors: %v", formatPostTableSyncErrors(postTableSyncResp.Errors)) - } - }) + w.TestEnv.SetOnActivityCompletedListener( + func(activityInfo *activity.Info, result converter.EncodedValue, err error) { + require.NoError(t, err, "Activity %s failed", activityInfo.ActivityType.Name) + if activityInfo.ActivityType.Name == "RunPostTableSync" && result.HasValue() { + var postTableSyncResp posttablesync_activity.RunPostTableSyncResponse + decodeErr := result.Get(&postTableSyncResp) + require.NoError( + t, + decodeErr, + "Failed to decode result for activity %s", + activityInfo.ActivityType.Name, + ) + require.Emptyf( + t, + postTableSyncResp.Errors, + "Post table sync activity returned errors: %v", + formatPostTableSyncErrors(postTableSyncResp.Errors), + ) + } + }, + ) } func formatPostTableSyncErrors(errors []*posttablesync_activity.PostTableSyncError) []string { formatted := []string{} for _, err := range errors { for _, e := range err.Errors { - formatted = append(formatted, fmt.Sprintf("statement: %s error: %s", e.Statement, e.Error)) + formatted = append( + formatted, + fmt.Sprintf("statement: %s error: %s", e.Statement, e.Error), + ) } } return formatted @@ -208,7 +226,10 @@ type TestDatabaseManagers struct { // NewTestDatabaseManagers creates and configures database connection managers for testing func NewTestDatabaseManagers(t testing.TB) *TestDatabaseManagers { - sqlconnmanager := connectionmanager.NewConnectionManager(sqlprovider.NewProvider(&sqlconnect.SqlOpenConnector{}), connectionmanager.WithReaperPoll(10*time.Second)) + sqlconnmanager := connectionmanager.NewConnectionManager( + sqlprovider.NewProvider(&sqlconnect.SqlOpenConnector{}), + connectionmanager.WithReaperPoll(10*time.Second), + ) go sqlconnmanager.Reaper(testutil.GetConcurrentTestLogger(t)) mongoconnmanager := connectionmanager.NewConnectionManager(mongoprovider.NewProvider()) go mongoconnmanager.Reaper(testutil.GetConcurrentTestLogger(t)) diff --git a/internal/javascript/functions/benthos/functions.go b/internal/javascript/functions/benthos/functions.go index c1aca27fec..4360a03463 100644 --- a/internal/javascript/functions/benthos/functions.go +++ b/internal/javascript/functions/benthos/functions.go @@ -32,244 +32,276 @@ func Get() []*javascript_functions.FunctionDefinition { func getV0Fetch(namespace string) *javascript_functions.FunctionDefinition { fnName := "v0_fetch" - return javascript_functions.NewFunctionDefinition(namespace, fnName, func(r javascript_functions.Runner) javascript_functions.Function { - return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { - defer func() { - if r := recover(); r != nil { - // we set the named "err" argument to the error so that it can be returned - err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) - l.Error( - "recovered from panic in custom benthos function", - "error", err, - "function", fmt.Sprintf("%s.%s", namespace, fnName), - "stack", string(debug.Stack()), - ) + return javascript_functions.NewFunctionDefinition( + namespace, + fnName, + func(r javascript_functions.Runner) javascript_functions.Function { + return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { + defer func() { + if r := recover(); r != nil { + // we set the named "err" argument to the error so that it can be returned + err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) + l.Error( + "recovered from panic in custom benthos function", + "error", err, + "function", fmt.Sprintf("%s.%s", namespace, fnName), + "stack", string(debug.Stack()), + ) + } + }() + var ( + url string + httpHeaders map[string]any + method = "GET" + payload = "" + ) + if err := javascript_functions.ParseFunctionArguments(call, &url, &httpHeaders, &method, &payload); err != nil { + return nil, err } - }() - var ( - url string - httpHeaders map[string]any - method = "GET" - payload = "" - ) - if err := javascript_functions.ParseFunctionArguments(call, &url, &httpHeaders, &method, &payload); err != nil { - return nil, err - } - var payloadReader io.Reader - if payload != "" { - payloadReader = strings.NewReader(payload) - } + var payloadReader io.Reader + if payload != "" { + payloadReader = strings.NewReader(payload) + } - req, err := http.NewRequestWithContext(ctx, method, url, payloadReader) - if err != nil { - return nil, err - } + req, err := http.NewRequestWithContext(ctx, method, url, payloadReader) + if err != nil { + return nil, err + } - // Parse HTTP headers - for k, v := range httpHeaders { - vStr, _ := v.(string) - req.Header.Add(k, vStr) - } + // Parse HTTP headers + for k, v := range httpHeaders { + vStr, _ := v.(string) + req.Header.Add(k, vStr) + } - // Do request - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() + // Do request + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } - return map[string]any{ - "status": resp.StatusCode, - "body": string(respBody), - }, nil - } - }) + return map[string]any{ + "status": resp.StatusCode, + "body": string(respBody), + }, nil + } + }, + ) } func getV0MsgSetString(namespace string) *javascript_functions.FunctionDefinition { fnName := "v0_msg_set_string" - return javascript_functions.NewFunctionDefinition(namespace, fnName, func(r javascript_functions.Runner) javascript_functions.Function { - return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { - defer func() { - if r := recover(); r != nil { - // we set the named "err" argument to the error so that it can be returned - err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) - l.Error( - "recovered from panic in custom benthos function", - "error", err, - "function", fmt.Sprintf("%s.%s", namespace, fnName), - "stack", string(debug.Stack()), - ) + return javascript_functions.NewFunctionDefinition( + namespace, + fnName, + func(r javascript_functions.Runner) javascript_functions.Function { + return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { + defer func() { + if r := recover(); r != nil { + // we set the named "err" argument to the error so that it can be returned + err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) + l.Error( + "recovered from panic in custom benthos function", + "error", err, + "function", fmt.Sprintf("%s.%s", namespace, fnName), + "stack", string(debug.Stack()), + ) + } + }() + var value string + if err := javascript_functions.ParseFunctionArguments(call, &value); err != nil { + return nil, err } - }() - var value string - if err := javascript_functions.ParseFunctionArguments(call, &value); err != nil { - return nil, err - } - r.ValueApi().SetBytes([]byte(value)) - return nil, nil - } - }) + r.ValueApi().SetBytes([]byte(value)) + return nil, nil + } + }, + ) } func getV0MsgAsString(namespace string) *javascript_functions.FunctionDefinition { fnName := "v0_msg_as_string" - return javascript_functions.NewFunctionDefinition(namespace, fnName, func(r javascript_functions.Runner) javascript_functions.Function { - return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { - defer func() { - if r := recover(); r != nil { - // we set the named "err" argument to the error so that it can be returned - err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) - l.Error( - "recovered from panic in custom benthos function", - "error", err, - "function", fmt.Sprintf("%s.%s", namespace, fnName), - "stack", string(debug.Stack()), - ) + return javascript_functions.NewFunctionDefinition( + namespace, + fnName, + func(r javascript_functions.Runner) javascript_functions.Function { + return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { + defer func() { + if r := recover(); r != nil { + // we set the named "err" argument to the error so that it can be returned + err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) + l.Error( + "recovered from panic in custom benthos function", + "error", err, + "function", fmt.Sprintf("%s.%s", namespace, fnName), + "stack", string(debug.Stack()), + ) + } + }() + b, err := r.ValueApi().AsBytes() + if err != nil { + return nil, err } - }() - b, err := r.ValueApi().AsBytes() - if err != nil { - return nil, err + return string(b), nil } - return string(b), nil - } - }) + }, + ) } func getV0MsgSetStructured(namespace string) *javascript_functions.FunctionDefinition { fnName := "v0_msg_set_structured" - return javascript_functions.NewFunctionDefinition(namespace, fnName, func(r javascript_functions.Runner) javascript_functions.Function { - return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { - defer func() { - if r := recover(); r != nil { - // we set the named "err" argument to the error so that it can be returned - err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) - l.Error( - "recovered from panic in custom benthos function", - "error", err, - "function", fmt.Sprintf("%s.%s", namespace, fnName), - "stack", string(debug.Stack()), - ) + return javascript_functions.NewFunctionDefinition( + namespace, + fnName, + func(r javascript_functions.Runner) javascript_functions.Function { + return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { + defer func() { + if r := recover(); r != nil { + // we set the named "err" argument to the error so that it can be returned + err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) + l.Error( + "recovered from panic in custom benthos function", + "error", err, + "function", fmt.Sprintf("%s.%s", namespace, fnName), + "stack", string(debug.Stack()), + ) + } + }() + var value any + if err := javascript_functions.ParseFunctionArguments(call, &value); err != nil { + return nil, err } - }() - var value any - if err := javascript_functions.ParseFunctionArguments(call, &value); err != nil { - return nil, err - } - r.ValueApi().SetStructured(value) - return nil, nil - } - }) + r.ValueApi().SetStructured(value) + return nil, nil + } + }, + ) } func getV0MsgAsStructured(namespace string) *javascript_functions.FunctionDefinition { fnName := "v0_msg_as_structured" - return javascript_functions.NewFunctionDefinition(namespace, fnName, func(r javascript_functions.Runner) javascript_functions.Function { - return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { - defer func() { - if r := recover(); r != nil { - // we set the named "err" argument to the error so that it can be returned - err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) - l.Error( - "recovered from panic in custom benthos function", - "error", err, - "function", fmt.Sprintf("%s.%s", namespace, fnName), - "stack", string(debug.Stack()), - ) - } - }() - return r.ValueApi().AsStructured() - } - }) + return javascript_functions.NewFunctionDefinition( + namespace, + fnName, + func(r javascript_functions.Runner) javascript_functions.Function { + return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { + defer func() { + if r := recover(); r != nil { + // we set the named "err" argument to the error so that it can be returned + err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) + l.Error( + "recovered from panic in custom benthos function", + "error", err, + "function", fmt.Sprintf("%s.%s", namespace, fnName), + "stack", string(debug.Stack()), + ) + } + }() + return r.ValueApi().AsStructured() + } + }, + ) } func getV0MsgSetMeta(namespace string) *javascript_functions.FunctionDefinition { fnName := "v0_msg_set_meta" - return javascript_functions.NewFunctionDefinition(namespace, fnName, func(r javascript_functions.Runner) javascript_functions.Function { - return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { - defer func() { - if r := recover(); r != nil { - // we set the named "err" argument to the error so that it can be returned - err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) - l.Error( - "recovered from panic in custom benthos function", - "error", err, - "function", fmt.Sprintf("%s.%s", namespace, fnName), - "stack", string(debug.Stack()), - ) + return javascript_functions.NewFunctionDefinition( + namespace, + fnName, + func(r javascript_functions.Runner) javascript_functions.Function { + return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { + defer func() { + if r := recover(); r != nil { + // we set the named "err" argument to the error so that it can be returned + err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) + l.Error( + "recovered from panic in custom benthos function", + "error", err, + "function", fmt.Sprintf("%s.%s", namespace, fnName), + "stack", string(debug.Stack()), + ) + } + }() + var key string + var value any + if err := javascript_functions.ParseFunctionArguments(call, &key, &value); err != nil { + return nil, err } - }() - var key string - var value any - if err := javascript_functions.ParseFunctionArguments(call, &key, &value); err != nil { - return nil, err + r.ValueApi().MetaSetMut(key, value) + return nil, nil } - r.ValueApi().MetaSetMut(key, value) - return nil, nil - } - }) + }, + ) } func getV0MsgGetMeta(namespace string) *javascript_functions.FunctionDefinition { fnName := "v0_msg_get_meta" - return javascript_functions.NewFunctionDefinition(namespace, fnName, func(r javascript_functions.Runner) javascript_functions.Function { - return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { - defer func() { - if r := recover(); r != nil { - // we set the named "err" argument to the error so that it can be returned - err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) - l.Error( - "recovered from panic in custom benthos function", - "error", err, - "function", fmt.Sprintf("%s.%s", namespace, fnName), - "stack", string(debug.Stack()), - ) + return javascript_functions.NewFunctionDefinition( + namespace, + fnName, + func(r javascript_functions.Runner) javascript_functions.Function { + return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { + defer func() { + if r := recover(); r != nil { + // we set the named "err" argument to the error so that it can be returned + err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) + l.Error( + "recovered from panic in custom benthos function", + "error", err, + "function", fmt.Sprintf("%s.%s", namespace, fnName), + "stack", string(debug.Stack()), + ) + } + }() + var key string + if err := javascript_functions.ParseFunctionArguments(call, &key); err != nil { + return nil, err } - }() - var key string - if err := javascript_functions.ParseFunctionArguments(call, &key); err != nil { - return nil, err - } - result, ok := r.ValueApi().MetaGet(key) - if !ok { - return nil, fmt.Errorf("key %s not found", key) + result, ok := r.ValueApi().MetaGet(key) + if !ok { + return nil, fmt.Errorf("key %s not found", key) + } + return result, nil } - return result, nil - } - }) + }, + ) } func getV0MsgMetaExists(namespace string) *javascript_functions.FunctionDefinition { fnName := "v0_msg_exists_meta" - return javascript_functions.NewFunctionDefinition(namespace, fnName, func(r javascript_functions.Runner) javascript_functions.Function { - return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { - defer func() { - if r := recover(); r != nil { - // we set the named "err" argument to the error so that it can be returned - err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) - l.Error( - "recovered from panic in custom benthos function", - "error", err, - "function", fmt.Sprintf("%s.%s", namespace, fnName), - "stack", string(debug.Stack()), - ) + return javascript_functions.NewFunctionDefinition( + namespace, + fnName, + func(r javascript_functions.Runner) javascript_functions.Function { + return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { + defer func() { + if r := recover(); r != nil { + // we set the named "err" argument to the error so that it can be returned + err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) + l.Error( + "recovered from panic in custom benthos function", + "error", err, + "function", fmt.Sprintf("%s.%s", namespace, fnName), + "stack", string(debug.Stack()), + ) + } + }() + var key string + if err := javascript_functions.ParseFunctionArguments(call, &key); err != nil { + return nil, err } - }() - var key string - if err := javascript_functions.ParseFunctionArguments(call, &key); err != nil { - return nil, err + _, ok := r.ValueApi().MetaGet(key) + return ok, nil } - _, ok := r.ValueApi().MetaGet(key) - return ok, nil - } - }) + }, + ) } diff --git a/internal/javascript/functions/functions.go b/internal/javascript/functions/functions.go index 78012cbf54..d6210becfc 100644 --- a/internal/javascript/functions/functions.go +++ b/internal/javascript/functions/functions.go @@ -74,7 +74,11 @@ func getTypeString(arg goja.Value) string { // Returns an error if the arguments are not of the expected type. func ParseFunctionArguments(call goja.FunctionCall, ptrs ...any) error { if len(ptrs) < len(call.Arguments) { - return fmt.Errorf("have %d arguments, but only %d pointers to parse into", len(call.Arguments), len(ptrs)) + return fmt.Errorf( + "have %d arguments, but only %d pointers to parse into", + len(call.Arguments), + len(ptrs), + ) } for i := range call.Arguments { @@ -112,7 +116,14 @@ func ParseFunctionArguments(call goja.FunctionCall, ptrs ...any) error { } if err != nil { typeStr := getTypeString(arg) - return fmt.Errorf("could not parse %v (%s) into %v (%T): %v", arg, typeStr, ptr, ptr, err) + return fmt.Errorf( + "could not parse %v (%s) into %v (%T): %v", + arg, + typeStr, + ptr, + ptr, + err, + ) } } diff --git a/internal/javascript/functions/neosync/functions.go b/internal/javascript/functions/neosync/functions.go index 0ad8a42342..99863cde90 100644 --- a/internal/javascript/functions/neosync/functions.go +++ b/internal/javascript/functions/neosync/functions.go @@ -27,7 +27,11 @@ func Get() ([]*javascript_functions.FunctionDefinition, error) { } patchStructuredMessage := getPatchStructuredMessage(namespace) - output := make([]*javascript_functions.FunctionDefinition, 0, len(generatorFns)+len(transformerFns)+1) + output := make( + []*javascript_functions.FunctionDefinition, + 0, + len(generatorFns)+len(transformerFns)+1, + ) output = append(output, generatorFns...) output = append(output, transformerFns...) output = append(output, patchStructuredMessage) @@ -36,44 +40,48 @@ func Get() ([]*javascript_functions.FunctionDefinition, error) { func getPatchStructuredMessage(namespace string) *javascript_functions.FunctionDefinition { fnName := "patchStructuredMessage" - return javascript_functions.NewFunctionDefinition(namespace, fnName, func(r javascript_functions.Runner) javascript_functions.Function { - return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { - defer func() { - if r := recover(); r != nil { - // we set the named "err" argument to the error so that it can be returned - err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) - l.Error( - "recovered from panic in custom neosync function", - "error", err, - "function", fmt.Sprintf("%s.%s", namespace, fnName), - "stack", string(debug.Stack()), - ) + return javascript_functions.NewFunctionDefinition( + namespace, + fnName, + func(r javascript_functions.Runner) javascript_functions.Function { + return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { + defer func() { + if r := recover(); r != nil { + // we set the named "err" argument to the error so that it can be returned + err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, fnName, r) + l.Error( + "recovered from panic in custom neosync function", + "error", err, + "function", fmt.Sprintf("%s.%s", namespace, fnName), + "stack", string(debug.Stack()), + ) + } + }() + var updates map[string]any + if err := javascript_functions.ParseFunctionArguments(call, &updates); err != nil { + return nil, err } - }() - var updates map[string]any - if err := javascript_functions.ParseFunctionArguments(call, &updates); err != nil { - return nil, err - } - originalData, err := r.ValueApi().AsStructured() - if err != nil { - return nil, fmt.Errorf("failed to get structured data: %w", err) - } + originalData, err := r.ValueApi().AsStructured() + if err != nil { + return nil, fmt.Errorf("failed to get structured data: %w", err) + } - originalMap, ok := originalData.(map[string]any) - if !ok { - return nil, fmt.Errorf("structured data is not a map") - } + originalMap, ok := originalData.(map[string]any) + if !ok { + return nil, fmt.Errorf("structured data is not a map") + } - for key, value := range updates { - setNestedProperty(originalMap, key, value) - } + for key, value := range updates { + setNestedProperty(originalMap, key, value) + } - r.ValueApi().SetStructured(originalMap) + r.ValueApi().SetStructured(originalMap) - return nil, nil - } - }) + return nil, nil + } + }, + ) } func setNestedProperty(obj map[string]any, path string, value any) { @@ -101,34 +109,43 @@ func getNeosyncGenerators() ([]*javascript_functions.FunctionDefinition, error) return nil, err } - fn := javascript_functions.NewFunctionDefinition(namespace, templateData.Name, func(r javascript_functions.Runner) javascript_functions.Function { - return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { - defer func() { - if r := recover(); r != nil { - // we set the named "err" argument to the error so that it can be returned - err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, templateData.Name, r) - l.Error( - "recovered from panic in custom neosync function", - "error", err, - "function", fmt.Sprintf("%s.%s", namespace, templateData.Name), - "stack", string(debug.Stack()), - ) - } - }() - var ( - opts map[string]any - ) + fn := javascript_functions.NewFunctionDefinition( + namespace, + templateData.Name, + func(r javascript_functions.Runner) javascript_functions.Function { + return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { + defer func() { + if r := recover(); r != nil { + // we set the named "err" argument to the error so that it can be returned + err = fmt.Errorf( + "panic recovered: %s.%s: %v", + namespace, + templateData.Name, + r, + ) + l.Error( + "recovered from panic in custom neosync function", + "error", err, + "function", fmt.Sprintf("%s.%s", namespace, templateData.Name), + "stack", string(debug.Stack()), + ) + } + }() + var ( + opts map[string]any + ) - if err := javascript_functions.ParseFunctionArguments(call, &opts); err != nil { - return nil, err - } - goOpts, err := f.ParseOptions(opts) - if err != nil { - return nil, err + if err := javascript_functions.ParseFunctionArguments(call, &opts); err != nil { + return nil, err + } + goOpts, err := f.ParseOptions(opts) + if err != nil { + return nil, err + } + return f.Generate(goOpts) } - return f.Generate(goOpts) - } - }) + }, + ) fns = append(fns, fn) } return fns, nil @@ -143,35 +160,44 @@ func getNeosyncTransformers() ([]*javascript_functions.FunctionDefinition, error return nil, err } - fn := javascript_functions.NewFunctionDefinition(namespace, templateData.Name, func(r javascript_functions.Runner) javascript_functions.Function { - return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { - defer func() { - if r := recover(); r != nil { - // we set the named "err" argument to the error so that it can be returned - err = fmt.Errorf("panic recovered: %s.%s: %v", namespace, templateData.Name, r) - l.Error( - "recovered from panic in custom neosync function", - "error", err, - "function", fmt.Sprintf("%s.%s", namespace, templateData.Name), - "stack", string(debug.Stack()), - ) - } - }() - var ( - value any - opts map[string]any - ) + fn := javascript_functions.NewFunctionDefinition( + namespace, + templateData.Name, + func(r javascript_functions.Runner) javascript_functions.Function { + return func(ctx context.Context, call goja.FunctionCall, rt *goja.Runtime, l *slog.Logger) (result any, err error) { + defer func() { + if r := recover(); r != nil { + // we set the named "err" argument to the error so that it can be returned + err = fmt.Errorf( + "panic recovered: %s.%s: %v", + namespace, + templateData.Name, + r, + ) + l.Error( + "recovered from panic in custom neosync function", + "error", err, + "function", fmt.Sprintf("%s.%s", namespace, templateData.Name), + "stack", string(debug.Stack()), + ) + } + }() + var ( + value any + opts map[string]any + ) - if err := javascript_functions.ParseFunctionArguments(call, &value, &opts); err != nil { - return nil, err - } - goOpts, err := f.ParseOptions(opts) - if err != nil { - return nil, err + if err := javascript_functions.ParseFunctionArguments(call, &value, &opts); err != nil { + return nil, err + } + goOpts, err := f.ParseOptions(opts) + if err != nil { + return nil, err + } + return f.Transform(value, goOpts) } - return f.Transform(value, goOpts) - } - }) + }, + ) fns = append(fns, fn) } return fns, nil diff --git a/internal/javascript/vm/vm.go b/internal/javascript/vm/vm.go index ebb1f9c6f3..780270489b 100644 --- a/internal/javascript/vm/vm.go +++ b/internal/javascript/vm/vm.go @@ -83,7 +83,10 @@ func NewRunner(opts ...Option) (*Runner, error) { // if the stars align, we'll register the custom console module with the logger // must come before requireRegistry.Enable() if options.requireRegistry != nil && options.consoleEnabled && options.logger != nil { - options.requireRegistry.RegisterNativeModule(console.ModuleName, console.RequireWithPrinter(newConsoleLogger(stdPrefix, options.logger))) + options.requireRegistry.RegisterNativeModule( + console.ModuleName, + console.RequireWithPrinter(newConsoleLogger(stdPrefix, options.logger)), + ) } if options.requireRegistry != nil { @@ -139,7 +142,12 @@ func registerFunction(runner *Runner, function *javascript_functions.FunctionDef } return rt.ToValue(result) }); err != nil { - return fmt.Errorf("failed to set global %s function %v: %w", function.Namespace(), function.Name(), err) + return fmt.Errorf( + "failed to set global %s function %v: %w", + function.Namespace(), + function.Name(), + err, + ) } return nil } diff --git a/internal/job/jobmapping-validator.go b/internal/job/jobmapping-validator.go index 18b5a6f634..6681058b88 100644 --- a/internal/job/jobmapping-validator.go +++ b/internal/job/jobmapping-validator.go @@ -34,7 +34,10 @@ func WithJobSourceOptions(jobSourceOptions *SqlJobSourceOpts) Option { } } -func NewJobMappingsValidator(jobMappings []*mgmtv1alpha1.JobMapping, opts ...Option) *JobMappingsValidator { +func NewJobMappingsValidator( + jobMappings []*mgmtv1alpha1.JobMapping, + opts ...Option, +) *JobMappingsValidator { tableToColumnMappings := map[string]map[string]*mgmtv1alpha1.JobMapping{} for _, m := range jobMappings { tn := sqlmanager_shared.BuildTable(m.Schema, m.Table) @@ -75,14 +78,20 @@ func (j *JobMappingsValidator) GetColumnWarnings() map[string]map[string][]*mgmt return j.columnWarnings } -func (j *JobMappingsValidator) addDatabaseError(err string, code mgmtv1alpha1.DatabaseError_DatabaseErrorCode) { +func (j *JobMappingsValidator) addDatabaseError( + err string, + code mgmtv1alpha1.DatabaseError_DatabaseErrorCode, +) { j.databaseErrors = append(j.databaseErrors, &mgmtv1alpha1.DatabaseError_DatabaseErrorReport{ Code: code, Message: err, }) } -func (j *JobMappingsValidator) addTableError(table, err string, code mgmtv1alpha1.TableError_TableErrorCode) { +func (j *JobMappingsValidator) addTableError( + table, err string, + code mgmtv1alpha1.TableError_TableErrorCode, +) { if _, ok := j.tableErrors[table]; !ok { j.tableErrors[table] = []*mgmtv1alpha1.TableError_TableErrorReport{} } @@ -92,24 +101,36 @@ func (j *JobMappingsValidator) addTableError(table, err string, code mgmtv1alpha }) } -func (j *JobMappingsValidator) addColumnError(table, column, err string, code mgmtv1alpha1.ColumnError_ColumnErrorCode) { +func (j *JobMappingsValidator) addColumnError( + table, column, err string, + code mgmtv1alpha1.ColumnError_ColumnErrorCode, +) { if _, ok := j.columnErrors[table]; !ok { j.columnErrors[table] = map[string][]*mgmtv1alpha1.ColumnError_ColumnErrorReport{} } - j.columnErrors[table][column] = append(j.columnErrors[table][column], &mgmtv1alpha1.ColumnError_ColumnErrorReport{ - Code: code, - Message: err, - }) + j.columnErrors[table][column] = append( + j.columnErrors[table][column], + &mgmtv1alpha1.ColumnError_ColumnErrorReport{ + Code: code, + Message: err, + }, + ) } -func (j *JobMappingsValidator) addColumnWarning(table, column, err string, code mgmtv1alpha1.ColumnWarning_ColumnWarningCode) { +func (j *JobMappingsValidator) addColumnWarning( + table, column, err string, + code mgmtv1alpha1.ColumnWarning_ColumnWarningCode, +) { if _, ok := j.columnWarnings[table]; !ok { j.columnWarnings[table] = map[string][]*mgmtv1alpha1.ColumnWarning_ColumnWarningReport{} } - j.columnWarnings[table][column] = append(j.columnWarnings[table][column], &mgmtv1alpha1.ColumnWarning_ColumnWarningReport{ - Code: code, - Message: err, - }) + j.columnWarnings[table][column] = append( + j.columnWarnings[table][column], + &mgmtv1alpha1.ColumnWarning_ColumnWarningReport{ + Code: code, + Message: err, + }, + ) } func (j *JobMappingsValidator) Validate( @@ -119,7 +140,12 @@ func (j *JobMappingsValidator) Validate( ) (*JobMappingsValidatorResponse, error) { j.ValidateJobMappingsExistInSource(tableColumnMap) j.ValidateVirtualForeignKeys(virtualForeignKeys, tableColumnMap, tableConstraints) - err := j.ValidateCircularDependencies(tableConstraints.ForeignKeyConstraints, tableConstraints.PrimaryKeyConstraints, virtualForeignKeys, tableColumnMap) + err := j.ValidateCircularDependencies( + tableConstraints.ForeignKeyConstraints, + tableConstraints.PrimaryKeyConstraints, + virtualForeignKeys, + tableColumnMap, + ) if err != nil { return nil, err } @@ -140,14 +166,27 @@ func (j *JobMappingsValidator) ValidateJobMappingsExistInSource( // check for job mappings that do not exist in the source for table, colMappings := range j.jobMappings { if _, ok := tableColumnMap[table]; !ok { - j.addTableError(table, fmt.Sprintf("Table does not exist [%s] in source", table), mgmtv1alpha1.TableError_TABLE_ERROR_CODE_TABLE_NOT_FOUND_IN_SOURCE) + j.addTableError( + table, + fmt.Sprintf("Table does not exist [%s] in source", table), + mgmtv1alpha1.TableError_TABLE_ERROR_CODE_TABLE_NOT_FOUND_IN_SOURCE, + ) continue } for col := range colMappings { if _, ok := tableColumnMap[table][col]; !ok { - msg := fmt.Sprintf("Column does not exist in source. Remove column from job mappings: %s.%s", table, col) + msg := fmt.Sprintf( + "Column does not exist in source. Remove column from job mappings: %s.%s", + table, + col, + ) if j.jobSourceOptions != nil && !j.jobSourceOptions.HaltOnColumnRemoval { - j.addColumnWarning(table, col, msg, mgmtv1alpha1.ColumnWarning_COLUMN_WARNING_CODE_NOT_FOUND_IN_SOURCE) + j.addColumnWarning( + table, + col, + msg, + mgmtv1alpha1.ColumnWarning_COLUMN_WARNING_CODE_NOT_FOUND_IN_SOURCE, + ) } else { j.addColumnError(table, col, msg, mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_NOT_FOUND_IN_SOURCE) } @@ -162,9 +201,18 @@ func (j *JobMappingsValidator) ValidateJobMappingsExistInSource( } for col := range colMap { if _, ok := j.jobMappings[table][col]; !ok { - msg := fmt.Sprintf("Column does not exist in job mappings. Add column to job mappings: %s.%s", table, col) + msg := fmt.Sprintf( + "Column does not exist in job mappings. Add column to job mappings: %s.%s", + table, + col, + ) if j.jobSourceOptions != nil && !j.jobSourceOptions.HaltOnNewColumnAddition { - j.addColumnWarning(table, col, msg, mgmtv1alpha1.ColumnWarning_COLUMN_WARNING_CODE_NOT_FOUND_IN_MAPPING) + j.addColumnWarning( + table, + col, + msg, + mgmtv1alpha1.ColumnWarning_COLUMN_WARNING_CODE_NOT_FOUND_IN_MAPPING, + ) } else { j.addColumnError(table, col, msg, mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_NOT_FOUND_IN_MAPPING) } @@ -195,7 +243,10 @@ func (j *JobMappingsValidator) ValidateCircularDependencies( if ok { fkCol := fk.ForeignKey.Columns[idx] if _, ok = fkColMappings[fkCol]; ok { - validForeignKeyDependencies[table] = append(validForeignKeyDependencies[table], fk) + validForeignKeyDependencies[table] = append( + validForeignKeyDependencies[table], + fk, + ) } } } @@ -221,7 +272,12 @@ func (j *JobMappingsValidator) ValidateCircularDependencies( for _, col := range vfk.GetColumns() { colInfo, ok := tableCols[col] if !ok { - j.addColumnError(tableName, col, "Column does not exist in source but required by virtual foreign key", mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_SOURCE_COLUMN_NOT_FOUND_IN_SOURCE) + j.addColumnError( + tableName, + col, + "Column does not exist in source but required by virtual foreign key", + mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_SOURCE_COLUMN_NOT_FOUND_IN_SOURCE, + ) return nil } notNullable = append(notNullable, !colInfo.IsNullable) @@ -236,7 +292,10 @@ func (j *JobMappingsValidator) ValidateCircularDependencies( }, } allForeignKeys[tableName] = append(allForeignKeys[tableName], virt) - validForeignKeyDependencies[tableName] = append(validForeignKeyDependencies[tableName], virt) + validForeignKeyDependencies[tableName] = append( + validForeignKeyDependencies[tableName], + virt, + ) } tableColumnNameMap := map[string][]string{} @@ -246,9 +305,19 @@ func (j *JobMappingsValidator) ValidateCircularDependencies( } } - _, err := runconfigs.BuildRunConfigs(validForeignKeyDependencies, map[string]string{}, primaryKeys, tableColumnNameMap, map[string][][]string{}, map[string][][]string{}) + _, err := runconfigs.BuildRunConfigs( + validForeignKeyDependencies, + map[string]string{}, + primaryKeys, + tableColumnNameMap, + map[string][][]string{}, + map[string][][]string{}, + ) if err != nil { - j.addDatabaseError(err.Error(), mgmtv1alpha1.DatabaseError_DATABASE_ERROR_CODE_UNSUPPORTED_CIRCULAR_DEPENDENCY_AT_LEAST_ONE_NULLABLE) + j.addDatabaseError( + err.Error(), + mgmtv1alpha1.DatabaseError_DATABASE_ERROR_CODE_UNSUPPORTED_CIRCULAR_DEPENDENCY_AT_LEAST_ONE_NULLABLE, + ) } return nil @@ -273,12 +342,30 @@ func (j *JobMappingsValidator) ValidateRequiredForeignKeys( fkColMappings, ok := j.jobMappings[fk.ForeignKey.Table] fkCol := fk.ForeignKey.Columns[idx] if !ok { - j.addColumnError(fk.ForeignKey.Table, fkCol, fmt.Sprintf("Missing required foreign key. Table: %s Column: %s", fk.ForeignKey.Table, fkCol), mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_REQUIRED_FOREIGN_KEY_NOT_FOUND_IN_MAPPING) + j.addColumnError( + fk.ForeignKey.Table, + fkCol, + fmt.Sprintf( + "Missing required foreign key. Table: %s Column: %s", + fk.ForeignKey.Table, + fkCol, + ), + mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_REQUIRED_FOREIGN_KEY_NOT_FOUND_IN_MAPPING, + ) continue } _, ok = fkColMappings[fkCol] if !ok { - j.addColumnError(fk.ForeignKey.Table, fkCol, fmt.Sprintf("Missing required foreign key. Table: %s Column: %s", fk.ForeignKey.Table, fkCol), mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_REQUIRED_FOREIGN_KEY_NOT_FOUND_IN_MAPPING) + j.addColumnError( + fk.ForeignKey.Table, + fkCol, + fmt.Sprintf( + "Missing required foreign key. Table: %s Column: %s", + fk.ForeignKey.Table, + fkCol, + ), + mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_REQUIRED_FOREIGN_KEY_NOT_FOUND_IN_MAPPING, + ) } } } @@ -301,7 +388,16 @@ func (j *JobMappingsValidator) ValidateRequiredColumns( continue } if _, ok := cm[col]; !ok { - j.addColumnError(table, col, fmt.Sprintf("Violates not-null constraint. Missing required column. Table: %s Column: %s", table, col), mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_REQUIRED_COLUMN_NOT_FOUND_IN_MAPPING) + j.addColumnError( + table, + col, + fmt.Sprintf( + "Violates not-null constraint. Missing required column. Table: %s Column: %s", + table, + col, + ), + mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_REQUIRED_COLUMN_NOT_FOUND_IN_MAPPING, + ) } } } @@ -320,24 +416,52 @@ func (j *JobMappingsValidator) ValidateVirtualForeignKeys( // check that source table exist in job mappings sourceColMappings, ok := j.jobMappings[sourceTable] if !ok { - j.addTableError(sourceTable, fmt.Sprintf("Virtual foreign key source table missing in job mappings. Table: %s", sourceTable), mgmtv1alpha1.TableError_TABLE_ERROR_CODE_VFK_SOURCE_TABLE_NOT_FOUND_IN_MAPPING) + j.addTableError( + sourceTable, + fmt.Sprintf( + "Virtual foreign key source table missing in job mappings. Table: %s", + sourceTable, + ), + mgmtv1alpha1.TableError_TABLE_ERROR_CODE_VFK_SOURCE_TABLE_NOT_FOUND_IN_MAPPING, + ) continue } sourceCols, ok := tableColumnMap[sourceTable] if !ok { - j.addTableError(sourceTable, fmt.Sprintf("Virtual foreign key source table missing in source database. Table: %s", sourceTable), mgmtv1alpha1.TableError_TABLE_ERROR_CODE_VFK_SOURCE_TABLE_NOT_FOUND_IN_SOURCE) + j.addTableError( + sourceTable, + fmt.Sprintf( + "Virtual foreign key source table missing in source database. Table: %s", + sourceTable, + ), + mgmtv1alpha1.TableError_TABLE_ERROR_CODE_VFK_SOURCE_TABLE_NOT_FOUND_IN_SOURCE, + ) return } // check that target table exist in job mappings targetColMappings, ok := j.jobMappings[targetTable] if !ok { - j.addTableError(targetTable, fmt.Sprintf("Virtual foreign key target table missing in job mappings. Table: %s", targetTable), mgmtv1alpha1.TableError_TABLE_ERROR_CODE_VFK_TARGET_TABLE_NOT_FOUND_IN_MAPPING) + j.addTableError( + targetTable, + fmt.Sprintf( + "Virtual foreign key target table missing in job mappings. Table: %s", + targetTable, + ), + mgmtv1alpha1.TableError_TABLE_ERROR_CODE_VFK_TARGET_TABLE_NOT_FOUND_IN_MAPPING, + ) continue } targetCols, ok := tableColumnMap[targetTable] if !ok { - j.addTableError(targetTable, fmt.Sprintf("Virtual foreign key target table missing in source database. Table: %s", targetTable), mgmtv1alpha1.TableError_TABLE_ERROR_CODE_VFK_TARGET_TABLE_NOT_FOUND_IN_SOURCE) + j.addTableError( + targetTable, + fmt.Sprintf( + "Virtual foreign key target table missing in source database. Table: %s", + targetTable, + ), + mgmtv1alpha1.TableError_TABLE_ERROR_CODE_VFK_TARGET_TABLE_NOT_FOUND_IN_SOURCE, + ) continue } @@ -346,7 +470,18 @@ func (j *JobMappingsValidator) ValidateVirtualForeignKeys( j.validateCircularVfk(sourceTable, targetTable, vfk, targetColMappings, targetCols) if len(vfk.GetColumns()) != len(vfk.GetForeignKey().GetColumns()) { - j.addDatabaseError(fmt.Sprintf("length of source columns was not equal to length of foreign key cols: %d %d. SourceTable: %s SourceColumn: %+v TargetTable: %s TargetColumn: %+v", len(vfk.GetColumns()), len(vfk.GetForeignKey().GetColumns()), sourceTable, vfk.GetColumns(), targetTable, vfk.GetForeignKey().GetColumns()), mgmtv1alpha1.DatabaseError_DATABASE_ERROR_CODE_VFK_COLUMN_MISMATCH) + j.addDatabaseError( + fmt.Sprintf( + "length of source columns was not equal to length of foreign key cols: %d %d. SourceTable: %s SourceColumn: %+v TargetTable: %s TargetColumn: %+v", + len(vfk.GetColumns()), + len(vfk.GetForeignKey().GetColumns()), + sourceTable, + vfk.GetColumns(), + targetTable, + vfk.GetForeignKey().GetColumns(), + ), + mgmtv1alpha1.DatabaseError_DATABASE_ERROR_CODE_VFK_COLUMN_MISMATCH, + ) continue } @@ -359,7 +494,20 @@ func (j *JobMappingsValidator) ValidateVirtualForeignKeys( continue } if srcColInfo.DataType != tarColInfo.DataType { - j.addColumnError(targetTable, tarCol, fmt.Sprintf("Column datatype mismatch. Source: %s.%s %s Target: %s.%s %s", sourceTable, srcCol, srcColInfo.DataType, targetTable, tarCol, tarColInfo.DataType), mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_COLUMN_DATATYPE_MISMATCH) + j.addColumnError( + targetTable, + tarCol, + fmt.Sprintf( + "Column datatype mismatch. Source: %s.%s %s Target: %s.%s %s", + sourceTable, + srcCol, + srcColInfo.DataType, + targetTable, + tarCol, + tarColInfo.DataType, + ), + mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_COLUMN_DATATYPE_MISMATCH, + ) } } } @@ -375,11 +523,29 @@ func (j *JobMappingsValidator) validateVfkTableColumnsExistInSource( for _, c := range vfk.GetForeignKey().GetColumns() { _, ok := colMappings[c] if !ok { - j.addColumnError(table, c, fmt.Sprintf("Virtual foreign key source column missing in job mappings. Table: %s Column: %s", table, c), mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_SOURCE_COLUMN_NOT_FOUND_IN_MAPPING) + j.addColumnError( + table, + c, + fmt.Sprintf( + "Virtual foreign key source column missing in job mappings. Table: %s Column: %s", + table, + c, + ), + mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_SOURCE_COLUMN_NOT_FOUND_IN_MAPPING, + ) } _, ok = sourceCols[c] if !ok { - j.addColumnError(table, c, fmt.Sprintf("Virtual foreign key source column missing in source database. Table: %s Column: %s", table, c), mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_SOURCE_COLUMN_NOT_FOUND_IN_SOURCE) + j.addColumnError( + table, + c, + fmt.Sprintf( + "Virtual foreign key source column missing in source database. Table: %s Column: %s", + table, + c, + ), + mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_SOURCE_COLUMN_NOT_FOUND_IN_SOURCE, + ) } } } @@ -395,7 +561,16 @@ func (j *JobMappingsValidator) validateVfkSourceColumnHasConstraint( isVfkValid := isVirtualForeignKeySourceUnique(vfk, pks, uniqueConstraints) if !isVfkValid { for _, c := range vfk.GetForeignKey().GetColumns() { - j.addColumnError(table, c, fmt.Sprintf("Virtual foreign key source must be either a primary key or have a unique constraint. Table: %s Columns: %+v", table, vfk.GetForeignKey().GetColumns()), mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_SOURCE_COLUMN_NOT_UNIQUE) + j.addColumnError( + table, + c, + fmt.Sprintf( + "Virtual foreign key source must be either a primary key or have a unique constraint. Table: %s Columns: %+v", + table, + vfk.GetForeignKey().GetColumns(), + ), + mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_SOURCE_COLUMN_NOT_UNIQUE, + ) } } } @@ -410,15 +585,42 @@ func (j *JobMappingsValidator) validateCircularVfk( for _, c := range vfk.GetColumns() { _, ok := targetColMappings[c] if !ok { - j.addColumnError(targetTable, c, fmt.Sprintf("Virtual foreign key target column missing in job mappings. Table: %s Column: %s", targetTable, c), mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_TARGET_COLUMN_NOT_FOUND_IN_MAPPING) + j.addColumnError( + targetTable, + c, + fmt.Sprintf( + "Virtual foreign key target column missing in job mappings. Table: %s Column: %s", + targetTable, + c, + ), + mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_TARGET_COLUMN_NOT_FOUND_IN_MAPPING, + ) } colInfo, ok := targetCols[c] if !ok { - j.addColumnError(targetTable, c, fmt.Sprintf("Virtual foreign key target column missing in source database. Table: %s Column: %s", targetTable, c), mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_TARGET_COLUMN_NOT_FOUND_IN_SOURCE) + j.addColumnError( + targetTable, + c, + fmt.Sprintf( + "Virtual foreign key target column missing in source database. Table: %s Column: %s", + targetTable, + c, + ), + mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_VFK_TARGET_COLUMN_NOT_FOUND_IN_SOURCE, + ) continue } if sourceTable == targetTable && !colInfo.IsNullable { - j.addColumnError(targetTable, c, fmt.Sprintf("Self referencing virtual foreign key target column must be nullable. Table: %s Column: %s", targetTable, c), mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_UNSUPPORTED_CIRCULAR_DEPENDENCY_AT_LEAST_ONE_NULLABLE) + j.addColumnError( + targetTable, + c, + fmt.Sprintf( + "Self referencing virtual foreign key target column must be nullable. Table: %s Column: %s", + targetTable, + c, + ), + mgmtv1alpha1.ColumnError_COLUMN_ERROR_CODE_UNSUPPORTED_CIRCULAR_DEPENDENCY_AT_LEAST_ONE_NULLABLE, + ) } } } diff --git a/internal/job/validate-schema.go b/internal/job/validate-schema.go index 2dde9d92ab..2e17b699b2 100644 --- a/internal/job/validate-schema.go +++ b/internal/job/validate-schema.go @@ -30,7 +30,11 @@ func getJobMappingKey(mapping *mgmtv1alpha1.JobMapping) string { // - missing: columns defined in the job mappings but not found in the schema (excluding those whose schema or table are missing). // - extra: columns present in the schema but not defined in the job mappings. // This function takes into account missing schemas and missing tables (passed in as arguments) to avoid duplicate missing column errors. -func diffColumnsAgainstMappings(schema []*mgmtv1alpha1.DatabaseColumn, mappings []*mgmtv1alpha1.JobMapping, missingSchemaSet, missingTableSet map[string]bool) (missingCols, extraCols []*mgmtv1alpha1.DatabaseColumn) { +func diffColumnsAgainstMappings( + schema []*mgmtv1alpha1.DatabaseColumn, + mappings []*mgmtv1alpha1.JobMapping, + missingSchemaSet, missingTableSet map[string]bool, +) (missingCols, extraCols []*mgmtv1alpha1.DatabaseColumn) { extra := []*mgmtv1alpha1.DatabaseColumn{} missing := []*mgmtv1alpha1.DatabaseColumn{} @@ -86,7 +90,11 @@ func diffColumnsAgainstMappings(schema []*mgmtv1alpha1.DatabaseColumn, mappings // It returns a list of tables (as SchemaTable objects) that are defined in the job mappings but missing in the schema, // excluding those whose schemas are already missing. // This function accepts missingSchemas as an argument. -func diffTablesAgainstMappings(schema []*mgmtv1alpha1.DatabaseColumn, mappings []*mgmtv1alpha1.JobMapping, missingSchemaSet map[string]bool) []*sqlmanager_shared.SchemaTable { +func diffTablesAgainstMappings( + schema []*mgmtv1alpha1.DatabaseColumn, + mappings []*mgmtv1alpha1.JobMapping, + missingSchemaSet map[string]bool, +) []*sqlmanager_shared.SchemaTable { missing := []*sqlmanager_shared.SchemaTable{} schemaTables := make(map[string]*sqlmanager_shared.SchemaTable) mappingTables := make(map[string]*sqlmanager_shared.SchemaTable) @@ -117,7 +125,10 @@ func diffTablesAgainstMappings(schema []*mgmtv1alpha1.DatabaseColumn, mappings [ // diffSchemasAgainstMappings compares the schemas used in job mappings with the schema. // It returns a list of schemas that appear in the job mappings but not in the schema. -func diffSchemaAgainstMappings(schema []*mgmtv1alpha1.DatabaseColumn, mappings []*mgmtv1alpha1.JobMapping) []string { +func diffSchemaAgainstMappings( + schema []*mgmtv1alpha1.DatabaseColumn, + mappings []*mgmtv1alpha1.JobMapping, +) []string { missing := []string{} schemaSchemas := make(map[string]bool) mappingSchemas := make(map[string]bool) @@ -145,7 +156,10 @@ func diffSchemaAgainstMappings(schema []*mgmtv1alpha1.DatabaseColumn, mappings [ // ValidateSchemaAgainstJobMappings returns the differences between the schema and the job mappings. // It produces a SchemaDiff that details columns, tables, and schemas in the schema that lack corresponding job mappings, // as well as extra columns that are in the schema but not in the mappings. -func ValidateSchemaAgainstJobMappings(schema []*mgmtv1alpha1.DatabaseColumn, mappings []*mgmtv1alpha1.JobMapping) ValidationReport { +func ValidateSchemaAgainstJobMappings( + schema []*mgmtv1alpha1.DatabaseColumn, + mappings []*mgmtv1alpha1.JobMapping, +) ValidationReport { missingSchemas := diffSchemaAgainstMappings(schema, mappings) missingSchemaSet := make(map[string]bool) for _, s := range missingSchemas { @@ -157,7 +171,12 @@ func ValidateSchemaAgainstJobMappings(schema []*mgmtv1alpha1.DatabaseColumn, map for _, mt := range missingTables { missingTableSet[mt.String()] = true } - missingColumns, extraColumns := diffColumnsAgainstMappings(schema, mappings, missingSchemaSet, missingTableSet) + missingColumns, extraColumns := diffColumnsAgainstMappings( + schema, + mappings, + missingSchemaSet, + missingTableSet, + ) return ValidationReport{ MissingColumns: missingColumns, diff --git a/internal/json-anonymizer/json-anonymizer.go b/internal/json-anonymizer/json-anonymizer.go index 7930bf7825..0aca36c6eb 100644 --- a/internal/json-anonymizer/json-anonymizer.go +++ b/internal/json-anonymizer/json-anonymizer.go @@ -54,19 +54,31 @@ func NewAnonymizer(opts ...Option) (*JsonAnonymizer, error) { } if len(a.transformerMappings) == 0 && a.defaultTransformers == nil { - return nil, fmt.Errorf("failed to initialize JSON anonymizer. must provide either default transformers or transformer mappings") + return nil, fmt.Errorf( + "failed to initialize JSON anonymizer. must provide either default transformers or transformer mappings", + ) } // Initialize transformerExecutors var err error - a.transformerExecutors, err = initTransformerExecutors(a.transformerMappings, a.anonymizeConfig, a.transformerClient, a.logger) + a.transformerExecutors, err = initTransformerExecutors( + a.transformerMappings, + a.anonymizeConfig, + a.transformerClient, + a.logger, + ) if err != nil { return nil, err } // Initialize defaultTransformerExecutor if needed if a.defaultTransformers != nil { - a.defaultTransformerExecutor, err = initDefaultTransformerExecutors(a.defaultTransformers, a.anonymizeConfig, a.transformerClient, a.logger) + a.defaultTransformerExecutor, err = initDefaultTransformerExecutors( + a.defaultTransformers, + a.anonymizeConfig, + a.transformerClient, + a.logger, + ) if err != nil { return nil, err } @@ -92,7 +104,12 @@ func WithTransformerClient(transformerClient mgmtv1alpha1connect.TransformersSer } // WithAnonymizeConfig sets the analyze and anonymize clients for use by the presidio transformers only if isEnabled is true -func WithConditionalAnonymizeConfig(isEnabled bool, analyze presidioapi.AnalyzeInterface, anonymize presidioapi.AnonymizeInterface, defaultLanguage *string) Option { +func WithConditionalAnonymizeConfig( + isEnabled bool, + analyze presidioapi.AnalyzeInterface, + anonymize presidioapi.AnonymizeInterface, + defaultLanguage *string, +) Option { return func(ja *JsonAnonymizer) { if isEnabled && analyze != nil && anonymize != nil { ja.anonymizeConfig = &anonymizeConfig{ @@ -142,14 +159,21 @@ func (a *JsonAnonymizer) initializeJq() error { fnName := functionNames[idx] exec := a.transformerExecutors[idx] path := mapping.GetExpression() - compilerOpts = append(compilerOpts, gojq.WithFunction(fnName, 1, 1, func(_ any, args []any) any { - value := args[0] - result, err := exec.Mutate(value, exec.Opts) - if err != nil { - return fmt.Errorf("unable to anonymize value. expression: %s error: %w", path, err) - } - return derefPointer(result) - })) + compilerOpts = append( + compilerOpts, + gojq.WithFunction(fnName, 1, 1, func(_ any, args []any) any { + value := args[0] + result, err := exec.Mutate(value, exec.Opts) + if err != nil { + return fmt.Errorf( + "unable to anonymize value. expression: %s error: %w", + path, + err, + ) + } + return derefPointer(result) + }), + ) sanitizedPath := strings.ReplaceAll(path, "?", "") a.skipPaths[sanitizedPath] = struct{}{} @@ -162,7 +186,10 @@ func (a *JsonAnonymizer) initializeJq() error { } return gojq.NewIter(result) } - compilerOpts = append(compilerOpts, gojq.WithIterFunction("applyDefaultTransformers", 0, 0, applyDefaultTransformersFunc)) + compilerOpts = append( + compilerOpts, + gojq.WithIterFunction("applyDefaultTransformers", 0, 0, applyDefaultTransformersFunc), + ) compiledQuery, err := gojq.Compile(query, compilerOpts...) if err != nil { @@ -190,7 +217,8 @@ func (a *JsonAnonymizer) buildJqQuery() (query string, transformerFunctions []st functionNames := []string{} if a.defaultTransformers != nil { - if a.defaultTransformers.S != nil || a.defaultTransformers.N != nil || a.defaultTransformers.Boolean != nil { + if a.defaultTransformers.S != nil || a.defaultTransformers.N != nil || + a.defaultTransformers.Boolean != nil { queryParts = append(queryParts, "applyDefaultTransformers") } } @@ -365,17 +393,29 @@ func initTransformerExecutors( transformer_executor.WithLogger(logger), transformer_executor.WithUserDefinedTransformerResolver(newUdtResolver(transformerClient)), } - if anonymizeConfig != nil && anonymizeConfig.analyze != nil && anonymizeConfig.anonymize != nil { + if anonymizeConfig != nil && anonymizeConfig.analyze != nil && + anonymizeConfig.anonymize != nil { execOpts = append( execOpts, - transformer_executor.WithTransformPiiTextConfig(anonymizeConfig.analyze, anonymizeConfig.anonymize, newNeosyncOperatorApi(execOpts), anonymizeConfig.defaultLanguage), + transformer_executor.WithTransformPiiTextConfig( + anonymizeConfig.analyze, + anonymizeConfig.anonymize, + newNeosyncOperatorApi(execOpts), + anonymizeConfig.defaultLanguage, + ), ) } for _, mapping := range transformerMappings { - executor, err := transformer_executor.InitializeTransformerByConfigType(mapping.GetTransformer(), execOpts...) + executor, err := transformer_executor.InitializeTransformerByConfigType( + mapping.GetTransformer(), + execOpts...) if err != nil { - return nil, fmt.Errorf("failed to initialize transformer for expression '%s': %v", mapping.GetExpression(), err) + return nil, fmt.Errorf( + "failed to initialize transformer for expression '%s': %v", + mapping.GetExpression(), + err, + ) } executors = append(executors, executor) } @@ -399,26 +439,41 @@ func initDefaultTransformerExecutors( transformer_executor.WithLogger(logger), transformer_executor.WithUserDefinedTransformerResolver(newUdtResolver(transformerClient)), } - if anonymizeConfig != nil && anonymizeConfig.analyze != nil && anonymizeConfig.anonymize != nil { - execOpts = append(execOpts, transformer_executor.WithTransformPiiTextConfig(anonymizeConfig.analyze, anonymizeConfig.anonymize, newNeosyncOperatorApi(execOpts), anonymizeConfig.defaultLanguage)) + if anonymizeConfig != nil && anonymizeConfig.analyze != nil && + anonymizeConfig.anonymize != nil { + execOpts = append( + execOpts, + transformer_executor.WithTransformPiiTextConfig( + anonymizeConfig.analyze, + anonymizeConfig.anonymize, + newNeosyncOperatorApi(execOpts), + anonymizeConfig.defaultLanguage, + ), + ) } var stringExecutor, numberExecutor, booleanExecutor *transformer_executor.TransformerExecutor var err error if defaultTransformer.S != nil { - stringExecutor, err = transformer_executor.InitializeTransformerByConfigType(defaultTransformer.S, execOpts...) + stringExecutor, err = transformer_executor.InitializeTransformerByConfigType( + defaultTransformer.S, + execOpts...) if err != nil { return nil, err } } if defaultTransformer.N != nil { - numberExecutor, err = transformer_executor.InitializeTransformerByConfigType(defaultTransformer.N, execOpts...) + numberExecutor, err = transformer_executor.InitializeTransformerByConfigType( + defaultTransformer.N, + execOpts...) if err != nil { return nil, err } } if defaultTransformer.Boolean != nil { - booleanExecutor, err = transformer_executor.InitializeTransformerByConfigType(defaultTransformer.Boolean, execOpts...) + booleanExecutor, err = transformer_executor.InitializeTransformerByConfigType( + defaultTransformer.Boolean, + execOpts...) if err != nil { return nil, err } diff --git a/internal/json-anonymizer/neosync-operator.go b/internal/json-anonymizer/neosync-operator.go index 7fe21d9249..508e8e9e36 100644 --- a/internal/json-anonymizer/neosync-operator.go +++ b/internal/json-anonymizer/neosync-operator.go @@ -14,11 +14,17 @@ type neosyncOperatorApi struct { opts []transformer_executor.TransformerExecutorOption } -func newNeosyncOperatorApi(executorOpts []transformer_executor.TransformerExecutorOption) *neosyncOperatorApi { +func newNeosyncOperatorApi( + executorOpts []transformer_executor.TransformerExecutorOption, +) *neosyncOperatorApi { return &neosyncOperatorApi{opts: executorOpts} } -func (n *neosyncOperatorApi) Transform(ctx context.Context, config *mgmtv1alpha1.TransformerConfig, value string) (string, error) { +func (n *neosyncOperatorApi) Transform( + ctx context.Context, + config *mgmtv1alpha1.TransformerConfig, + value string, +) (string, error) { executor, err := transformer_executor.InitializeTransformerByConfigType(config, n.opts...) if err != nil { return "", err @@ -46,10 +52,16 @@ func newUdtResolver(transformerClient mgmtv1alpha1connect.TransformersServiceCli return &udtResolver{transformerClient: transformerClient} } -func (u *udtResolver) GetUserDefinedTransformer(ctx context.Context, id string) (*mgmtv1alpha1.TransformerConfig, error) { - resp, err := u.transformerClient.GetUserDefinedTransformerById(ctx, connect.NewRequest(&mgmtv1alpha1.GetUserDefinedTransformerByIdRequest{ - TransformerId: id, - })) +func (u *udtResolver) GetUserDefinedTransformer( + ctx context.Context, + id string, +) (*mgmtv1alpha1.TransformerConfig, error) { + resp, err := u.transformerClient.GetUserDefinedTransformerById( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetUserDefinedTransformerByIdRequest{ + TransformerId: id, + }), + ) if err != nil { return nil, err } diff --git a/internal/neosync-types/array.go b/internal/neosync-types/array.go index e998f23078..5606a7699a 100644 --- a/internal/neosync-types/array.go +++ b/internal/neosync-types/array.go @@ -8,7 +8,7 @@ import ( ) type NeosyncArray struct { - BaseType `json:",inline"` + BaseType ` json:",inline"` Elements []NeosyncAdapter `json:"elements"` } @@ -43,7 +43,11 @@ func (a *NeosyncArray) ScanPgx(value any) error { return err } if len(valueSlice) != len(a.Elements) { - return fmt.Errorf("length mismatch: got %d elements, expected %d", len(valueSlice), len(a.Elements)) + return fmt.Errorf( + "length mismatch: got %d elements, expected %d", + len(valueSlice), + len(a.Elements), + ) } for i, v := range valueSlice { if err := a.Elements[i].ScanPgx(v); err != nil { @@ -71,7 +75,11 @@ func (a *NeosyncArray) ScanJson(value any) error { return err } if len(valueSlice) != len(a.Elements) { - return fmt.Errorf("length mismatch: got %d elements, expected %d", len(valueSlice), len(a.Elements)) + return fmt.Errorf( + "length mismatch: got %d elements, expected %d", + len(valueSlice), + len(a.Elements), + ) } for i, v := range valueSlice { if err := a.Elements[i].ScanJson(v); err != nil { @@ -99,7 +107,11 @@ func (a *NeosyncArray) ScanMysql(value any) error { return err } if len(valueSlice) != len(a.Elements) { - return fmt.Errorf("length mismatch: got %d elements, expected %d", len(valueSlice), len(a.Elements)) + return fmt.Errorf( + "length mismatch: got %d elements, expected %d", + len(valueSlice), + len(a.Elements), + ) } for i, v := range valueSlice { if err := a.Elements[i].ScanMysql(v); err != nil { diff --git a/internal/neosync-types/binary.go b/internal/neosync-types/binary.go index fd231a0b86..4e693debcf 100644 --- a/internal/neosync-types/binary.go +++ b/internal/neosync-types/binary.go @@ -1,8 +1,8 @@ package neosynctypes type Binary struct { - BaseType `json:",inline"` - JsonScanner `json:"-"` + BaseType ` json:",inline"` + JsonScanner ` json:"-"` Bytes []byte `json:"bytes"` } @@ -110,7 +110,11 @@ func NewBinaryFromMssql(value any, opts ...NeosyncTypeOption) (*Binary, error) { return binary, nil } -func NewBinaryArrayFromPgx(elements [][]byte, opts []NeosyncTypeOption, arrayOpts ...NeosyncTypeOption) (*NeosyncArray, error) { +func NewBinaryArrayFromPgx( + elements [][]byte, + opts []NeosyncTypeOption, + arrayOpts ...NeosyncTypeOption, +) (*NeosyncArray, error) { neosyncAdapters := make([]NeosyncAdapter, len(elements)) for i, e := range elements { newBinary, err := NewBinary(opts...) diff --git a/internal/neosync-types/bits.go b/internal/neosync-types/bits.go index f76b98214c..a404121b05 100644 --- a/internal/neosync-types/bits.go +++ b/internal/neosync-types/bits.go @@ -8,8 +8,8 @@ import ( ) type Bits struct { - BaseType `json:",inline"` - JsonScanner `json:"-"` + BaseType ` json:",inline"` + JsonScanner ` json:"-"` Bytes []byte `json:"bytes"` Len int32 `json:"len"` } @@ -129,7 +129,11 @@ func NewBits(opts ...NeosyncTypeOption) (*Bits, error) { return bits, nil } -func NewBitsArrayFromPgx(elements []*pgtype.Bits, opts []NeosyncTypeOption, arrayOpts ...NeosyncTypeOption) (*NeosyncArray, error) { +func NewBitsArrayFromPgx( + elements []*pgtype.Bits, + opts []NeosyncTypeOption, + arrayOpts ...NeosyncTypeOption, +) (*NeosyncArray, error) { neosyncAdapters := make([]NeosyncAdapter, len(elements)) for i, e := range elements { newBits, err := NewBits(opts...) diff --git a/internal/neosync-types/datetime.go b/internal/neosync-types/datetime.go index b8a8c7579f..5eb8d3b11f 100644 --- a/internal/neosync-types/datetime.go +++ b/internal/neosync-types/datetime.go @@ -338,7 +338,11 @@ func NewDateTime(opts ...NeosyncTypeOption) (*NeosyncDateTime, error) { return dt, nil } -func NewDateTimeArrayFromPgx(elements []time.Time, opts []NeosyncTypeOption, arrayOpts ...NeosyncTypeOption) (*NeosyncArray, error) { +func NewDateTimeArrayFromPgx( + elements []time.Time, + opts []NeosyncTypeOption, + arrayOpts ...NeosyncTypeOption, +) (*NeosyncArray, error) { neosyncAdapters := make([]NeosyncAdapter, len(elements)) for i, e := range elements { newDateTime, err := NewDateTime(opts...) diff --git a/internal/neosync-types/interval.go b/internal/neosync-types/interval.go index c43ef2b444..ff134ff7ce 100644 --- a/internal/neosync-types/interval.go +++ b/internal/neosync-types/interval.go @@ -8,8 +8,8 @@ import ( ) type Interval struct { - BaseType `json:",inline"` - JsonScanner `json:"-"` + BaseType ` json:",inline"` + JsonScanner ` json:"-"` Microseconds int64 `json:"microseconds"` Days int32 `json:"days"` Months int32 `json:"months"` @@ -96,7 +96,11 @@ func NewInterval(opts ...NeosyncTypeOption) (*Interval, error) { return interval, nil } -func NewIntervalArrayFromPgx(elements []*pgtype.Interval, opts []NeosyncTypeOption, arrayOpts ...NeosyncTypeOption) (*NeosyncArray, error) { +func NewIntervalArrayFromPgx( + elements []*pgtype.Interval, + opts []NeosyncTypeOption, + arrayOpts ...NeosyncTypeOption, +) (*NeosyncArray, error) { neosyncAdapters := make([]NeosyncAdapter, len(elements)) for i, e := range elements { newInterval, err := NewInterval(opts...) diff --git a/internal/neosync-types/registry.go b/internal/neosync-types/registry.go index 7cf00eba54..79b0e77ab3 100644 --- a/internal/neosync-types/registry.go +++ b/internal/neosync-types/registry.go @@ -44,7 +44,11 @@ func NewTypeRegistry(logger *slog.Logger) *TypeRegistry { return registry } -func (r *TypeRegistry) Register(typeId string, version Version, newTypeFunc func() (NeosyncAdapter, error)) { +func (r *TypeRegistry) Register( + typeId string, + version Version, + newTypeFunc func() (NeosyncAdapter, error), +) { if _, exists := r.types[typeId]; !exists { r.types[typeId] = make(map[Version]func() (NeosyncAdapter, error)) } @@ -63,12 +67,22 @@ func (r *TypeRegistry) New(typeId string, version Version) (NeosyncAdapter, erro } // Try LatestVersion - r.logger.Warn(fmt.Sprintf("version %d not registered for Type Id: %s using latest version instead", version, typeId)) + r.logger.Warn( + fmt.Sprintf( + "version %d not registered for Type Id: %s using latest version instead", + version, + typeId, + ), + ) if newTypeFunc, ok := versionedTypes[LatestVersion]; ok { return newTypeFunc() } - return nil, fmt.Errorf("unknown version %d for type Id: %s. latest version not found", version, typeId) + return nil, fmt.Errorf( + "unknown version %d for type Id: %s. latest version not found", + version, + typeId, + ) } // UnmarshalAny deserializes a value of type any into an appropriate type based on the Neosync type system. diff --git a/internal/neosyncdb/db.go b/internal/neosyncdb/db.go index d55422af31..8e07449dc7 100644 --- a/internal/neosyncdb/db.go +++ b/internal/neosyncdb/db.go @@ -25,7 +25,12 @@ type BaseDBTX interface { Query(context.Context, string, ...any) (pgx.Rows, error) QueryRow(context.Context, string, ...any) pgx.Row - CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) + CopyFrom( + ctx context.Context, + tableName pgx.Identifier, + columnNames []string, + rowSrc pgx.CopyFromSource, + ) (int64, error) } type NeosyncDb struct { diff --git a/internal/neosyncdb/users.go b/internal/neosyncdb/users.go index c02bfafd23..ce335340da 100644 --- a/internal/neosyncdb/users.go +++ b/internal/neosyncdb/users.go @@ -95,7 +95,12 @@ type upsertPersonalAccountResponse struct { Account *db_queries.NeosyncApiAccount } -func upsertPersonalAccount(ctx context.Context, q db_queries.Querier, dbtx BaseDBTX, req *upsertPersonalAccountRequest) (*upsertPersonalAccountResponse, error) { +func upsertPersonalAccount( + ctx context.Context, + q db_queries.Querier, + dbtx BaseDBTX, + req *upsertPersonalAccountRequest, +) (*upsertPersonalAccountResponse, error) { resp := &upsertPersonalAccountResponse{} account, err := q.GetPersonalAccountByUserId(ctx, dbtx, req.UserId) if err != nil && !IsNoRows(err) { @@ -175,13 +180,18 @@ func (d *NeosyncDb) CreateTeamAccount( func verifyAccountNameUnique(accounts []db_queries.NeosyncApiAccount, name string) error { for idx := range accounts { if strings.EqualFold(accounts[idx].AccountSlug, name) { - return nucleuserrors.NewAlreadyExists(fmt.Sprintf("team account with the name %s already exists", name)) + return nucleuserrors.NewAlreadyExists( + fmt.Sprintf("team account with the name %s already exists", name), + ) } } return nil } -func getAccountById(accounts []db_queries.NeosyncApiAccount, id pgtype.UUID) (*db_queries.NeosyncApiAccount, error) { +func getAccountById( + accounts []db_queries.NeosyncApiAccount, + id pgtype.UUID, +) (*db_queries.NeosyncApiAccount, error) { for idx := range accounts { if accounts[idx].ID.Valid && id.Valid && UUIDString(accounts[idx].ID) == UUIDString(id) { return &accounts[idx], nil diff --git a/internal/neosyncdb/util.go b/internal/neosyncdb/util.go index d566fcf3c6..13a35dbbff 100644 --- a/internal/neosyncdb/util.go +++ b/internal/neosyncdb/util.go @@ -64,7 +64,9 @@ func GetDbUrl(cfg *ConnectConfig) string { pgOpts["x-migrations-table"] = []string{*cfg.MigrationsTableName} } if cfg.MigrationsTableQuoted != nil { - pgOpts["x-migrations-table-quoted"] = []string{strconv.FormatBool(*cfg.MigrationsTableQuoted)} + pgOpts["x-migrations-table-quoted"] = []string{ + strconv.FormatBool(*cfg.MigrationsTableQuoted), + } } if cfg.Options != nil { pgOpts["options"] = []string{*cfg.Options} @@ -76,7 +78,14 @@ func GetDbUrl(cfg *ConnectConfig) string { } func UUIDString(value pgtype.UUID) string { - return fmt.Sprintf("%x-%x-%x-%x-%x", value.Bytes[0:4], value.Bytes[4:6], value.Bytes[6:8], value.Bytes[8:10], value.Bytes[10:16]) + return fmt.Sprintf( + "%x-%x-%x-%x-%x", + value.Bytes[0:4], + value.Bytes[4:6], + value.Bytes[6:8], + value.Bytes[8:10], + value.Bytes[10:16], + ) } func UUIDStrings(values []pgtype.UUID) []string { diff --git a/internal/otel/otel.go b/internal/otel/otel.go index 4d7cbf958d..91bc121b8e 100644 --- a/internal/otel/otel.go +++ b/internal/otel/otel.go @@ -94,7 +94,10 @@ type TraceProviderConfig struct { Opts TraceExporterOpts } -func NewTraceProvider(ctx context.Context, config *TraceProviderConfig) (*tracesdk.TracerProvider, error) { +func NewTraceProvider( + ctx context.Context, + config *TraceProviderConfig, +) (*tracesdk.TracerProvider, error) { exporter, err := getTraceExporter(ctx, config.Exporter, config.Opts) if err != nil { return nil, err @@ -119,7 +122,11 @@ const ( noneExporter = "none" ) -func getTraceExporter(ctx context.Context, exporter string, opts TraceExporterOpts) (tracesdk.SpanExporter, error) { +func getTraceExporter( + ctx context.Context, + exporter string, + opts TraceExporterOpts, +) (tracesdk.SpanExporter, error) { switch exporter { case otlpExporter: return otlptracegrpc.New(ctx, opts.Otlp...) @@ -128,7 +135,11 @@ func getTraceExporter(ctx context.Context, exporter string, opts TraceExporterOp case noneExporter: return nil, nil default: - return nil, fmt.Errorf("this tracer exporter is not currently supported %q: %w", exporter, errors.ErrUnsupported) + return nil, fmt.Errorf( + "this tracer exporter is not currently supported %q: %w", + exporter, + errors.ErrUnsupported, + ) } } @@ -138,7 +149,10 @@ type MeterProviderConfig struct { AppVersion string } -func NewMeterProvider(ctx context.Context, config *MeterProviderConfig) (*metricsdk.MeterProvider, error) { +func NewMeterProvider( + ctx context.Context, + config *MeterProviderConfig, +) (*metricsdk.MeterProvider, error) { exporter, err := getMeterExporter(ctx, config.Exporter, config.Opts) if err != nil { return nil, err @@ -158,7 +172,11 @@ type MeterExporterOpts struct { Console []stdoutmetric.Option } -func getMeterExporter(ctx context.Context, exporter string, opts MeterExporterOpts) (metricsdk.Exporter, error) { +func getMeterExporter( + ctx context.Context, + exporter string, + opts MeterExporterOpts, +) (metricsdk.Exporter, error) { switch exporter { case otlpExporter: return otlpmetricgrpc.New(ctx, opts.Otlp...) @@ -167,22 +185,30 @@ func getMeterExporter(ctx context.Context, exporter string, opts MeterExporterOp case noneExporter: return nil, nil default: - return nil, fmt.Errorf("this meter exporter is not currently supported %q: %w", exporter, errors.ErrUnsupported) + return nil, fmt.Errorf( + "this meter exporter is not currently supported %q: %w", + exporter, + errors.ErrUnsupported, + ) } } func WithDefaultDeltaTemporalitySelector() otlpmetricgrpc.Option { - return otlpmetricgrpc.WithTemporalitySelector(func(ik metricsdk.InstrumentKind) metricdata.Temporality { - // Delta Temporality causes metrics to be reset after some time. - // We are using this today for benthos metrics so that they don't persist indefinitely in the time series database - return metricdata.DeltaTemporality - }) + return otlpmetricgrpc.WithTemporalitySelector( + func(ik metricsdk.InstrumentKind) metricdata.Temporality { + // Delta Temporality causes metrics to be reset after some time. + // We are using this today for benthos metrics so that they don't persist indefinitely in the time series database + return metricdata.DeltaTemporality + }, + ) } func withCumulativeTemporalitySelector() otlpmetricgrpc.Option { - return otlpmetricgrpc.WithTemporalitySelector(func(ik metricsdk.InstrumentKind) metricdata.Temporality { - return metricdata.CumulativeTemporality - }) + return otlpmetricgrpc.WithTemporalitySelector( + func(ik metricsdk.InstrumentKind) metricdata.Temporality { + return metricdata.CumulativeTemporality + }, + ) } type OtelEnvConfig struct { diff --git a/internal/pgx-slog/adapter.go b/internal/pgx-slog/adapter.go index cd2f11ddd5..0908a3892b 100644 --- a/internal/pgx-slog/adapter.go +++ b/internal/pgx-slog/adapter.go @@ -21,7 +21,12 @@ func NewLogger(l *slog.Logger, omitArgs bool) *Logger { return &Logger{l: l, omitArgs: omitArgs} } -func (l *Logger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { +func (l *Logger) Log( + ctx context.Context, + level tracelog.LogLevel, + msg string, + data map[string]any, +) { if level == tracelog.LogLevelNone { return } diff --git a/internal/runconfigs/builder.go b/internal/runconfigs/builder.go index e8889507a6..272bc5b1b8 100644 --- a/internal/runconfigs/builder.go +++ b/internal/runconfigs/builder.go @@ -156,7 +156,9 @@ func (b *tableConfigsBuilder) computeAllSubsetPaths() map[string][]*SubsetPath { for _, fc := range b.foreignKeys[child] { if fc.ForeignKey != nil && fc.ForeignKey.Table == entry.current { if len(fc.ForeignKey.Columns) > 0 && len(fc.Columns) > 0 { - referenceSchema, referenceTable := sqlmanager_shared.SplitTableKey(fc.ForeignKey.Table) + referenceSchema, referenceTable := sqlmanager_shared.SplitTableKey( + fc.ForeignKey.Table, + ) js = &JoinStep{ ToKey: entry.current, FromKey: child, @@ -378,7 +380,15 @@ func (b *runConfigBuilder) buildConstraintHandlingConfigs() []*RunConfig { if len(updateConfigs) > 0 { prevConfig = updateConfigs[len(updateConfigs)-1] } - updateConfig := b.buildUpdateConfig(fc, updateCols, updateFkCols, where, orderByColumns, updateConfigCount, prevConfig) + updateConfig := b.buildUpdateConfig( + fc, + updateCols, + updateFkCols, + where, + orderByColumns, + updateConfigCount, + prevConfig, + ) updateConfigs = append(updateConfigs, updateConfig) } } diff --git a/internal/runconfigs/circular-dependencies.go b/internal/runconfigs/circular-dependencies.go index 977d64ed7d..a253d0f2db 100644 --- a/internal/runconfigs/circular-dependencies.go +++ b/internal/runconfigs/circular-dependencies.go @@ -17,7 +17,13 @@ func FindCircularDependencies(dependencies map[string][]string) [][]string { } // finds all possible path variations -func dfsCycles(start, current string, dependencies map[string][]string, recStack map[string]bool, path []string, result *[][]string) { +func dfsCycles( + start, current string, + dependencies map[string][]string, + recStack map[string]bool, + path []string, + result *[][]string, +) { if recStack[current] { if current == start { // make copy to prevent reference issues diff --git a/internal/runconfigs/runconfigs.go b/internal/runconfigs/runconfigs.go index 28d7b784bc..6c0df95b87 100644 --- a/internal/runconfigs/runconfigs.go +++ b/internal/runconfigs/runconfigs.go @@ -183,10 +183,20 @@ func (rc *RunConfig) String() string { sb.WriteString(fmt.Sprintf(" [%d] Root: %s, Subset: %s\n", i, sp.Root, sp.Subset)) sb.WriteString(" JoinSteps:\n") for j, js := range sp.JoinSteps { - sb.WriteString(fmt.Sprintf(" [%d] FromKey: %s, ToKey: %s\n", j, js.FromKey, js.ToKey)) + sb.WriteString( + fmt.Sprintf(" [%d] FromKey: %s, ToKey: %s\n", j, js.FromKey, js.ToKey), + ) if js.ForeignKey != nil { - sb.WriteString(fmt.Sprintf(" FK: Columns: %v, NotNullable: %v, ReferenceSchema: %s, ReferenceTable: %s, ReferenceColumns: %v\n", - js.ForeignKey.Columns, js.ForeignKey.NotNullable, js.ForeignKey.ReferenceSchema, js.ForeignKey.ReferenceTable, js.ForeignKey.ReferenceColumns)) + sb.WriteString( + fmt.Sprintf( + " FK: Columns: %v, NotNullable: %v, ReferenceSchema: %s, ReferenceTable: %s, ReferenceColumns: %v\n", + js.ForeignKey.Columns, + js.ForeignKey.NotNullable, + js.ForeignKey.ReferenceSchema, + js.ForeignKey.ReferenceTable, + js.ForeignKey.ReferenceColumns, + ), + ) } } } @@ -215,7 +225,14 @@ func BuildRunConfigs( // filter dependencies to only include tables in tableColumnsMap (jobmappings) filteredFks := filterDependencies(dependencyMap, tableColumnsMap) - tableConfigsBuilder := newTableConfigsBuilder(tableColumnsMap, primaryKeyMap, subsets, uniqueIndexesMap, uniqueConstraintsMap, filteredFks) + tableConfigsBuilder := newTableConfigsBuilder( + tableColumnsMap, + primaryKeyMap, + subsets, + uniqueIndexesMap, + uniqueConstraintsMap, + filteredFks, + ) // build configs for each table for schematable := range tableColumnsMap { @@ -230,7 +247,9 @@ func BuildRunConfigs( // check run path if !isValidRunOrder(configs) { - return nil, errors.New("unsupported circular dependency detected. at least one foreign key in circular dependency must be nullable") + return nil, errors.New( + "unsupported circular dependency detected. at least one foreign key in circular dependency must be nullable", + ) } return configs, nil @@ -295,7 +314,9 @@ func isValidRunOrder(configs []*RunConfig) bool { prevTableLen = len(configMap) for id, config := range configMap { if AreConfigDependenciesSatisfied(config.DependsOn(), seenTables) { - seenTables[config.Table()] = append(seenTables[config.Table()], config.InsertColumns()...) + seenTables[config.Table()] = append( + seenTables[config.Table()], + config.InsertColumns()...) delete(configMap, id) } } diff --git a/internal/schema-manager/mssql/mssql.go b/internal/schema-manager/mssql/mssql.go index a6ef9a4852..ed6ab6dff5 100644 --- a/internal/schema-manager/mssql/mssql.go +++ b/internal/schema-manager/mssql/mssql.go @@ -60,14 +60,19 @@ func NewMssqlSchemaManager( }, nil } -func (d *MssqlSchemaManager) InitializeSchema(ctx context.Context, uniqueTables map[string]struct{}) ([]*shared.InitSchemaError, error) { +func (d *MssqlSchemaManager) InitializeSchema( + ctx context.Context, + uniqueTables map[string]struct{}, +) ([]*shared.InitSchemaError, error) { initErrors := []*shared.InitSchemaError{} if !d.destOpts.GetInitTableSchema() { d.logger.Info("skipping schema init as it is not enabled") return initErrors, nil } if !d.eelicense.IsValid() { - return nil, fmt.Errorf("invalid or non-existent Neosync License. SQL Server schema init requires valid Enterprise license") + return nil, fmt.Errorf( + "invalid or non-existent Neosync License. SQL Server schema init requires valid Enterprise license", + ) } tables := []*sqlmanager_shared.SchemaTable{} for tableKey := range uniqueTables { @@ -81,16 +86,30 @@ func (d *MssqlSchemaManager) InitializeSchema(ctx context.Context, uniqueTables } for _, block := range initblocks { - d.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements))) + d.logger.Info( + fmt.Sprintf( + "[%s] found %d statements to execute during schema initialization", + block.Label, + len(block.Statements), + ), + ) if len(block.Statements) == 0 { continue } for _, stmt := range block.Statements { err = d.destdb.Db().Exec(ctx, stmt) if err != nil { - d.logger.Error(fmt.Sprintf("unable to exec mssql %s statements: %s", block.Label, err.Error())) - if block.Label != ee_sqlmanager_mssql.SchemasLabel && block.Label != ee_sqlmanager_mssql.ViewsFunctionsLabel && block.Label != ee_sqlmanager_mssql.TableIndexLabel { - return nil, fmt.Errorf("unable to exec mssql %s statements: %w", block.Label, err) + d.logger.Error( + fmt.Sprintf("unable to exec mssql %s statements: %s", block.Label, err.Error()), + ) + if block.Label != ee_sqlmanager_mssql.SchemasLabel && + block.Label != ee_sqlmanager_mssql.ViewsFunctionsLabel && + block.Label != ee_sqlmanager_mssql.TableIndexLabel { + return nil, fmt.Errorf( + "unable to exec mssql %s statements: %w", + block.Label, + err, + ) } initErrors = append(initErrors, &shared.InitSchemaError{ Statement: stmt, @@ -102,7 +121,11 @@ func (d *MssqlSchemaManager) InitializeSchema(ctx context.Context, uniqueTables return initErrors, nil } -func (d *MssqlSchemaManager) TruncateData(ctx context.Context, uniqueTables map[string]struct{}, uniqueSchemas []string) error { +func (d *MssqlSchemaManager) TruncateData( + ctx context.Context, + uniqueTables map[string]struct{}, + uniqueSchemas []string, +) error { if !d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() { d.logger.Info("skipping truncate as it is not enabled") return nil @@ -111,9 +134,19 @@ func (d *MssqlSchemaManager) TruncateData(ctx context.Context, uniqueTables map[ if err != nil { return fmt.Errorf("unable to retrieve database foreign key constraints: %w", err) } - d.logger.Info(fmt.Sprintf("found %d foreign key constraints for database", len(tableDependencies.ForeignKeyConstraints))) - tablePrimaryDependencyMap := shared.GetFilteredForeignToPrimaryTableMap(tableDependencies.ForeignKeyConstraints, uniqueTables) - orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(tablePrimaryDependencyMap) + d.logger.Info( + fmt.Sprintf( + "found %d foreign key constraints for database", + len(tableDependencies.ForeignKeyConstraints), + ), + ) + tablePrimaryDependencyMap := shared.GetFilteredForeignToPrimaryTableMap( + tableDependencies.ForeignKeyConstraints, + uniqueTables, + ) + orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency( + tablePrimaryDependencyMap, + ) if err != nil { return err } @@ -128,7 +161,12 @@ func (d *MssqlSchemaManager) TruncateData(ctx context.Context, uniqueTables map[ orderedTableDelete = append(orderedTableDelete, stmt) } - d.logger.Info(fmt.Sprintf("executing %d sql statements that will delete from tables", len(orderedTableDelete))) + d.logger.Info( + fmt.Sprintf( + "executing %d sql statements that will delete from tables", + len(orderedTableDelete), + ), + ) err = d.destdb.Db().BatchExec(ctx, 10, orderedTableDelete, &sqlmanager_shared.BatchExecOpts{}) if err != nil { return fmt.Errorf("unable to exec ordered delete from statements: %w", err) @@ -148,7 +186,12 @@ func (d *MssqlSchemaManager) TruncateData(ctx context.Context, uniqueTables map[ for _, c := range cols { if c.IdentityGeneration != nil && *c.IdentityGeneration != "" { schema, table := sqlmanager_shared.SplitTableKey(table) - identityResetStatement := sqlmanager_mssql.BuildMssqlIdentityColumnResetStatement(schema, table, c.IdentitySeed, c.IdentityIncrement) + identityResetStatement := sqlmanager_mssql.BuildMssqlIdentityColumnResetStatement( + schema, + table, + c.IdentitySeed, + c.IdentityIncrement, + ) identityStmts = append(identityStmts, identityResetStatement) } } @@ -162,19 +205,32 @@ func (d *MssqlSchemaManager) TruncateData(ctx context.Context, uniqueTables map[ return nil } -func (d *MssqlSchemaManager) CalculateSchemaDiff(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable) (*shared.SchemaDifferences, error) { +func (d *MssqlSchemaManager) CalculateSchemaDiff( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, +) (*shared.SchemaDifferences, error) { return nil, errors.ErrUnsupported } -func (d *MssqlSchemaManager) BuildSchemaDiffStatements(ctx context.Context, diff *shared.SchemaDifferences) ([]*sqlmanager_shared.InitSchemaStatements, error) { +func (d *MssqlSchemaManager) BuildSchemaDiffStatements( + ctx context.Context, + diff *shared.SchemaDifferences, +) ([]*sqlmanager_shared.InitSchemaStatements, error) { return nil, errors.ErrUnsupported } -func (d *MssqlSchemaManager) ReconcileDestinationSchema(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable, schemaStatements []*sqlmanager_shared.InitSchemaStatements) ([]*shared.InitSchemaError, error) { +func (d *MssqlSchemaManager) ReconcileDestinationSchema( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, + schemaStatements []*sqlmanager_shared.InitSchemaStatements, +) ([]*shared.InitSchemaError, error) { return nil, errors.ErrUnsupported } -func (d *MssqlSchemaManager) TruncateTables(ctx context.Context, schemaDiff *shared.SchemaDifferences) error { +func (d *MssqlSchemaManager) TruncateTables( + ctx context.Context, + schemaDiff *shared.SchemaDifferences, +) error { return errors.ErrUnsupported } func (d *MssqlSchemaManager) CloseConnections() { diff --git a/internal/schema-manager/mysql/mysql.go b/internal/schema-manager/mysql/mysql.go index 0e1a10143c..33f692123b 100644 --- a/internal/schema-manager/mysql/mysql.go +++ b/internal/schema-manager/mysql/mysql.go @@ -56,7 +56,10 @@ func NewMysqlSchemaManager( }, nil } -func (d *MysqlSchemaManager) CalculateSchemaDiff(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable) (*shared.SchemaDifferences, error) { +func (d *MysqlSchemaManager) CalculateSchemaDiff( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, +) (*shared.SchemaDifferences, error) { d.logger.Debug("calculating schema diff") tables := []*sqlmanager_shared.SchemaTable{} schemaMap := map[string][]*sqlmanager_shared.SchemaTable{} @@ -150,7 +153,11 @@ func getDatabaseDataForSchemaDiff( errgrp.Go(func() error { tableconstraints, err := db.Db().GetTableConstraintsByTables(errctx, schema, tableNames) if err != nil { - return fmt.Errorf("failed to retrieve database table constraints for schema %s: %w", schema, err) + return fmt.Errorf( + "failed to retrieve database table constraints for schema %s: %w", + schema, + err, + ) } mu.Lock() defer mu.Unlock() @@ -180,7 +187,10 @@ func getDatabaseDataForSchemaDiff( }, nil } -func (d *MysqlSchemaManager) BuildSchemaDiffStatements(ctx context.Context, diff *shared.SchemaDifferences) ([]*sqlmanager_shared.InitSchemaStatements, error) { +func (d *MysqlSchemaManager) BuildSchemaDiffStatements( + ctx context.Context, + diff *shared.SchemaDifferences, +) ([]*sqlmanager_shared.InitSchemaStatements, error) { d.logger.Debug("building schema diff statements") if !d.destOpts.GetInitTableSchema() { d.logger.Info("skipping schema init as it is not enabled") @@ -197,23 +207,45 @@ func (d *MysqlSchemaManager) BuildSchemaDiffStatements(ctx context.Context, diff dropNonFkConstraintStatements := []string{} for _, constraint := range diff.ExistsInDestination.NonForeignKeyConstraints { - dropNonFkConstraintStatements = append(dropNonFkConstraintStatements, sqlmanager_mysql.BuildDropConstraintStatement(constraint.SchemaName, constraint.TableName, constraint.ConstraintType, constraint.ConstraintName)) + dropNonFkConstraintStatements = append( + dropNonFkConstraintStatements, + sqlmanager_mysql.BuildDropConstraintStatement( + constraint.SchemaName, + constraint.TableName, + constraint.ConstraintType, + constraint.ConstraintName, + ), + ) } orderedForeignKeysToDrop := shared.BuildOrderedForeignKeyConstraintsToDrop(d.logger, diff) orderedForeignKeyDropStatements := []string{} for _, fk := range orderedForeignKeysToDrop { - orderedForeignKeyDropStatements = append(orderedForeignKeyDropStatements, sqlmanager_mysql.BuildDropConstraintStatement(fk.ReferencingSchema, fk.ReferencingTable, fk.ConstraintType, fk.ConstraintName)) + orderedForeignKeyDropStatements = append( + orderedForeignKeyDropStatements, + sqlmanager_mysql.BuildDropConstraintStatement( + fk.ReferencingSchema, + fk.ReferencingTable, + fk.ConstraintType, + fk.ConstraintName, + ), + ) } dropColumnStatements := []string{} for _, column := range diff.ExistsInDestination.Columns { - dropColumnStatements = append(dropColumnStatements, sqlmanager_mysql.BuildDropColumnStatement(column)) + dropColumnStatements = append( + dropColumnStatements, + sqlmanager_mysql.BuildDropColumnStatement(column), + ) } dropTriggerStatements := []string{} for _, trigger := range diff.ExistsInDestination.Triggers { - dropTriggerStatements = append(dropTriggerStatements, sqlmanager_mysql.BuildDropTriggerStatement(trigger.TriggerSchema, trigger.TriggerName)) + dropTriggerStatements = append( + dropTriggerStatements, + sqlmanager_mysql.BuildDropTriggerStatement(trigger.TriggerSchema, trigger.TriggerName), + ) } updateColumnStatements := []string{} @@ -227,7 +259,10 @@ func (d *MysqlSchemaManager) BuildSchemaDiffStatements(ctx context.Context, diff dropFunctionStatements := []string{} for _, function := range diff.ExistsInDestination.Functions { - dropFunctionStatements = append(dropFunctionStatements, sqlmanager_mysql.BuildDropFunctionStatement(function.Schema, function.Name)) + dropFunctionStatements = append( + dropFunctionStatements, + sqlmanager_mysql.BuildDropFunctionStatement(function.Schema, function.Name), + ) } return []*sqlmanager_shared.InitSchemaStatements{ @@ -262,7 +297,11 @@ func (d *MysqlSchemaManager) BuildSchemaDiffStatements(ctx context.Context, diff }, nil } -func (d *MysqlSchemaManager) ReconcileDestinationSchema(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable, schemaStatements []*sqlmanager_shared.InitSchemaStatements) ([]*shared.InitSchemaError, error) { +func (d *MysqlSchemaManager) ReconcileDestinationSchema( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, + schemaStatements []*sqlmanager_shared.InitSchemaStatements, +) ([]*shared.InitSchemaError, error) { d.logger.Debug("reconciling destination schema") initErrors := []*shared.InitSchemaError{} if !d.destOpts.GetInitTableSchema() { @@ -281,7 +320,10 @@ func (d *MysqlSchemaManager) ReconcileDestinationSchema(ctx context.Context, uni schemaStatementsByLabel := map[string][]*sqlmanager_shared.InitSchemaStatements{} for _, statement := range schemaStatements { - schemaStatementsByLabel[statement.Label] = append(schemaStatementsByLabel[statement.Label], statement) + schemaStatementsByLabel[statement.Label] = append( + schemaStatementsByLabel[statement.Label], + statement, + ) } // insert add columns statements after create table statements @@ -291,28 +333,52 @@ func (d *MysqlSchemaManager) ReconcileDestinationSchema(ctx context.Context, uni for _, statement := range initblocks { statementBlocks = append(statementBlocks, statement) if statement.Label == sqlmanager_shared.SchemasLabel { - statementBlocks = append(statementBlocks, schemaStatementsByLabel[sqlmanager_shared.DropFunctionsLabel]...) + statementBlocks = append( + statementBlocks, + schemaStatementsByLabel[sqlmanager_shared.DropFunctionsLabel]...) } if statement.Label == sqlmanager_shared.CreateTablesLabel { - statementBlocks = append(statementBlocks, schemaStatementsByLabel[sqlmanager_shared.DropTriggersLabel]...) - statementBlocks = append(statementBlocks, schemaStatementsByLabel[sqlmanager_shared.DropForeignKeyConstraintsLabel]...) - statementBlocks = append(statementBlocks, schemaStatementsByLabel[sqlmanager_shared.DropNonForeignKeyConstraintsLabel]...) - statementBlocks = append(statementBlocks, schemaStatementsByLabel[sqlmanager_shared.DropColumnsLabel]...) - statementBlocks = append(statementBlocks, schemaStatementsByLabel[sqlmanager_shared.AddColumnsLabel]...) - statementBlocks = append(statementBlocks, schemaStatementsByLabel[sqlmanager_shared.UpdateColumnsLabel]...) + statementBlocks = append( + statementBlocks, + schemaStatementsByLabel[sqlmanager_shared.DropTriggersLabel]...) + statementBlocks = append( + statementBlocks, + schemaStatementsByLabel[sqlmanager_shared.DropForeignKeyConstraintsLabel]...) + statementBlocks = append( + statementBlocks, + schemaStatementsByLabel[sqlmanager_shared.DropNonForeignKeyConstraintsLabel]...) + statementBlocks = append( + statementBlocks, + schemaStatementsByLabel[sqlmanager_shared.DropColumnsLabel]...) + statementBlocks = append( + statementBlocks, + schemaStatementsByLabel[sqlmanager_shared.AddColumnsLabel]...) + statementBlocks = append( + statementBlocks, + schemaStatementsByLabel[sqlmanager_shared.UpdateColumnsLabel]...) } } for _, block := range statementBlocks { - d.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements))) + d.logger.Info( + fmt.Sprintf( + "[%s] found %d statements to execute during schema initialization", + block.Label, + len(block.Statements), + ), + ) if len(block.Statements) == 0 { continue } - err = d.destdb.Db().BatchExec(ctx, shared.BatchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{}) + err = d.destdb.Db(). + BatchExec(ctx, shared.BatchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{}) if err != nil { - d.logger.Error(fmt.Sprintf("unable to exec mysql %s statements: %s", block.Label, err.Error())) + d.logger.Error( + fmt.Sprintf("unable to exec mysql %s statements: %s", block.Label, err.Error()), + ) for _, stmt := range block.Statements { - err = d.destdb.Db().BatchExec(ctx, 1, []string{stmt}, &sqlmanager_shared.BatchExecOpts{}) + err = d.destdb.Db(). + BatchExec(ctx, 1, []string{stmt}, &sqlmanager_shared.BatchExecOpts{}) if err != nil { initErrors = append(initErrors, &shared.InitSchemaError{ Statement: stmt, @@ -325,29 +391,41 @@ func (d *MysqlSchemaManager) ReconcileDestinationSchema(ctx context.Context, uni return initErrors, nil } -func (d *MysqlSchemaManager) TruncateTables(ctx context.Context, schemaDiff *shared.SchemaDifferences) error { +func (d *MysqlSchemaManager) TruncateTables( + ctx context.Context, + schemaDiff *shared.SchemaDifferences, +) error { if !d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() { d.logger.Info("skipping truncate as it is not enabled") return nil } tableTruncate := []string{} for _, schemaTable := range schemaDiff.ExistsInBoth.Tables { - stmt, err := sqlmanager_mysql.BuildMysqlTruncateStatement(schemaTable.Schema, schemaTable.Table) + stmt, err := sqlmanager_mysql.BuildMysqlTruncateStatement( + schemaTable.Schema, + schemaTable.Table, + ) if err != nil { return err } tableTruncate = append(tableTruncate, stmt) } - d.logger.Info(fmt.Sprintf("executing %d sql statements that will truncate tables", len(tableTruncate))) + d.logger.Info( + fmt.Sprintf("executing %d sql statements that will truncate tables", len(tableTruncate)), + ) disableFkChecks := sqlmanager_shared.DisableForeignKeyChecks - err := d.destdb.Db().BatchExec(ctx, shared.BatchSizeConst, tableTruncate, &sqlmanager_shared.BatchExecOpts{Prefix: &disableFkChecks}) + err := d.destdb.Db(). + BatchExec(ctx, shared.BatchSizeConst, tableTruncate, &sqlmanager_shared.BatchExecOpts{Prefix: &disableFkChecks}) if err != nil { return err } return nil } -func (d *MysqlSchemaManager) InitializeSchema(ctx context.Context, uniqueTables map[string]struct{}) ([]*shared.InitSchemaError, error) { +func (d *MysqlSchemaManager) InitializeSchema( + ctx context.Context, + uniqueTables map[string]struct{}, +) ([]*shared.InitSchemaError, error) { initErrors := []*shared.InitSchemaError{} if !d.destOpts.GetInitTableSchema() { d.logger.Info("skipping schema init as it is not enabled") @@ -365,18 +443,28 @@ func (d *MysqlSchemaManager) InitializeSchema(ctx context.Context, uniqueTables } for _, block := range initblocks { - d.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements))) + d.logger.Info( + fmt.Sprintf( + "[%s] found %d statements to execute during schema initialization", + block.Label, + len(block.Statements), + ), + ) if len(block.Statements) == 0 { continue } - err = d.destdb.Db().BatchExec(ctx, shared.BatchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{}) + err = d.destdb.Db(). + BatchExec(ctx, shared.BatchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{}) if err != nil { - d.logger.Error(fmt.Sprintf("unable to exec mysql %s statements: %s", block.Label, err.Error())) + d.logger.Error( + fmt.Sprintf("unable to exec mysql %s statements: %s", block.Label, err.Error()), + ) if block.Label != sqlmanager_shared.SchemasLabel { return nil, fmt.Errorf("unable to exec mysql %s statements: %w", block.Label, err) } for _, stmt := range block.Statements { - err = d.destdb.Db().BatchExec(ctx, 1, []string{stmt}, &sqlmanager_shared.BatchExecOpts{}) + err = d.destdb.Db(). + BatchExec(ctx, 1, []string{stmt}, &sqlmanager_shared.BatchExecOpts{}) if err != nil { initErrors = append(initErrors, &shared.InitSchemaError{ Statement: stmt, @@ -389,7 +477,11 @@ func (d *MysqlSchemaManager) InitializeSchema(ctx context.Context, uniqueTables return initErrors, nil } -func (d *MysqlSchemaManager) TruncateData(ctx context.Context, uniqueTables map[string]struct{}, uniqueSchemas []string) error { +func (d *MysqlSchemaManager) TruncateData( + ctx context.Context, + uniqueTables map[string]struct{}, + uniqueSchemas []string, +) error { if !d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() { d.logger.Info("skipping truncate as it is not enabled") return nil @@ -403,9 +495,12 @@ func (d *MysqlSchemaManager) TruncateData(ctx context.Context, uniqueTables map[ } tableTruncate = append(tableTruncate, stmt) } - d.logger.Info(fmt.Sprintf("executing %d sql statements that will truncate tables", len(tableTruncate))) + d.logger.Info( + fmt.Sprintf("executing %d sql statements that will truncate tables", len(tableTruncate)), + ) disableFkChecks := sqlmanager_shared.DisableForeignKeyChecks - err := d.destdb.Db().BatchExec(ctx, shared.BatchSizeConst, tableTruncate, &sqlmanager_shared.BatchExecOpts{Prefix: &disableFkChecks}) + err := d.destdb.Db(). + BatchExec(ctx, shared.BatchSizeConst, tableTruncate, &sqlmanager_shared.BatchExecOpts{Prefix: &disableFkChecks}) if err != nil { return err } diff --git a/internal/schema-manager/not-supported/not-supported.go b/internal/schema-manager/not-supported/not-supported.go index 1781df0d90..c503c1ead9 100644 --- a/internal/schema-manager/not-supported/not-supported.go +++ b/internal/schema-manager/not-supported/not-supported.go @@ -14,27 +14,47 @@ func NewNotSupportedSchemaManager() (*NotSupportedSchemaManager, error) { return &NotSupportedSchemaManager{}, nil } -func (d *NotSupportedSchemaManager) InitializeSchema(ctx context.Context, uniqueTables map[string]struct{}) ([]*shared.InitSchemaError, error) { +func (d *NotSupportedSchemaManager) InitializeSchema( + ctx context.Context, + uniqueTables map[string]struct{}, +) ([]*shared.InitSchemaError, error) { return []*shared.InitSchemaError{}, nil } -func (d *NotSupportedSchemaManager) TruncateData(ctx context.Context, uniqueTables map[string]struct{}, uniqueSchemas []string) error { +func (d *NotSupportedSchemaManager) TruncateData( + ctx context.Context, + uniqueTables map[string]struct{}, + uniqueSchemas []string, +) error { return nil } -func (d *NotSupportedSchemaManager) CalculateSchemaDiff(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable) (*shared.SchemaDifferences, error) { +func (d *NotSupportedSchemaManager) CalculateSchemaDiff( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, +) (*shared.SchemaDifferences, error) { return nil, nil } -func (d *NotSupportedSchemaManager) BuildSchemaDiffStatements(ctx context.Context, diff *shared.SchemaDifferences) ([]*sqlmanager_shared.InitSchemaStatements, error) { +func (d *NotSupportedSchemaManager) BuildSchemaDiffStatements( + ctx context.Context, + diff *shared.SchemaDifferences, +) ([]*sqlmanager_shared.InitSchemaStatements, error) { return nil, nil } -func (d *NotSupportedSchemaManager) ReconcileDestinationSchema(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable, schemaStatements []*sqlmanager_shared.InitSchemaStatements) ([]*shared.InitSchemaError, error) { +func (d *NotSupportedSchemaManager) ReconcileDestinationSchema( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, + schemaStatements []*sqlmanager_shared.InitSchemaStatements, +) ([]*shared.InitSchemaError, error) { return []*shared.InitSchemaError{}, nil } -func (d *NotSupportedSchemaManager) TruncateTables(ctx context.Context, schemaDiff *shared.SchemaDifferences) error { +func (d *NotSupportedSchemaManager) TruncateTables( + ctx context.Context, + schemaDiff *shared.SchemaDifferences, +) error { return nil } diff --git a/internal/schema-manager/postgres/postgres.go b/internal/schema-manager/postgres/postgres.go index 134baf528d..8db9571cf4 100644 --- a/internal/schema-manager/postgres/postgres.go +++ b/internal/schema-manager/postgres/postgres.go @@ -57,7 +57,10 @@ func NewPostgresSchemaManager( }, nil } -func (d *PostgresSchemaManager) CalculateSchemaDiff(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable) (*shared.SchemaDifferences, error) { +func (d *PostgresSchemaManager) CalculateSchemaDiff( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, +) (*shared.SchemaDifferences, error) { d.logger.Debug("calculating schema diff") tables := []*sqlmanager_shared.SchemaTable{} schemaMap := map[string][]*sqlmanager_shared.SchemaTable{} @@ -127,7 +130,11 @@ func getDatabaseDataForSchemaDiff( errgrp.Go(func() error { tableconstraints, err := db.Db().GetTableConstraintsByTables(errctx, schema, tableNames) if err != nil { - return fmt.Errorf("failed to retrieve database table constraints for schema %s: %w", schema, err) + return fmt.Errorf( + "failed to retrieve database table constraints for schema %s: %w", + schema, + err, + ) } mu.Lock() defer mu.Unlock() @@ -155,7 +162,10 @@ func getDatabaseDataForSchemaDiff( }, nil } -func (d *PostgresSchemaManager) BuildSchemaDiffStatements(ctx context.Context, diff *shared.SchemaDifferences) ([]*sqlmanager_shared.InitSchemaStatements, error) { +func (d *PostgresSchemaManager) BuildSchemaDiffStatements( + ctx context.Context, + diff *shared.SchemaDifferences, +) ([]*sqlmanager_shared.InitSchemaStatements, error) { d.logger.Debug("building schema diff statements") if !d.destOpts.GetInitTableSchema() { d.logger.Info("skipping schema init as it is not enabled") @@ -164,23 +174,45 @@ func (d *PostgresSchemaManager) BuildSchemaDiffStatements(ctx context.Context, d addColumnStatements := []string{} for _, column := range diff.ExistsInSource.Columns { stmt := sqlmanager_postgres.BuildAddColumnStatement(column) - commentStmt := sqlmanager_postgres.BuildUpdateCommentStatement(column.Schema, column.Table, column.Name, column.Comment) + commentStmt := sqlmanager_postgres.BuildUpdateCommentStatement( + column.Schema, + column.Table, + column.Name, + column.Comment, + ) addColumnStatements = append(addColumnStatements, stmt, commentStmt) } dropNonFkConstraintStatements := []string{} for _, constraint := range diff.ExistsInDestination.NonForeignKeyConstraints { - dropNonFkConstraintStatements = append(dropNonFkConstraintStatements, sqlmanager_postgres.BuildDropConstraintStatement(constraint.SchemaName, constraint.TableName, constraint.ConstraintName)) + dropNonFkConstraintStatements = append( + dropNonFkConstraintStatements, + sqlmanager_postgres.BuildDropConstraintStatement( + constraint.SchemaName, + constraint.TableName, + constraint.ConstraintName, + ), + ) } dropFkConstraintStatements := []string{} for _, constraint := range diff.ExistsInDestination.ForeignKeyConstraints { - dropFkConstraintStatements = append(dropFkConstraintStatements, sqlmanager_postgres.BuildDropConstraintStatement(constraint.ReferencingSchema, constraint.ReferencingTable, constraint.ConstraintName)) + dropFkConstraintStatements = append( + dropFkConstraintStatements, + sqlmanager_postgres.BuildDropConstraintStatement( + constraint.ReferencingSchema, + constraint.ReferencingTable, + constraint.ConstraintName, + ), + ) } dropColumnStatements := []string{} for _, column := range diff.ExistsInDestination.Columns { - dropColumnStatements = append(dropColumnStatements, sqlmanager_postgres.BuildDropColumnStatement(column.Schema, column.Table, column.Name)) + dropColumnStatements = append( + dropColumnStatements, + sqlmanager_postgres.BuildDropColumnStatement(column.Schema, column.Table, column.Name), + ) } return []*sqlmanager_shared.InitSchemaStatements{ @@ -203,7 +235,11 @@ func (d *PostgresSchemaManager) BuildSchemaDiffStatements(ctx context.Context, d }, nil } -func (d *PostgresSchemaManager) ReconcileDestinationSchema(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable, schemaStatements []*sqlmanager_shared.InitSchemaStatements) ([]*shared.InitSchemaError, error) { +func (d *PostgresSchemaManager) ReconcileDestinationSchema( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, + schemaStatements []*sqlmanager_shared.InitSchemaStatements, +) ([]*shared.InitSchemaError, error) { d.logger.Debug("reconciling destination schema") initErrors := []*shared.InitSchemaError{} if !d.destOpts.GetInitTableSchema() { @@ -222,7 +258,10 @@ func (d *PostgresSchemaManager) ReconcileDestinationSchema(ctx context.Context, schemaStatementsByLabel := map[string][]*sqlmanager_shared.InitSchemaStatements{} for _, statement := range schemaStatements { - schemaStatementsByLabel[statement.Label] = append(schemaStatementsByLabel[statement.Label], statement) + schemaStatementsByLabel[statement.Label] = append( + schemaStatementsByLabel[statement.Label], + statement, + ) } // insert add columns statements after create table statements @@ -232,23 +271,41 @@ func (d *PostgresSchemaManager) ReconcileDestinationSchema(ctx context.Context, for _, statement := range initblocks { statementBlocks = append(statementBlocks, statement) if statement.Label == sqlmanager_shared.CreateTablesLabel { - statementBlocks = append(statementBlocks, schemaStatementsByLabel[sqlmanager_shared.DropForeignKeyConstraintsLabel]...) - statementBlocks = append(statementBlocks, schemaStatementsByLabel[sqlmanager_shared.DropNonForeignKeyConstraintsLabel]...) - statementBlocks = append(statementBlocks, schemaStatementsByLabel[sqlmanager_shared.DropColumnsLabel]...) - statementBlocks = append(statementBlocks, schemaStatementsByLabel[sqlmanager_shared.AddColumnsLabel]...) + statementBlocks = append( + statementBlocks, + schemaStatementsByLabel[sqlmanager_shared.DropForeignKeyConstraintsLabel]...) + statementBlocks = append( + statementBlocks, + schemaStatementsByLabel[sqlmanager_shared.DropNonForeignKeyConstraintsLabel]...) + statementBlocks = append( + statementBlocks, + schemaStatementsByLabel[sqlmanager_shared.DropColumnsLabel]...) + statementBlocks = append( + statementBlocks, + schemaStatementsByLabel[sqlmanager_shared.AddColumnsLabel]...) } } for _, block := range statementBlocks { - d.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements))) + d.logger.Info( + fmt.Sprintf( + "[%s] found %d statements to execute during schema initialization", + block.Label, + len(block.Statements), + ), + ) if len(block.Statements) == 0 { continue } - err = d.destdb.Db().BatchExec(ctx, shared.BatchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{}) + err = d.destdb.Db(). + BatchExec(ctx, shared.BatchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{}) if err != nil { - d.logger.Error(fmt.Sprintf("unable to exec postgres %s statements: %s", block.Label, err.Error())) + d.logger.Error( + fmt.Sprintf("unable to exec postgres %s statements: %s", block.Label, err.Error()), + ) for _, stmt := range block.Statements { - err = d.destdb.Db().BatchExec(ctx, 1, []string{stmt}, &sqlmanager_shared.BatchExecOpts{}) + err = d.destdb.Db(). + BatchExec(ctx, 1, []string{stmt}, &sqlmanager_shared.BatchExecOpts{}) if err != nil { initErrors = append(initErrors, &shared.InitSchemaError{ Statement: stmt, @@ -261,8 +318,12 @@ func (d *PostgresSchemaManager) ReconcileDestinationSchema(ctx context.Context, return initErrors, nil } -func (d *PostgresSchemaManager) TruncateTables(ctx context.Context, schemaDiff *shared.SchemaDifferences) error { - if !d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() && !d.destOpts.GetTruncateTable().GetCascade() { +func (d *PostgresSchemaManager) TruncateTables( + ctx context.Context, + schemaDiff *shared.SchemaDifferences, +) error { + if !d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() && + !d.destOpts.GetTruncateTable().GetCascade() { d.logger.Info("skipping truncate as it is not enabled") return nil } @@ -283,7 +344,10 @@ func (d *PostgresSchemaManager) TruncateTables(ctx context.Context, schemaDiff * return d.TruncateData(ctx, uniqueTables, uniqueSchemas) } -func (d *PostgresSchemaManager) InitializeSchema(ctx context.Context, uniqueTables map[string]struct{}) ([]*shared.InitSchemaError, error) { +func (d *PostgresSchemaManager) InitializeSchema( + ctx context.Context, + uniqueTables map[string]struct{}, +) ([]*shared.InitSchemaError, error) { initErrors := []*shared.InitSchemaError{} if !d.destOpts.GetInitTableSchema() { d.logger.Info("skipping schema init as it is not enabled") @@ -301,14 +365,24 @@ func (d *PostgresSchemaManager) InitializeSchema(ctx context.Context, uniqueTabl } for _, block := range initblocks { - d.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements))) + d.logger.Info( + fmt.Sprintf( + "[%s] found %d statements to execute during schema initialization", + block.Label, + len(block.Statements), + ), + ) if len(block.Statements) == 0 { continue } - err = d.destdb.Db().BatchExec(ctx, shared.BatchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{}) + err = d.destdb.Db(). + BatchExec(ctx, shared.BatchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{}) if err != nil { - d.logger.Error(fmt.Sprintf("unable to exec pg %s statements: %s", block.Label, err.Error())) - if block.Label != sqlmanager_shared.SchemasLabel && block.Label != sqlmanager_shared.ExtensionsLabel { + d.logger.Error( + fmt.Sprintf("unable to exec pg %s statements: %s", block.Label, err.Error()), + ) + if block.Label != sqlmanager_shared.SchemasLabel && + block.Label != sqlmanager_shared.ExtensionsLabel { return nil, fmt.Errorf("unable to exec pg %s statements: %w", block.Label, err) } for _, stmt := range block.Statements { @@ -325,8 +399,13 @@ func (d *PostgresSchemaManager) InitializeSchema(ctx context.Context, uniqueTabl return initErrors, nil } -func (d *PostgresSchemaManager) TruncateData(ctx context.Context, uniqueTables map[string]struct{}, uniqueSchemas []string) error { - if !d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() && !d.destOpts.GetTruncateTable().GetCascade() { +func (d *PostgresSchemaManager) TruncateData( + ctx context.Context, + uniqueTables map[string]struct{}, + uniqueSchemas []string, +) error { + if !d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() && + !d.destOpts.GetTruncateTable().GetCascade() { d.logger.Info("skipping truncate as it is not enabled") return nil } @@ -340,8 +419,14 @@ func (d *PostgresSchemaManager) TruncateData(ctx context.Context, uniqueTables m } tableTruncateStmts = append(tableTruncateStmts, stmt) } - d.logger.Info(fmt.Sprintf("executing %d sql statements that will truncate cascade tables", len(tableTruncateStmts))) - err := d.destdb.Db().BatchExec(ctx, shared.BatchSizeConst, tableTruncateStmts, &sqlmanager_shared.BatchExecOpts{}) + d.logger.Info( + fmt.Sprintf( + "executing %d sql statements that will truncate cascade tables", + len(tableTruncateStmts), + ), + ) + err := d.destdb.Db(). + BatchExec(ctx, shared.BatchSizeConst, tableTruncateStmts, &sqlmanager_shared.BatchExecOpts{}) if err != nil { return fmt.Errorf("unable to exec truncate cascade statements: %w", err) } @@ -367,7 +452,8 @@ func (d *PostgresSchemaManager) TruncateData(ctx context.Context, uniqueTables m return fmt.Errorf("unable to exec ordered truncate statements: %w", err) } } - if d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() || d.destOpts.GetTruncateTable().GetCascade() { + if d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() || + d.destOpts.GetTruncateTable().GetCascade() { // reset serial counts // identity counts are automatically reset with truncate identity restart clause schemaTableMap := map[string][]string{} @@ -383,14 +469,21 @@ func (d *PostgresSchemaManager) TruncateData(ctx context.Context, uniqueTables m } resetSeqStmts := []string{} for _, seq := range sequences { - resetSeqStmts = append(resetSeqStmts, sqlmanager_postgres.BuildPgResetSequenceSql(seq.Schema, seq.Name)) + resetSeqStmts = append( + resetSeqStmts, + sqlmanager_postgres.BuildPgResetSequenceSql(seq.Schema, seq.Name), + ) } if len(resetSeqStmts) > 0 { - err = d.destdb.Db().BatchExec(ctx, 10, resetSeqStmts, &sqlmanager_shared.BatchExecOpts{}) + err = d.destdb.Db(). + BatchExec(ctx, 10, resetSeqStmts, &sqlmanager_shared.BatchExecOpts{}) if err != nil { // handle not found errors if !strings.Contains(err.Error(), `does not exist`) { - return fmt.Errorf("unable to exec postgres sequence reset statements: %w", err) + return fmt.Errorf( + "unable to exec postgres sequence reset statements: %w", + err, + ) } } } diff --git a/internal/schema-manager/schema-manager.go b/internal/schema-manager/schema-manager.go index 39b72eab97..c515458fc8 100644 --- a/internal/schema-manager/schema-manager.go +++ b/internal/schema-manager/schema-manager.go @@ -18,12 +18,29 @@ import ( ) type SchemaManagerService interface { - InitializeSchema(ctx context.Context, uniqueTables map[string]struct{}) ([]*schema_shared.InitSchemaError, error) - TruncateData(ctx context.Context, uniqueTables map[string]struct{}, uniqueSchemas []string) error + InitializeSchema( + ctx context.Context, + uniqueTables map[string]struct{}, + ) ([]*schema_shared.InitSchemaError, error) + TruncateData( + ctx context.Context, + uniqueTables map[string]struct{}, + uniqueSchemas []string, + ) error - CalculateSchemaDiff(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable) (*schema_shared.SchemaDifferences, error) - BuildSchemaDiffStatements(ctx context.Context, diff *schema_shared.SchemaDifferences) ([]*sqlmanager_shared.InitSchemaStatements, error) - ReconcileDestinationSchema(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable, schemaStatements []*sqlmanager_shared.InitSchemaStatements) ([]*schema_shared.InitSchemaError, error) + CalculateSchemaDiff( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, + ) (*schema_shared.SchemaDifferences, error) + BuildSchemaDiffStatements( + ctx context.Context, + diff *schema_shared.SchemaDifferences, + ) ([]*sqlmanager_shared.InitSchemaStatements, error) + ReconcileDestinationSchema( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, + schemaStatements []*sqlmanager_shared.InitSchemaStatements, + ) ([]*schema_shared.InitSchemaError, error) TruncateTables(ctx context.Context, schemaDiff *schema_shared.SchemaDifferences) error CloseConnections() @@ -51,7 +68,12 @@ func NewSchemaManager( logger *slog.Logger, eelicense license.EEInterface, ) SchemaManager { - return &DefaultSchemaManager{sqlmanagerclient: sqlmanagerclient, session: session, logger: logger, eelicense: eelicense} + return &DefaultSchemaManager{ + sqlmanagerclient: sqlmanagerclient, + session: session, + logger: logger, + eelicense: eelicense, + } } func (d *DefaultSchemaManager) New( diff --git a/internal/schema-manager/shared/foreign-keys.go b/internal/schema-manager/shared/foreign-keys.go index 42700c421e..7cfdde6a2c 100644 --- a/internal/schema-manager/shared/foreign-keys.go +++ b/internal/schema-manager/shared/foreign-keys.go @@ -94,7 +94,9 @@ func BuildOrderedForeignKeyConstraintsToDrop( // Either forcibly drop them or return an error. Here, we forcibly handle them (like your code). hadCycle := (len(topoOrder) < len(inDegree)) if hadCycle { - logger.Warn("Cycle detected among foreign keys. Forcibly dropping all remaining constraints.") + logger.Warn( + "Cycle detected among foreign keys. Forcibly dropping all remaining constraints.", + ) } // If no cycle, we can produce a stable drop order for "normal" FKs: diff --git a/internal/schema-manager/shared/schema-diff.go b/internal/schema-manager/shared/schema-diff.go index 18a5615651..2d4d8d56d8 100644 --- a/internal/schema-manager/shared/schema-diff.go +++ b/internal/schema-manager/shared/schema-diff.go @@ -130,25 +130,37 @@ func (b *SchemaDifferencesBuilder) buildTableColumnDifferences() { } func (b *SchemaDifferencesBuilder) buildTableForeignKeyConstraintDifferences() { - existsInSource, existsInDestination := buildDifferencesByFingerprint(b.source.ForeignKeyConstraints, b.destination.ForeignKeyConstraints) + existsInSource, existsInDestination := buildDifferencesByFingerprint( + b.source.ForeignKeyConstraints, + b.destination.ForeignKeyConstraints, + ) b.diff.ExistsInSource.ForeignKeyConstraints = existsInSource b.diff.ExistsInDestination.ForeignKeyConstraints = existsInDestination } func (b *SchemaDifferencesBuilder) buildTableNonForeignKeyConstraintDifferences() { - existsInSource, existsInDestination := buildDifferencesByFingerprint(b.source.NonForeignKeyConstraints, b.destination.NonForeignKeyConstraints) + existsInSource, existsInDestination := buildDifferencesByFingerprint( + b.source.NonForeignKeyConstraints, + b.destination.NonForeignKeyConstraints, + ) b.diff.ExistsInSource.NonForeignKeyConstraints = existsInSource b.diff.ExistsInDestination.NonForeignKeyConstraints = existsInDestination } func (b *SchemaDifferencesBuilder) buildTableTriggerDifferences() { - existsInSource, existsInDestination := buildDifferencesByFingerprint(b.source.Triggers, b.destination.Triggers) + existsInSource, existsInDestination := buildDifferencesByFingerprint( + b.source.Triggers, + b.destination.Triggers, + ) b.diff.ExistsInSource.Triggers = existsInSource b.diff.ExistsInDestination.Triggers = existsInDestination } func (b *SchemaDifferencesBuilder) buildSchemaFunctionDifferences() { - existsInSource, existsInDestination := buildDifferencesByFingerprint(b.source.Functions, b.destination.Functions) + existsInSource, existsInDestination := buildDifferencesByFingerprint( + b.source.Functions, + b.destination.Functions, + ) b.diff.ExistsInSource.Functions = existsInSource b.diff.ExistsInDestination.Functions = existsInDestination } diff --git a/internal/schema-manager/shared/shared.go b/internal/schema-manager/shared/shared.go index 7c2f68b1c1..7a3b39c49d 100644 --- a/internal/schema-manager/shared/shared.go +++ b/internal/schema-manager/shared/shared.go @@ -12,7 +12,10 @@ type InitSchemaError struct { } // filtered by tables found in job mappings -func GetFilteredForeignToPrimaryTableMap(td map[string][]*sqlmanager_shared.ForeignConstraint, uniqueTables map[string]struct{}) map[string][]string { +func GetFilteredForeignToPrimaryTableMap( + td map[string][]*sqlmanager_shared.ForeignConstraint, + uniqueTables map[string]struct{}, +) map[string][]string { dpMap := map[string][]string{} for table := range uniqueTables { _, dpOk := dpMap[table] diff --git a/internal/sshtunnel/connectors/postgrestunconnector/connector.go b/internal/sshtunnel/connectors/postgrestunconnector/connector.go index f990619b54..7c56343d21 100644 --- a/internal/sshtunnel/connectors/postgrestunconnector/connector.go +++ b/internal/sshtunnel/connectors/postgrestunconnector/connector.go @@ -81,7 +81,13 @@ func New( addrs := []string{name} resp, err := net.DefaultResolver.LookupHost(ctx, name) if err != nil { - cfg.logger.Error("unable to lookup addrs for hostname during postgres tunnel dial", "name", name, "err", err) + cfg.logger.Error( + "unable to lookup addrs for hostname during postgres tunnel dial", + "name", + name, + "err", + err, + ) } else { addrs = append(addrs, resp...) } diff --git a/internal/sshtunnel/dialer.go b/internal/sshtunnel/dialer.go index 8a01c1c833..a4744a856c 100644 --- a/internal/sshtunnel/dialer.go +++ b/internal/sshtunnel/dialer.go @@ -57,7 +57,13 @@ func DefaultSSHDialerConfig() *SSHDialerConfig { backoff.WithMaxTries(10), backoff.WithMaxElapsedTime(5 * time.Minute), backoff.WithNotify(func(err error, d time.Duration) { - logger.Warn(fmt.Sprintf("ssh error with retry: %s, retrying in %s", err.Error(), d.String())) + logger.Warn( + fmt.Sprintf( + "ssh error with retry: %s, retrying in %s", + err.Error(), + d.String(), + ), + ) }), } }, @@ -67,11 +73,22 @@ func DefaultSSHDialerConfig() *SSHDialerConfig { } } -func NewLazySSHDialer(addr string, ccfg *ssh.ClientConfig, dialCfg *SSHDialerConfig, logger *slog.Logger) *SSHDialer { +func NewLazySSHDialer( + addr string, + ccfg *ssh.ClientConfig, + dialCfg *SSHDialerConfig, + logger *slog.Logger, +) *SSHDialer { if dialCfg == nil { dialCfg = DefaultSSHDialerConfig() } - return &SSHDialer{addr: addr, ccfg: ccfg, clientmu: &sync.Mutex{}, dialCfg: dialCfg, logger: logger} + return &SSHDialer{ + addr: addr, + ccfg: ccfg, + clientmu: &sync.Mutex{}, + dialCfg: dialCfg, + logger: logger, + } } func NewSSHDialer(client *ssh.Client, logger *slog.Logger) *SSHDialer { @@ -122,7 +139,12 @@ func (s *SSHDialer) getClient(ctx context.Context) (*ssh.Client, error) { if err == nil { return s.client, nil } - s.logger.Warn(fmt.Sprintf("SSH client was dead, closing and attempting to re-open connection: %s", err.Error())) + s.logger.Warn( + fmt.Sprintf( + "SSH client was dead, closing and attempting to re-open connection: %s", + err.Error(), + ), + ) s.client.Close() s.client = nil } diff --git a/internal/sshtunnel/utils.go b/internal/sshtunnel/utils.go index 6129157c58..893f6f3f4c 100644 --- a/internal/sshtunnel/utils.go +++ b/internal/sshtunnel/utils.go @@ -46,7 +46,9 @@ func parseSshKey(keyString string) (ssh.PublicKey, error) { // Auth Method is optional and will return nil if there is no valid method. // Will only return error if unable to parse the private key into an auth method -func getTunnelAuthMethodFromSshConfig(auth *mgmtv1alpha1.SSHAuthentication) (ssh.AuthMethod, error) { +func getTunnelAuthMethodFromSshConfig( + auth *mgmtv1alpha1.SSHAuthentication, +) (ssh.AuthMethod, error) { if auth == nil { return nil, nil } diff --git a/internal/temporal/clientmanager/client_factory.go b/internal/temporal/clientmanager/client_factory.go index 25bec8489f..07a5bbe646 100644 --- a/internal/temporal/clientmanager/client_factory.go +++ b/internal/temporal/clientmanager/client_factory.go @@ -8,8 +8,16 @@ import ( ) type ClientFactory interface { - CreateNamespaceClient(ctx context.Context, config *TemporalConfig, logger *slog.Logger) (temporalclient.NamespaceClient, error) - CreateWorkflowClient(ctx context.Context, config *TemporalConfig, logger *slog.Logger) (temporalclient.Client, error) + CreateNamespaceClient( + ctx context.Context, + config *TemporalConfig, + logger *slog.Logger, + ) (temporalclient.NamespaceClient, error) + CreateWorkflowClient( + ctx context.Context, + config *TemporalConfig, + logger *slog.Logger, + ) (temporalclient.Client, error) } type TemporalClientFactory struct{} @@ -34,7 +42,10 @@ func (f *TemporalClientFactory) CreateWorkflowClient( return temporalclient.NewLazyClient(f.getClientOptions(config, logger)) } -func (f *TemporalClientFactory) getClientOptions(config *TemporalConfig, logger *slog.Logger) temporalclient.Options { +func (f *TemporalClientFactory) getClientOptions( + config *TemporalConfig, + logger *slog.Logger, +) temporalclient.Options { opts := temporalclient.Options{ Logger: logger, HostPort: config.Url, diff --git a/internal/temporal/clientmanager/config_provider.go b/internal/temporal/clientmanager/config_provider.go index 1078759ba7..7e68f29f9a 100644 --- a/internal/temporal/clientmanager/config_provider.go +++ b/internal/temporal/clientmanager/config_provider.go @@ -15,7 +15,11 @@ type ConfigProvider interface { } type DB interface { - GetTemporalConfigByAccount(ctx context.Context, db db_queries.DBTX, accountId pgtype.UUID) (*pg_models.TemporalConfig, error) + GetTemporalConfigByAccount( + ctx context.Context, + db db_queries.DBTX, + accountId pgtype.UUID, + ) (*pg_models.TemporalConfig, error) } type DBConfigProvider struct { @@ -24,7 +28,11 @@ type DBConfigProvider struct { dbtx db_queries.DBTX } -func NewDBConfigProvider(defaultConfig *TemporalConfig, db DB, dbtx db_queries.DBTX) *DBConfigProvider { +func NewDBConfigProvider( + defaultConfig *TemporalConfig, + db DB, + dbtx db_queries.DBTX, +) *DBConfigProvider { return &DBConfigProvider{ defaultConfig: defaultConfig, db: db, @@ -32,7 +40,10 @@ func NewDBConfigProvider(defaultConfig *TemporalConfig, db DB, dbtx db_queries.D } } -func (p *DBConfigProvider) GetConfig(ctx context.Context, accountID string) (*TemporalConfig, error) { +func (p *DBConfigProvider) GetConfig( + ctx context.Context, + accountID string, +) (*TemporalConfig, error) { accountUuid, err := neosyncdb.ToUuid(accountID) if err != nil { return nil, fmt.Errorf("invalid account ID: %w", err) diff --git a/internal/temporal/clientmanager/manager.go b/internal/temporal/clientmanager/manager.go index fcf7b8543b..95826c4a07 100644 --- a/internal/temporal/clientmanager/manager.go +++ b/internal/temporal/clientmanager/manager.go @@ -25,24 +25,107 @@ type DescribeSchedulesResponse struct { } type Interface interface { - DoesAccountHaveNamespace(ctx context.Context, accountId string, logger *slog.Logger) (bool, error) + DoesAccountHaveNamespace( + ctx context.Context, + accountId string, + logger *slog.Logger, + ) (bool, error) GetSyncJobTaskQueue(ctx context.Context, accountId string, logger *slog.Logger) (string, error) - CreateSchedule(ctx context.Context, accountId string, opts *temporalclient.ScheduleOptions, logger *slog.Logger) (string, error) - TriggerSchedule(ctx context.Context, accountId string, scheduleId string, opts *temporalclient.ScheduleTriggerOptions, logger *slog.Logger) error - PauseSchedule(ctx context.Context, accountId string, scheduleId string, opts *temporalclient.SchedulePauseOptions, logger *slog.Logger) error - UnpauseSchedule(ctx context.Context, accountId string, scheduleId string, opts *temporalclient.ScheduleUnpauseOptions, logger *slog.Logger) error - UpdateSchedule(ctx context.Context, accountId string, scheduleId string, opts *temporalclient.ScheduleUpdateOptions, logger *slog.Logger) error - DescribeSchedule(ctx context.Context, accountId string, scheduleId string, logger *slog.Logger) (*temporalclient.ScheduleDescription, error) - DescribeSchedules(ctx context.Context, accountId string, scheduleIds []string, logger *slog.Logger) ([]*DescribeSchedulesResponse, error) - DeleteSchedule(ctx context.Context, accountId string, scheduleId string, logger *slog.Logger) error - GetWorkflowExecutionById(ctx context.Context, accountId string, workflowId string, logger *slog.Logger) (*workflowpb.WorkflowExecutionInfo, error) - DeleteWorkflowExecution(ctx context.Context, accountId string, workflowId string, logger *slog.Logger) error - GetWorkflowExecutionsByScheduleIds(ctx context.Context, accountId string, scheduleIds []string, logger *slog.Logger) ([]*workflowpb.WorkflowExecutionInfo, error) - DescribeWorklowExecution(ctx context.Context, accountId string, workflowId string, logger *slog.Logger) (*workflowservice.DescribeWorkflowExecutionResponse, error) - CancelWorkflow(ctx context.Context, accountId string, workflowId string, logger *slog.Logger) error - TerminateWorkflow(ctx context.Context, accountId string, workflowId string, logger *slog.Logger) error - GetWorkflowHistory(ctx context.Context, accountId string, workflowId string, logger *slog.Logger) (temporalclient.HistoryEventIterator, error) + CreateSchedule( + ctx context.Context, + accountId string, + opts *temporalclient.ScheduleOptions, + logger *slog.Logger, + ) (string, error) + TriggerSchedule( + ctx context.Context, + accountId string, + scheduleId string, + opts *temporalclient.ScheduleTriggerOptions, + logger *slog.Logger, + ) error + PauseSchedule( + ctx context.Context, + accountId string, + scheduleId string, + opts *temporalclient.SchedulePauseOptions, + logger *slog.Logger, + ) error + UnpauseSchedule( + ctx context.Context, + accountId string, + scheduleId string, + opts *temporalclient.ScheduleUnpauseOptions, + logger *slog.Logger, + ) error + UpdateSchedule( + ctx context.Context, + accountId string, + scheduleId string, + opts *temporalclient.ScheduleUpdateOptions, + logger *slog.Logger, + ) error + DescribeSchedule( + ctx context.Context, + accountId string, + scheduleId string, + logger *slog.Logger, + ) (*temporalclient.ScheduleDescription, error) + DescribeSchedules( + ctx context.Context, + accountId string, + scheduleIds []string, + logger *slog.Logger, + ) ([]*DescribeSchedulesResponse, error) + DeleteSchedule( + ctx context.Context, + accountId string, + scheduleId string, + logger *slog.Logger, + ) error + GetWorkflowExecutionById( + ctx context.Context, + accountId string, + workflowId string, + logger *slog.Logger, + ) (*workflowpb.WorkflowExecutionInfo, error) + DeleteWorkflowExecution( + ctx context.Context, + accountId string, + workflowId string, + logger *slog.Logger, + ) error + GetWorkflowExecutionsByScheduleIds( + ctx context.Context, + accountId string, + scheduleIds []string, + logger *slog.Logger, + ) ([]*workflowpb.WorkflowExecutionInfo, error) + DescribeWorklowExecution( + ctx context.Context, + accountId string, + workflowId string, + logger *slog.Logger, + ) (*workflowservice.DescribeWorkflowExecutionResponse, error) + CancelWorkflow( + ctx context.Context, + accountId string, + workflowId string, + logger *slog.Logger, + ) error + TerminateWorkflow( + ctx context.Context, + accountId string, + workflowId string, + logger *slog.Logger, + ) error + GetWorkflowHistory( + ctx context.Context, + accountId string, + workflowId string, + logger *slog.Logger, + ) (temporalclient.HistoryEventIterator, error) } var _ Interface = (*ClientManager)(nil) @@ -64,7 +147,11 @@ func NewClientManager( } } -func (m *ClientManager) getClients(ctx context.Context, accountId string, logger *slog.Logger) (*clientHandle, error) { +func (m *ClientManager) getClients( + ctx context.Context, + accountId string, + logger *slog.Logger, +) (*clientHandle, error) { config, err := m.configProvider.GetConfig(ctx, accountId) if err != nil { return nil, fmt.Errorf("failed to get temporal config: %w", err) @@ -100,7 +187,11 @@ func (m *ClientManager) DoesAccountHaveNamespace( return true, nil } -func (m *ClientManager) GetSyncJobTaskQueue(ctx context.Context, accountId string, logger *slog.Logger) (string, error) { +func (m *ClientManager) GetSyncJobTaskQueue( + ctx context.Context, + accountId string, + logger *slog.Logger, +) (string, error) { config, err := m.configProvider.GetConfig(ctx, accountId) if err != nil { return "", fmt.Errorf("failed to get temporal config: %w", err) @@ -283,7 +374,12 @@ func (m *ClientManager) GetWorkflowExecutionsByScheduleIds( } defer clients.Release() - return getWorfklowsByScheduleIds(ctx, clients.WorkflowClient(), clients.config.Namespace, scheduleIds) + return getWorfklowsByScheduleIds( + ctx, + clients.WorkflowClient(), + clients.config.Namespace, + scheduleIds, + ) } func (m *ClientManager) GetWorkflowExecutionById( @@ -334,11 +430,17 @@ func (m *ClientManager) DescribeWorklowExecution( } defer clients.Release() - wf, err := getLatestWorkflow(ctx, clients.WorkflowClient(), clients.config.Namespace, workflowId) + wf, err := getLatestWorkflow( + ctx, + clients.WorkflowClient(), + clients.config.Namespace, + workflowId, + ) if err != nil { return nil, err } - return clients.WorkflowClient().DescribeWorkflowExecution(ctx, wf.GetExecution().GetWorkflowId(), wf.GetExecution().GetRunId()) + return clients.WorkflowClient(). + DescribeWorkflowExecution(ctx, wf.GetExecution().GetWorkflowId(), wf.GetExecution().GetRunId()) } func (m *ClientManager) DeleteWorkflowExecution( @@ -359,10 +461,11 @@ func (m *ClientManager) DeleteWorkflowExecution( clients.config.Namespace, func(ctx context.Context, namespace string) ([]*workflowpb.WorkflowExecutionInfo, error) { // todo: should technically paginate this, but the amount of workflows + unique run ids should be only ever 1 - resp, err := clients.WorkflowClient().ListWorkflow(ctx, &workflowservice.ListWorkflowExecutionsRequest{ - Namespace: namespace, - Query: fmt.Sprintf("WorkflowId = %q", workflowId), - }) + resp, err := clients.WorkflowClient(). + ListWorkflow(ctx, &workflowservice.ListWorkflowExecutionsRequest{ + Namespace: namespace, + Query: fmt.Sprintf("WorkflowId = %q", workflowId), + }) if err != nil { return nil, err } @@ -394,10 +497,13 @@ func (m *ClientManager) deleteWorkflows( for _, wf := range workflowExecs { wf := wf errgrp.Go(func() error { - _, err := svc.DeleteWorkflowExecution(ctx, &workflowservice.DeleteWorkflowExecutionRequest{ - Namespace: namespace, - WorkflowExecution: wf.GetExecution(), - }) + _, err := svc.DeleteWorkflowExecution( + ctx, + &workflowservice.DeleteWorkflowExecutionRequest{ + Namespace: namespace, + WorkflowExecution: wf.GetExecution(), + }, + ) return err }) } @@ -416,11 +522,17 @@ func (m *ClientManager) CancelWorkflow( } defer clients.Release() - wf, err := getLatestWorkflow(ctx, clients.WorkflowClient(), clients.config.Namespace, workflowId) + wf, err := getLatestWorkflow( + ctx, + clients.WorkflowClient(), + clients.config.Namespace, + workflowId, + ) if err != nil { return err } - return clients.WorkflowClient().CancelWorkflow(ctx, wf.GetExecution().GetWorkflowId(), wf.GetExecution().GetRunId()) + return clients.WorkflowClient(). + CancelWorkflow(ctx, wf.GetExecution().GetWorkflowId(), wf.GetExecution().GetRunId()) } func (m *ClientManager) TerminateWorkflow( @@ -435,11 +547,17 @@ func (m *ClientManager) TerminateWorkflow( } defer clients.Release() - wf, err := getLatestWorkflow(ctx, clients.WorkflowClient(), clients.config.Namespace, workflowId) + wf, err := getLatestWorkflow( + ctx, + clients.WorkflowClient(), + clients.config.Namespace, + workflowId, + ) if err != nil { return err } - return clients.WorkflowClient().TerminateWorkflow(ctx, wf.GetExecution().GetWorkflowId(), wf.GetExecution().GetRunId(), "terminated by user") + return clients.WorkflowClient(). + TerminateWorkflow(ctx, wf.GetExecution().GetWorkflowId(), wf.GetExecution().GetRunId(), "terminated by user") } func (m *ClientManager) GetWorkflowHistory( @@ -454,7 +572,12 @@ func (m *ClientManager) GetWorkflowHistory( } defer clients.Release() - wf, err := getLatestWorkflow(ctx, clients.WorkflowClient(), clients.config.Namespace, workflowId) + wf, err := getLatestWorkflow( + ctx, + clients.WorkflowClient(), + clients.config.Namespace, + workflowId, + ) if err != nil { return nil, err } diff --git a/internal/testutil/testcontainers/dynamodb/dynamodb.go b/internal/testutil/testcontainers/dynamodb/dynamodb.go index 4c78c3df33..9b1ee32a3c 100644 --- a/internal/testutil/testcontainers/dynamodb/dynamodb.go +++ b/internal/testutil/testcontainers/dynamodb/dynamodb.go @@ -20,7 +20,11 @@ type DynamoDBTestSyncContainer struct { Target *DynamoDBTestContainer } -func NewDynamoDBTestSyncContainer(ctx context.Context, t *testing.T, sourceOpts, destOpts []Option) (*DynamoDBTestSyncContainer, error) { +func NewDynamoDBTestSyncContainer( + ctx context.Context, + t *testing.T, + sourceOpts, destOpts []Option, +) (*DynamoDBTestSyncContainer, error) { tc := &DynamoDBTestSyncContainer{} errgrp := errgroup.Group{} errgrp.Go(func() error { @@ -105,7 +109,11 @@ func WithAwsToken(token string) Option { } // NewDynamoDBTestContainer initializes a new DynamoDB Test Container with functional options -func NewDynamoDBTestContainer(ctx context.Context, t *testing.T, opts ...Option) (*DynamoDBTestContainer, error) { +func NewDynamoDBTestContainer( + ctx context.Context, + t *testing.T, + opts ...Option, +) (*DynamoDBTestContainer, error) { d := &DynamoDBTestContainer{ awsId: "fakeid", // default value awsSecret: "fakesecret", // default value @@ -118,7 +126,10 @@ func NewDynamoDBTestContainer(ctx context.Context, t *testing.T, opts ...Option) } // Creates and starts a DynamoDB test container -func (d *DynamoDBTestContainer) Setup(ctx context.Context, t *testing.T) (*DynamoDBTestContainer, error) { +func (d *DynamoDBTestContainer) Setup( + ctx context.Context, + t *testing.T, +) (*DynamoDBTestContainer, error) { port := nat.Port("8000/tcp") container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ ContainerRequest: testcontainers.ContainerRequest{ @@ -182,12 +193,19 @@ func (d *DynamoDBTestContainer) TearDown(ctx context.Context) error { return nil } -func (d *DynamoDBTestContainer) SetupDynamoDbTable(ctx context.Context, tableName, primaryKey string) error { +func (d *DynamoDBTestContainer) SetupDynamoDbTable( + ctx context.Context, + tableName, primaryKey string, +) error { out, err := d.Client.CreateTable(ctx, &dynamodb.CreateTableInput{ - TableName: &tableName, - KeySchema: []dyntypes.KeySchemaElement{{KeyType: dyntypes.KeyTypeHash, AttributeName: &primaryKey}}, - AttributeDefinitions: []dyntypes.AttributeDefinition{{AttributeName: &primaryKey, AttributeType: dyntypes.ScalarAttributeTypeS}}, - BillingMode: dyntypes.BillingModePayPerRequest, + TableName: &tableName, + KeySchema: []dyntypes.KeySchemaElement{ + {KeyType: dyntypes.KeyTypeHash, AttributeName: &primaryKey}, + }, + AttributeDefinitions: []dyntypes.AttributeDefinition{ + {AttributeName: &primaryKey, AttributeType: dyntypes.ScalarAttributeTypeS}, + }, + BillingMode: dyntypes.BillingModePayPerRequest, }) if err != nil { return err @@ -198,10 +216,17 @@ func (d *DynamoDBTestContainer) SetupDynamoDbTable(ctx context.Context, tableNam if out.TableDescription.TableStatus == dyntypes.TableStatusCreating { return d.waitUntilDynamoTableExists(ctx, tableName) } - return fmt.Errorf("%s dynamo table created but unexpected table status: %s", tableName, out.TableDescription.TableStatus) + return fmt.Errorf( + "%s dynamo table created but unexpected table status: %s", + tableName, + out.TableDescription.TableStatus, + ) } -func (d *DynamoDBTestContainer) waitUntilDynamoTableExists(ctx context.Context, tableName string) error { +func (d *DynamoDBTestContainer) waitUntilDynamoTableExists( + ctx context.Context, + tableName string, +) error { input := &dynamodb.DescribeTableInput{TableName: &tableName} for { out, err := d.Client.DescribeTable(ctx, input) @@ -227,7 +252,10 @@ func (d *DynamoDBTestContainer) DestroyDynamoDbTable(ctx context.Context, tableN return d.waitUntilDynamoTableDestroy(ctx, tableName) } -func (d *DynamoDBTestContainer) waitUntilDynamoTableDestroy(ctx context.Context, tableName string) error { +func (d *DynamoDBTestContainer) waitUntilDynamoTableDestroy( + ctx context.Context, + tableName string, +) error { input := &dynamodb.DescribeTableInput{TableName: &tableName} for { _, err := d.Client.DescribeTable(ctx, input) @@ -240,7 +268,11 @@ func (d *DynamoDBTestContainer) waitUntilDynamoTableDestroy(ctx context.Context, } } -func (d *DynamoDBTestContainer) InsertDynamoDBRecords(ctx context.Context, tableName string, data []map[string]dyntypes.AttributeValue) error { +func (d *DynamoDBTestContainer) InsertDynamoDBRecords( + ctx context.Context, + tableName string, + data []map[string]dyntypes.AttributeValue, +) error { writeRequests := make([]dyntypes.WriteRequest, len(data)) for i, record := range data { writeRequests[i] = dyntypes.WriteRequest{ diff --git a/internal/testutil/testcontainers/mongodb/mongodb.go b/internal/testutil/testcontainers/mongodb/mongodb.go index 606e0af793..a2f325a1ca 100644 --- a/internal/testutil/testcontainers/mongodb/mongodb.go +++ b/internal/testutil/testcontainers/mongodb/mongodb.go @@ -17,7 +17,10 @@ type MongoDBTestSyncContainer struct { Target *MongoDBTestContainer } -func NewMongoDBTestSyncContainer(ctx context.Context, t *testing.T) (*MongoDBTestSyncContainer, error) { +func NewMongoDBTestSyncContainer( + ctx context.Context, + t *testing.T, +) (*MongoDBTestSyncContainer, error) { tc := &MongoDBTestSyncContainer{} errgrp := errgroup.Group{} errgrp.Go(func() error { @@ -77,7 +80,10 @@ func NewMongoDBTestContainer(ctx context.Context, t *testing.T) (*MongoDBTestCon return m.Setup(ctx, t) } -func (m *MongoDBTestContainer) Setup(ctx context.Context, t *testing.T) (*MongoDBTestContainer, error) { +func (m *MongoDBTestContainer) Setup( + ctx context.Context, + t *testing.T, +) (*MongoDBTestContainer, error) { container, err := testmongodb.Run(ctx, "mongo:6") if err != nil { return nil, err @@ -112,7 +118,11 @@ func (m *MongoDBTestContainer) TearDown(ctx context.Context) error { return nil } -func (m *MongoDBTestContainer) InsertMongoDbRecords(ctx context.Context, database, collection string, documents []any) (int, error) { +func (m *MongoDBTestContainer) InsertMongoDbRecords( + ctx context.Context, + database, collection string, + documents []any, +) (int, error) { db := m.Client.Database(database) col := db.Collection(collection) @@ -124,7 +134,10 @@ func (m *MongoDBTestContainer) InsertMongoDbRecords(ctx context.Context, databas return len(result.InsertedIDs), nil } -func (m *MongoDBTestContainer) DropMongoDbCollection(ctx context.Context, database, collection string) error { +func (m *MongoDBTestContainer) DropMongoDbCollection( + ctx context.Context, + database, collection string, +) error { db := m.Client.Database(database) collections, err := db.ListCollectionNames(ctx, map[string]any{"name": collection}) if err != nil { diff --git a/internal/testutil/testcontainers/mysql/mysql.go b/internal/testutil/testcontainers/mysql/mysql.go index a09e71c008..f3caadd2d4 100644 --- a/internal/testutil/testcontainers/mysql/mysql.go +++ b/internal/testutil/testcontainers/mysql/mysql.go @@ -22,7 +22,10 @@ type MysqlTestSyncContainer struct { Target *MysqlTestContainer } -func NewMysqlTestSyncContainer(ctx context.Context, sourceOpts, destOpts []Option) (*MysqlTestSyncContainer, error) { +func NewMysqlTestSyncContainer( + ctx context.Context, + sourceOpts, destOpts []Option, +) (*MysqlTestSyncContainer, error) { tc := &MysqlTestSyncContainer{} errgrp := errgroup.Group{} errgrp.Go(func() error { @@ -134,7 +137,9 @@ func setup(ctx context.Context, cfg *mysqlTestContainerConfig) (*MysqlTestContai testmysql.WithUsername(cfg.username), testmysql.WithPassword(cfg.password), testcontainers.WithWaitStrategy( - wait.ForLog("port: 3306 MySQL Community Server").WithOccurrence(1).WithStartupTimeout(20 * time.Second), + wait.ForLog("port: 3306 MySQL Community Server"). + WithOccurrence(1). + WithStartupTimeout(20 * time.Second), ), } if cfg.useTls { @@ -252,7 +257,11 @@ func (m *MysqlTestContainer) TearDown(ctx context.Context) error { } // Executes SQL files within the test container -func (m *MysqlTestContainer) RunSqlFiles(ctx context.Context, folder *string, files []string) error { +func (m *MysqlTestContainer) RunSqlFiles( + ctx context.Context, + folder *string, + files []string, +) error { for _, file := range files { filePath := file if folder != nil && *folder != "" { @@ -271,7 +280,12 @@ func (m *MysqlTestContainer) RunSqlFiles(ctx context.Context, folder *string, fi } // Creates schema and sets USE to schema before running SQL files -func (m *MysqlTestContainer) RunCreateStmtsInDatabase(ctx context.Context, folder string, files []string, database string) error { +func (m *MysqlTestContainer) RunCreateStmtsInDatabase( + ctx context.Context, + folder string, + files []string, + database string, +) error { for _, file := range files { filePath := file if folder != "" { @@ -282,7 +296,11 @@ func (m *MysqlTestContainer) RunCreateStmtsInDatabase(ctx context.Context, folde return err } - setSchemaSql := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`; \n USE `%s`; \n", database, database) + setSchemaSql := fmt.Sprintf( + "CREATE DATABASE IF NOT EXISTS `%s`; \n USE `%s`; \n", + database, + database, + ) _, err = m.DB.ExecContext(ctx, setSchemaSql+string(sqlStr)) if err != nil { return fmt.Errorf("unable to exec sql when running mysql sql files: %w", err) @@ -293,7 +311,10 @@ func (m *MysqlTestContainer) RunCreateStmtsInDatabase(ctx context.Context, folde func (m *MysqlTestContainer) CreateDatabases(ctx context.Context, databases []string) error { for _, database := range databases { - _, err := m.DB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`;", database)) + _, err := m.DB.ExecContext( + ctx, + fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`;", database), + ) if err != nil { return fmt.Errorf("unable to create database %s: %w", database, err) } @@ -311,8 +332,14 @@ func (m *MysqlTestContainer) DropDatabases(ctx context.Context, databases []stri return nil } -func (m *MysqlTestContainer) GetTableRowCount(ctx context.Context, schema, table string) (int, error) { - rows := m.DB.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`;", schema, table)) +func (m *MysqlTestContainer) GetTableRowCount( + ctx context.Context, + schema, table string, +) (int, error) { + rows := m.DB.QueryRowContext( + ctx, + fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`;", schema, table), + ) var count int err := rows.Scan(&count) if err != nil { diff --git a/internal/testutil/testcontainers/postgres/postgres.go b/internal/testutil/testcontainers/postgres/postgres.go index 024c7848f8..5f010d564d 100644 --- a/internal/testutil/testcontainers/postgres/postgres.go +++ b/internal/testutil/testcontainers/postgres/postgres.go @@ -21,7 +21,10 @@ type PostgresTestSyncContainer struct { Target *PostgresTestContainer } -func NewPostgresTestSyncContainer(ctx context.Context, sourceOpts, destOpts []Option) (*PostgresTestSyncContainer, error) { +func NewPostgresTestSyncContainer( + ctx context.Context, + sourceOpts, destOpts []Option, +) (*PostgresTestSyncContainer, error) { tc := &PostgresTestSyncContainer{} errgrp := errgroup.Group{} errgrp.Go(func() error { @@ -257,7 +260,11 @@ func (p *PostgresTestContainer) TearDown(ctx context.Context) error { } // Executes SQL files within the test container -func (p *PostgresTestContainer) RunSqlFiles(ctx context.Context, folder *string, files []string) error { +func (p *PostgresTestContainer) RunSqlFiles( + ctx context.Context, + folder *string, + files []string, +) error { for _, file := range files { filePath := file if folder != nil && *folder != "" { @@ -276,7 +283,12 @@ func (p *PostgresTestContainer) RunSqlFiles(ctx context.Context, folder *string, } // Creates schema and sets search_path to schema before running SQL files -func (p *PostgresTestContainer) RunCreateStmtsInSchema(ctx context.Context, folder string, files []string, schema string) error { +func (p *PostgresTestContainer) RunCreateStmtsInSchema( + ctx context.Context, + folder string, + files []string, + schema string, +) error { for _, file := range files { filePath := file if folder != "" { @@ -286,7 +298,11 @@ func (p *PostgresTestContainer) RunCreateStmtsInSchema(ctx context.Context, fold if err != nil { return err } - setSchemaSql := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %q; \n SET search_path TO %q; \n", schema, schema) + setSchemaSql := fmt.Sprintf( + "CREATE SCHEMA IF NOT EXISTS %q; \n SET search_path TO %q; \n", + schema, + schema, + ) _, err = p.DB.Exec(ctx, setSchemaSql+string(sqlStr)) if err != nil { return fmt.Errorf("unable to exec postgres create stmts in schema: %w", err) @@ -315,7 +331,10 @@ func (p *PostgresTestContainer) DropSchemas(ctx context.Context, schemas []strin return nil } -func (p *PostgresTestContainer) GetTableRowCount(ctx context.Context, schema, table string) (int, error) { +func (p *PostgresTestContainer) GetTableRowCount( + ctx context.Context, + schema, table string, +) (int, error) { rows := p.DB.QueryRow(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %q.%q;", schema, table)) var count int err := rows.Scan(&count) diff --git a/internal/testutil/testcontainers/sqlserver/sqlserver.go b/internal/testutil/testcontainers/sqlserver/sqlserver.go index 7ba4c11dbf..1dfc9d177d 100644 --- a/internal/testutil/testcontainers/sqlserver/sqlserver.go +++ b/internal/testutil/testcontainers/sqlserver/sqlserver.go @@ -24,7 +24,10 @@ type MssqlTestSyncContainer struct { Target *MssqlTestContainer } -func NewMssqlTestSyncContainer(ctx context.Context, sourceOpts, destOpts []Option) (*MssqlTestSyncContainer, error) { +func NewMssqlTestSyncContainer( + ctx context.Context, + sourceOpts, destOpts []Option, +) (*MssqlTestSyncContainer, error) { tc := &MssqlTestSyncContainer{} errgrp := errgroup.Group{} errgrp.Go(func() error { @@ -139,7 +142,8 @@ func setup(ctx context.Context, cfg *mssqlTestContainerConfig) (*MssqlTestContai testutil.WithDockerFile(mssqlDf), ) } - mssqlcontainer, err := testmssql.Run(ctx, + mssqlcontainer, err := testmssql.Run( + ctx, "mcr.microsoft.com/mssql/server:2022-latest", // WithDockerFile overrides the image and updates it to be empty tcOpts..., ) @@ -234,7 +238,11 @@ func (m *MssqlTestContainer) TearDown(ctx context.Context) error { } // Executes SQL files within the test container -func (m *MssqlTestContainer) RunSqlFiles(ctx context.Context, folder *string, files []string) error { +func (m *MssqlTestContainer) RunSqlFiles( + ctx context.Context, + folder *string, + files []string, +) error { for _, file := range files { filePath := file if folder != nil && *folder != "" { @@ -252,7 +260,12 @@ func (m *MssqlTestContainer) RunSqlFiles(ctx context.Context, folder *string, fi return nil } -func (m *MssqlTestContainer) RunCreateStmtsInSchema(ctx context.Context, folder string, files []string, schema string) error { +func (m *MssqlTestContainer) RunCreateStmtsInSchema( + ctx context.Context, + folder string, + files []string, + schema string, +) error { for _, file := range files { filePath := file if folder != "" { @@ -295,7 +308,10 @@ func (m *MssqlTestContainer) DropSchemas(ctx context.Context, schemas []string) return nil } -func (m *MssqlTestContainer) GetTableRowCount(ctx context.Context, schema, table string) (int, error) { +func (m *MssqlTestContainer) GetTableRowCount( + ctx context.Context, + schema, table string, +) (int, error) { rows := m.DB.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM [%s].[%s];", schema, table)) var count int err := rows.Scan(&count) diff --git a/worker/internal/cmds/worker/serve/serve.go b/worker/internal/cmds/worker/serve/serve.go index b8df416949..ec7c325df9 100644 --- a/worker/internal/cmds/worker/serve/serve.go +++ b/worker/internal/cmds/worker/serve/serve.go @@ -79,7 +79,9 @@ func NewCmd() *cobra.Command { func serve(ctx context.Context) error { logger, loglogger := neosynclogger.NewLoggers() - slog.SetDefault(logger) // set default logger for methods that can't easily access the configured logger + slog.SetDefault( + logger, + ) // set default logger for methods that can't easily access the configured logger eelicense, err := license.NewFromEnv() if err != nil { @@ -115,20 +117,28 @@ func serve(ctx context.Context) error { if otelconfig.IsEnabled { logger.Debug("otel is enabled") tmPropagator := neosyncotel.NewDefaultPropagator() - otelconnopts := []otelconnect.Option{otelconnect.WithoutServerPeerAttributes(), otelconnect.WithPropagator(tmPropagator)} + otelconnopts := []otelconnect.Option{ + otelconnect.WithoutServerPeerAttributes(), + otelconnect.WithPropagator(tmPropagator), + } meterProviders := []neosyncotel.MeterProvider{} traceProviders := []neosyncotel.TracerProvider{} // Meter Provider that uses delta temporality for use with Benthos metrics // This meter provider is setup expire metrics after a specified time period for easy computation - benthosMeterProvider, err := neosyncotel.NewMeterProvider(ctx, &neosyncotel.MeterProviderConfig{ - Exporter: otelconfig.MeterExporter, - AppVersion: otelconfig.ServiceVersion, - Opts: neosyncotel.MeterExporterOpts{ - Otlp: []otlpmetricgrpc.Option{neosyncotel.GetBenthosMetricTemporalityOption()}, - Console: []stdoutmetric.Option{stdoutmetric.WithPrettyPrint()}, + benthosMeterProvider, err := neosyncotel.NewMeterProvider( + ctx, + &neosyncotel.MeterProviderConfig{ + Exporter: otelconfig.MeterExporter, + AppVersion: otelconfig.ServiceVersion, + Opts: neosyncotel.MeterExporterOpts{ + Otlp: []otlpmetricgrpc.Option{ + neosyncotel.GetBenthosMetricTemporalityOption(), + }, + Console: []stdoutmetric.Option{stdoutmetric.WithPrettyPrint()}, + }, }, - }) + ) if err != nil { return err } @@ -138,36 +148,44 @@ func serve(ctx context.Context) error { syncActivityMeter = benthosMeterProvider.Meter("sync_activity") } - temporalMeterProvider, err := neosyncotel.NewMeterProvider(ctx, &neosyncotel.MeterProviderConfig{ - Exporter: otelconfig.MeterExporter, - AppVersion: otelconfig.ServiceVersion, - Opts: neosyncotel.MeterExporterOpts{ - Otlp: []otlpmetricgrpc.Option{}, - Console: []stdoutmetric.Option{stdoutmetric.WithPrettyPrint()}, + temporalMeterProvider, err := neosyncotel.NewMeterProvider( + ctx, + &neosyncotel.MeterProviderConfig{ + Exporter: otelconfig.MeterExporter, + AppVersion: otelconfig.ServiceVersion, + Opts: neosyncotel.MeterExporterOpts{ + Otlp: []otlpmetricgrpc.Option{}, + Console: []stdoutmetric.Option{stdoutmetric.WithPrettyPrint()}, + }, }, - }) + ) if err != nil { return err } if temporalMeterProvider != nil { logger.Debug("otel metering for temporal has been configured") meterProviders = append(meterProviders, temporalMeterProvider) - temopralMeterHandler = temporalotel.NewMetricsHandler(temporalotel.MetricsHandlerOptions{ - Meter: temporalMeterProvider.Meter("neosync-temporal-sdk"), - OnError: func(err error) { - logger.Error(fmt.Errorf("error with temporal metering: %w", err).Error()) + temopralMeterHandler = temporalotel.NewMetricsHandler( + temporalotel.MetricsHandlerOptions{ + Meter: temporalMeterProvider.Meter("neosync-temporal-sdk"), + OnError: func(err error) { + logger.Error(fmt.Errorf("error with temporal metering: %w", err).Error()) + }, }, - }) + ) } - neosyncMeterProvider, err := neosyncotel.NewMeterProvider(ctx, &neosyncotel.MeterProviderConfig{ - Exporter: otelconfig.MeterExporter, - AppVersion: otelconfig.ServiceVersion, - Opts: neosyncotel.MeterExporterOpts{ - Otlp: []otlpmetricgrpc.Option{}, - Console: []stdoutmetric.Option{stdoutmetric.WithPrettyPrint()}, + neosyncMeterProvider, err := neosyncotel.NewMeterProvider( + ctx, + &neosyncotel.MeterProviderConfig{ + Exporter: otelconfig.MeterExporter, + AppVersion: otelconfig.ServiceVersion, + Opts: neosyncotel.MeterExporterOpts{ + Otlp: []otlpmetricgrpc.Option{}, + Console: []stdoutmetric.Option{stdoutmetric.WithPrettyPrint()}, + }, }, - }) + ) if err != nil { return err } @@ -179,41 +197,55 @@ func serve(ctx context.Context) error { otelconnopts = append(otelconnopts, otelconnect.WithoutMetrics()) } - temporalTraceProvider, err := neosyncotel.NewTraceProvider(ctx, &neosyncotel.TraceProviderConfig{ - Exporter: otelconfig.TraceExporter, - Opts: neosyncotel.TraceExporterOpts{ - Otlp: []otlptracegrpc.Option{}, - Console: []stdouttrace.Option{stdouttrace.WithPrettyPrint()}, + temporalTraceProvider, err := neosyncotel.NewTraceProvider( + ctx, + &neosyncotel.TraceProviderConfig{ + Exporter: otelconfig.TraceExporter, + Opts: neosyncotel.TraceExporterOpts{ + Otlp: []otlptracegrpc.Option{}, + Console: []stdouttrace.Option{stdouttrace.WithPrettyPrint()}, + }, }, - }) + ) if err != nil { return err } if temporalTraceProvider != nil { logger.Debug("otel tracing for temporal has been configured") - temporalTraceInterceptor, err := temporalotel.NewTracingInterceptor(temporalotel.TracerOptions{ - Tracer: temporalTraceProvider.Tracer("neosync-temporal-sdk"), - }) + temporalTraceInterceptor, err := temporalotel.NewTracingInterceptor( + temporalotel.TracerOptions{ + Tracer: temporalTraceProvider.Tracer("neosync-temporal-sdk"), + }, + ) if err != nil { return err } - temporalClientInterceptors = append(temporalClientInterceptors, temporalTraceInterceptor) + temporalClientInterceptors = append( + temporalClientInterceptors, + temporalTraceInterceptor, + ) traceProviders = append(traceProviders, temporalTraceProvider) } - neosyncTraceProvider, err := neosyncotel.NewTraceProvider(ctx, &neosyncotel.TraceProviderConfig{ - Exporter: otelconfig.TraceExporter, - Opts: neosyncotel.TraceExporterOpts{ - Otlp: []otlptracegrpc.Option{}, - Console: []stdouttrace.Option{stdouttrace.WithPrettyPrint()}, + neosyncTraceProvider, err := neosyncotel.NewTraceProvider( + ctx, + &neosyncotel.TraceProviderConfig{ + Exporter: otelconfig.TraceExporter, + Opts: neosyncotel.TraceExporterOpts{ + Otlp: []otlptracegrpc.Option{}, + Console: []stdouttrace.Option{stdouttrace.WithPrettyPrint()}, + }, }, - }) + ) if err != nil { return err } if neosyncTraceProvider != nil { logger.Debug("otel tracing for neosync clients has been configured") - otelconnopts = append(otelconnopts, otelconnect.WithTracerProvider(neosyncTraceProvider)) + otelconnopts = append( + otelconnopts, + otelconnect.WithTracerProvider(neosyncTraceProvider), + ) } else { otelconnopts = append(otelconnopts, otelconnect.WithoutTracing(), otelconnect.WithoutTraceEvents()) } @@ -232,13 +264,18 @@ func serve(ctx context.Context) error { }) defer func() { if err := otelshutdown(context.Background()); err != nil { - logger.Error(fmt.Errorf("unable to gracefully shutdown otel providers: %w", err).Error()) + logger.Error( + fmt.Errorf("unable to gracefully shutdown otel providers: %w", err).Error(), + ) } }() } // Ensure that the retry interceptor comes after the otel interceptor - connectInterceptors = append(connectInterceptors, retry_interceptor.DefaultRetryInterceptor(logger)) + connectInterceptors = append( + connectInterceptors, + retry_interceptor.DefaultRetryInterceptor(logger), + ) temporalUrl := viper.GetString("TEMPORAL_URL") if temporalUrl == "" { @@ -296,11 +333,31 @@ func serve(ctx context.Context) error { neosyncurl := shared.GetNeosyncUrl() httpclient := shared.GetNeosyncHttpClient() connectInterceptorOption := connect.WithInterceptors(connectInterceptors...) - userclient := mgmtv1alpha1connect.NewUserAccountServiceClient(httpclient, neosyncurl, connectInterceptorOption) - connclient := mgmtv1alpha1connect.NewConnectionServiceClient(httpclient, neosyncurl, connectInterceptorOption) - jobclient := mgmtv1alpha1connect.NewJobServiceClient(httpclient, neosyncurl, connectInterceptorOption) - transformerclient := mgmtv1alpha1connect.NewTransformersServiceClient(httpclient, neosyncurl, connectInterceptorOption) - accounthookclient := mgmtv1alpha1connect.NewAccountHookServiceClient(httpclient, neosyncurl, connectInterceptorOption) + userclient := mgmtv1alpha1connect.NewUserAccountServiceClient( + httpclient, + neosyncurl, + connectInterceptorOption, + ) + connclient := mgmtv1alpha1connect.NewConnectionServiceClient( + httpclient, + neosyncurl, + connectInterceptorOption, + ) + jobclient := mgmtv1alpha1connect.NewJobServiceClient( + httpclient, + neosyncurl, + connectInterceptorOption, + ) + transformerclient := mgmtv1alpha1connect.NewTransformersServiceClient( + httpclient, + neosyncurl, + connectInterceptorOption, + ) + accounthookclient := mgmtv1alpha1connect.NewAccountHookServiceClient( + httpclient, + neosyncurl, + connectInterceptorOption, + ) sqlConnector := &sqlconnect.SqlOpenConnector{} sqlconnmanager := connectionmanager.NewConnectionManager(sqlprovider.NewProvider(sqlConnector)) @@ -370,7 +427,15 @@ func serve(ctx context.Context) error { neosynctyperegistry, ) - piidetect_workflow_register.Register(w, connclient, jobclient, openaiclient, conndatabuilder, cascadelicense, temporalClient.ScheduleClient()) + piidetect_workflow_register.Register( + w, + connclient, + jobclient, + openaiclient, + conndatabuilder, + cascadelicense, + temporalClient.ScheduleClient(), + ) } if err := w.Start(); err != nil { diff --git a/worker/internal/temporal-logger/logger.go b/worker/internal/temporal-logger/logger.go index 53be4ee60b..d4c44123e1 100644 --- a/worker/internal/temporal-logger/logger.go +++ b/worker/internal/temporal-logger/logger.go @@ -20,7 +20,10 @@ func (h *temporalLogHandler) Enabled(ctx context.Context, level slog.Level) bool return true } -func (h *temporalLogHandler) Handle(ctx context.Context, r slog.Record) error { //nolint:gocritic // Needs to conform to the slog.Handler interface +func (h *temporalLogHandler) Handle( + ctx context.Context, + r slog.Record, +) error { //nolint:gocritic // Needs to conform to the slog.Handler interface // Combine pre-defined attrs with record attrs allAttrs := make([]slog.Attr, 0, len(h.attrs)+r.NumAttrs()) allAttrs = append(allAttrs, h.attrs...) diff --git a/worker/pkg/benthos/config.go b/worker/pkg/benthos/config.go index d23f5b2ec6..946071622b 100644 --- a/worker/pkg/benthos/config.go +++ b/worker/pkg/benthos/config.go @@ -5,158 +5,158 @@ type BenthosConfig struct { } type StreamConfig struct { - Logger *LoggerConfig `json:"logger" yaml:"logger,omitempty"` - Input *InputConfig `json:"input" yaml:"input"` - Pipeline *PipelineConfig `json:"pipeline" yaml:"pipeline"` - Output *OutputConfig `json:"output" yaml:"output"` + Logger *LoggerConfig `json:"logger" yaml:"logger,omitempty"` + Input *InputConfig `json:"input" yaml:"input"` + Pipeline *PipelineConfig `json:"pipeline" yaml:"pipeline"` + Output *OutputConfig `json:"output" yaml:"output"` Metrics *Metrics `json:"metrics,omitempty" yaml:"metrics,omitempty"` } type LoggerConfig struct { - Level string `json:"level" yaml:"level"` + Level string `json:"level" yaml:"level"` AddTimestamp bool `json:"add_timestamp" yaml:"add_timestamp"` } type Metrics struct { OtelCollector *MetricsOtelCollector `json:"otel_collector,omitempty" yaml:"otel_collector,omitempty"` - Mapping string `json:"mapping,omitempty" yaml:"mapping,omitempty"` + Mapping string `json:"mapping,omitempty" yaml:"mapping,omitempty"` } type MetricsOtelCollector struct { } type InputConfig struct { - Label string `json:"label" yaml:"label"` - Inputs `json:",inline" yaml:",inline"` + Label string `json:"label" yaml:"label"` + Inputs ` json:",inline" yaml:",inline"` } type Inputs struct { - PooledSqlRaw *InputPooledSqlRaw `json:"pooled_sql_raw,omitempty" yaml:"pooled_sql_raw,omitempty"` - Generate *Generate `json:"generate,omitempty" yaml:"generate,omitempty"` - OpenAiGenerate *OpenAiGenerate `json:"openai_generate,omitempty" yaml:"openai_generate,omitempty"` - PooledMongoDB *InputMongoDb `json:"pooled_mongodb,omitempty" yaml:"pooled_mongodb,omitempty"` - AwsDynamoDB *InputAwsDynamoDB `json:"aws_dynamodb,omitempty" yaml:"aws_dynamodb,omitempty"` + PooledSqlRaw *InputPooledSqlRaw `json:"pooled_sql_raw,omitempty" yaml:"pooled_sql_raw,omitempty"` + Generate *Generate `json:"generate,omitempty" yaml:"generate,omitempty"` + OpenAiGenerate *OpenAiGenerate `json:"openai_generate,omitempty" yaml:"openai_generate,omitempty"` + PooledMongoDB *InputMongoDb `json:"pooled_mongodb,omitempty" yaml:"pooled_mongodb,omitempty"` + AwsDynamoDB *InputAwsDynamoDB `json:"aws_dynamodb,omitempty" yaml:"aws_dynamodb,omitempty"` NeosyncConnectionData *NeosyncConnectionData `json:"neosync_connection_data,omitempty" yaml:"neosync_connection_data,omitempty"` - Broker *InputBrokerConfig `json:"broker,omitempty" yaml:"broker,omitempty"` + Broker *InputBrokerConfig `json:"broker,omitempty" yaml:"broker,omitempty"` } type NeosyncConnectionData struct { - ConnectionId string `json:"connection_id" yaml:"connection_id"` - ConnectionType string `json:"connection_type" yaml:"connection_type"` - JobId *string `json:"job_id,omitempty" yaml:"job_id,omitempty"` + ConnectionId string `json:"connection_id" yaml:"connection_id"` + ConnectionType string `json:"connection_type" yaml:"connection_type"` + JobId *string `json:"job_id,omitempty" yaml:"job_id,omitempty"` JobRunId *string `json:"job_run_id,omitempty" yaml:"job_run_id,omitempty"` - Schema string `json:"schema" yaml:"schema"` - Table string `json:"table" yaml:"table"` + Schema string `json:"schema" yaml:"schema"` + Table string `json:"table" yaml:"table"` } type InputAwsDynamoDB struct { - Table string `json:"table" yaml:"table"` + Table string `json:"table" yaml:"table"` Where *string `json:"where,omitempty" yaml:"where,omitempty"` ConsistentRead bool `json:"consistent_read" yaml:"consistent_read"` - Region string `json:"region,omitempty" yaml:"region,omitempty"` + Region string `json:"region,omitempty" yaml:"region,omitempty"` Endpoint string `json:"endpoint,omitempty" yaml:"endpoint,omitempty"` Credentials *AwsCredentials `json:"credentials,omitempty" yaml:"credentials,omitempty"` } type OutputAwsDynamoDB struct { - Table string `json:"table" yaml:"table"` + Table string `json:"table" yaml:"table"` JsonMapColumns map[string]string `json:"json_map_columns,omitempty" yaml:"json_map_columns,omitempty"` - Region string `json:"region,omitempty" yaml:"region,omitempty"` + Region string `json:"region,omitempty" yaml:"region,omitempty"` Endpoint string `json:"endpoint,omitempty" yaml:"endpoint,omitempty"` Credentials *AwsCredentials `json:"credentials,omitempty" yaml:"credentials,omitempty"` MaxInFlight *int `json:"max_in_flight,omitempty" yaml:"max_in_flight,omitempty"` - Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` + Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` } type InputMongoDb struct { - ConnectionId string `json:"connection_id" yaml:"connection_id"` - Database string `json:"database" yaml:"database"` - Username string `json:"username,omitempty" yaml:"username,omitempty"` - Password string `json:"password,omitempty" yaml:"password,omitempty"` - Operation *string `json:"operation,omitempty" yaml:"operation,omitempty"` - Collection string `json:"collection" yaml:"collection"` + ConnectionId string `json:"connection_id" yaml:"connection_id"` + Database string `json:"database" yaml:"database"` + Username string `json:"username,omitempty" yaml:"username,omitempty"` + Password string `json:"password,omitempty" yaml:"password,omitempty"` + Operation *string `json:"operation,omitempty" yaml:"operation,omitempty"` + Collection string `json:"collection" yaml:"collection"` JsonMarshalMode *string `json:"json_marshal_mode,omitempty" yaml:"json_marshal_mode,omitempty"` - Query string `json:"query" yaml:"query"` + Query string `json:"query" yaml:"query"` AutoReplayNacks *bool `json:"auto_replay_nacks,omitempty" yaml:"auto_replay_nacks,omitempty"` - BatchSize *int32 `json:"batch_size,omitempty" yaml:"batch_size,omitempty"` - Sort map[string]int `json:"sort,omitempty" yaml:"sort,omitempty"` - Limit *int32 `json:"limit,omitempty" yaml:"limit,omitempty"` + BatchSize *int32 `json:"batch_size,omitempty" yaml:"batch_size,omitempty"` + Sort map[string]int `json:"sort,omitempty" yaml:"sort,omitempty"` + Limit *int32 `json:"limit,omitempty" yaml:"limit,omitempty"` } type OutputMongoDb struct { - ConnectionId string `json:"connection_id" yaml:"connection_id"` - Database string `json:"database" yaml:"database"` - Username string `json:"username,omitempty" yaml:"username,omitempty"` - Password string `json:"password,omitempty" yaml:"password,omitempty"` - Operation string `json:"operation" yaml:"operation"` - Collection string `json:"collection" yaml:"collection"` - DocumentMap string `json:"document_map" yaml:"document_map"` - FilterMap string `json:"filter_map" yaml:"filter_map"` - HintMap string `json:"hint_map" yaml:"hint_map"` - Upsert bool `json:"upsert" yaml:"upsert"` + ConnectionId string `json:"connection_id" yaml:"connection_id"` + Database string `json:"database" yaml:"database"` + Username string `json:"username,omitempty" yaml:"username,omitempty"` + Password string `json:"password,omitempty" yaml:"password,omitempty"` + Operation string `json:"operation" yaml:"operation"` + Collection string `json:"collection" yaml:"collection"` + DocumentMap string `json:"document_map" yaml:"document_map"` + FilterMap string `json:"filter_map" yaml:"filter_map"` + HintMap string `json:"hint_map" yaml:"hint_map"` + Upsert bool `json:"upsert" yaml:"upsert"` MaxInFlight *int `json:"max_in_flight,omitempty" yaml:"max_in_flight,omitempty"` - Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` + Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` WriteConcern *MongoWriteConcern `json:"write_concern,omitempty" yaml:"write_concern,omitempty"` } type MongoWriteConcern struct { - W string `json:"w,omitempty" yaml:"w,omitempty"` - J string `json:"j,omitempty" yaml:"j,omitempty"` + W string `json:"w,omitempty" yaml:"w,omitempty"` + J string `json:"j,omitempty" yaml:"j,omitempty"` WTimeout string `json:"w_timeout,omitempty" yaml:"w_timeout,omitempty"` } type OpenAiGenerate struct { - ApiUrl string `json:"api_url" yaml:"api_url"` - ApiKey string `json:"api_key" yaml:"api_key"` + ApiUrl string `json:"api_url" yaml:"api_url"` + ApiKey string `json:"api_key" yaml:"api_key"` UserPrompt *string `json:"user_prompt,omitempty" yaml:"user_prompt,omitempty"` - Columns []string `json:"columns" yaml:"columns"` - DataTypes []string `json:"data_types" yaml:"data_types"` - Model string `json:"model" yaml:"model"` - Count int `json:"count" yaml:"count"` - BatchSize int `json:"batch_size" yaml:"batch_size"` + Columns []string `json:"columns" yaml:"columns"` + DataTypes []string `json:"data_types" yaml:"data_types"` + Model string `json:"model" yaml:"model"` + Count int `json:"count" yaml:"count"` + BatchSize int `json:"batch_size" yaml:"batch_size"` } type Generate struct { Mapping string `json:"mapping" yaml:"mapping"` - Interval string `json:"interval" yaml:"interval"` - Count int `json:"count" yaml:"count"` + Interval string `json:"interval" yaml:"interval"` + Count int `json:"count" yaml:"count"` BatchSize *int `json:"batch_size,omitempty" yaml:"batch_size,omitempty"` } type InputPooledSqlRaw struct { - ConnectionId string `json:"connection_id" yaml:"connection_id"` - Query string `json:"query" yaml:"query"` + ConnectionId string `json:"connection_id" yaml:"connection_id"` + Query string `json:"query" yaml:"query"` PagedQuery string `json:"paged_query,omitempty" yaml:"paged_query,omitempty"` ExpectedTotalRows *int `json:"expected_total_rows,omitempty" yaml:"expected_total_rows,omitempty"` - OrderByColumns []string `json:"order_by_columns,omitempty" yaml:"order_by_columns,omitempty"` + OrderByColumns []string `json:"order_by_columns,omitempty" yaml:"order_by_columns,omitempty"` } type PipelineConfig struct { - Threads int `json:"threads" yaml:"threads"` + Threads int `json:"threads" yaml:"threads"` Processors []ProcessorConfig `json:"processors" yaml:"processors"` } type ProcessorConfig struct { - Mutation *string `json:"mutation,omitempty" yaml:"mutation,omitempty"` - NeosyncJavascript *NeosyncJavascriptConfig `json:"neosync_javascript,omitempty" yaml:"neosync_javascript,omitempty"` - Branch *BranchConfig `json:"branch,omitempty" yaml:"branch,omitempty"` - Mapping *string `json:"mapping,omitempty" yaml:"mapping,omitempty"` - Redis *RedisProcessorConfig `json:"redis,omitempty" yaml:"redis,omitempty"` - Error *ErrorProcessorConfig `json:"error,omitempty" yaml:"error,omitempty"` - Catch []*ProcessorConfig `json:"catch,omitempty" yaml:"catch,omitempty"` + Mutation *string `json:"mutation,omitempty" yaml:"mutation,omitempty"` + NeosyncJavascript *NeosyncJavascriptConfig `json:"neosync_javascript,omitempty" yaml:"neosync_javascript,omitempty"` + Branch *BranchConfig `json:"branch,omitempty" yaml:"branch,omitempty"` + Mapping *string `json:"mapping,omitempty" yaml:"mapping,omitempty"` + Redis *RedisProcessorConfig `json:"redis,omitempty" yaml:"redis,omitempty"` + Error *ErrorProcessorConfig `json:"error,omitempty" yaml:"error,omitempty"` + Catch []*ProcessorConfig `json:"catch,omitempty" yaml:"catch,omitempty"` NeosyncDefaultTransformer *NeosyncDefaultTransformerConfig `json:"neosync_default_transformer,omitempty" yaml:"neosync_default_transformer,omitempty"` } type NeosyncDefaultTransformerConfig struct { JobSourceOptionsString string `json:"job_source_options_string" yaml:"job_source_options_string"` - MappedKeys []string `json:"mapped_keys" yaml:"mapped_keys"` + MappedKeys []string `json:"mapped_keys" yaml:"mapped_keys"` } type NeosyncJavascriptConfig struct { @@ -168,192 +168,192 @@ type ErrorProcessorConfig struct { } type RedisProcessorConfig struct { - Url string `json:"url" yaml:"url"` - Command string `json:"command" yaml:"command"` - ArgsMapping string `json:"args_mapping" yaml:"args_mapping"` - Kind *string `json:"kind,omitempty" yaml:"kind,omitempty"` + Url string `json:"url" yaml:"url"` + Command string `json:"command" yaml:"command"` + ArgsMapping string `json:"args_mapping" yaml:"args_mapping"` + Kind *string `json:"kind,omitempty" yaml:"kind,omitempty"` Master *string `json:"master,omitempty" yaml:"master,omitempty"` - Tls *RedisTlsConfig `json:"tls,omitempty" yaml:"tls,omitempty"` + Tls *RedisTlsConfig `json:"tls,omitempty" yaml:"tls,omitempty"` } type RedisTlsConfig struct { - Enabled bool `json:"enabled" yaml:"enabled"` - SkipCertVerify bool `json:"skip_cert_verify" yaml:"skip_cert_verify"` - EnableRenegotiation bool `json:"enable_renegotiation" yaml:"enable_renegotiation"` - RootCas *string `json:"root_cas,omitempty" yaml:"root_cas,omitempty"` - RootCasFile *string `json:"root_cas_file,omitempty" yaml:"root_cas_file,omitempty"` + Enabled bool `json:"enabled" yaml:"enabled"` + SkipCertVerify bool `json:"skip_cert_verify" yaml:"skip_cert_verify"` + EnableRenegotiation bool `json:"enable_renegotiation" yaml:"enable_renegotiation"` + RootCas *string `json:"root_cas,omitempty" yaml:"root_cas,omitempty"` + RootCasFile *string `json:"root_cas_file,omitempty" yaml:"root_cas_file,omitempty"` } type BranchConfig struct { - Processors []ProcessorConfig `json:"processors" yaml:"processors"` + Processors []ProcessorConfig `json:"processors" yaml:"processors"` RequestMap *string `json:"request_map,omitempty" yaml:"request_map,omitempty"` - ResultMap *string `json:"result_map,omitempty" yaml:"result_map,omitempty"` + ResultMap *string `json:"result_map,omitempty" yaml:"result_map,omitempty"` } type OutputConfig struct { - Label string `json:"label" yaml:"label"` - Outputs `json:",inline" yaml:",inline"` + Label string `json:"label" yaml:"label"` + Outputs ` json:",inline" yaml:",inline"` Processors []ProcessorConfig `json:"processors,omitempty" yaml:"processors,omitempty"` } type Outputs struct { PooledSqlInsert *PooledSqlInsert `json:"pooled_sql_insert,omitempty" yaml:"pooled_sql_insert,omitempty"` PooledSqlUpdate *PooledSqlUpdate `json:"pooled_sql_update,omitempty" yaml:"pooled_sql_update,omitempty"` - AwsS3 *AwsS3Insert `json:"aws_s3,omitempty" yaml:"aws_s3,omitempty"` + AwsS3 *AwsS3Insert `json:"aws_s3,omitempty" yaml:"aws_s3,omitempty"` GcpCloudStorage *GcpCloudStorageOutput `json:"gcp_cloud_storage,omitempty" yaml:"gcp_cloud_storage,omitempty"` - Retry *RetryConfig `json:"retry,omitempty" yaml:"retry,omitempty"` - Broker *OutputBrokerConfig `json:"broker,omitempty" yaml:"broker,omitempty"` - Fallback []Outputs `json:"fallback,omitempty" yaml:"fallback,omitempty"` + Retry *RetryConfig `json:"retry,omitempty" yaml:"retry,omitempty"` + Broker *OutputBrokerConfig `json:"broker,omitempty" yaml:"broker,omitempty"` + Fallback []Outputs `json:"fallback,omitempty" yaml:"fallback,omitempty"` RedisHashOutput *RedisHashOutputConfig `json:"redis_hash_output,omitempty" yaml:"redis_hash_output,omitempty"` - Error *ErrorOutputConfig `json:"error,omitempty" yaml:"error,omitempty"` - PooledMongoDB *OutputMongoDb `json:"pooled_mongodb,omitempty" yaml:"pooled_mongodb,omitempty"` - AwsDynamoDB *OutputAwsDynamoDB `json:"aws_dynamodb,omitempty" yaml:"aws_dynamodb,omitempty"` + Error *ErrorOutputConfig `json:"error,omitempty" yaml:"error,omitempty"` + PooledMongoDB *OutputMongoDb `json:"pooled_mongodb,omitempty" yaml:"pooled_mongodb,omitempty"` + AwsDynamoDB *OutputAwsDynamoDB `json:"aws_dynamodb,omitempty" yaml:"aws_dynamodb,omitempty"` } type ErrorOutputConfig struct { - ErrorMsg string `json:"error_msg" yaml:"error_msg"` - IsGenerateJob bool `json:"is_generate_job" yaml:"is_generate_job"` + ErrorMsg string `json:"error_msg" yaml:"error_msg"` + IsGenerateJob bool `json:"is_generate_job" yaml:"is_generate_job"` Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` } type RedisHashOutputConfig struct { - Url string `json:"url" yaml:"url"` - Key string `json:"key" yaml:"key"` - WalkMetadata bool `json:"walk_metadata" yaml:"walk_metadata"` - WalkJsonObject bool `json:"walk_json_object" yaml:"walk_json_object"` - FieldsMapping string `json:"fields_mapping" yaml:"fields_mapping"` + Url string `json:"url" yaml:"url"` + Key string `json:"key" yaml:"key"` + WalkMetadata bool `json:"walk_metadata" yaml:"walk_metadata"` + WalkJsonObject bool `json:"walk_json_object" yaml:"walk_json_object"` + FieldsMapping string `json:"fields_mapping" yaml:"fields_mapping"` MaxInFlight *int `json:"max_in_flight,omitempty" yaml:"max_in_flight,omitempty"` - Kind *string `json:"kind,omitempty" yaml:"kind,omitempty"` - Master *string `json:"master,omitempty" yaml:"master,omitempty"` - Tls *RedisTlsConfig `json:"tls,omitempty" yaml:"tls,omitempty"` + Kind *string `json:"kind,omitempty" yaml:"kind,omitempty"` + Master *string `json:"master,omitempty" yaml:"master,omitempty"` + Tls *RedisTlsConfig `json:"tls,omitempty" yaml:"tls,omitempty"` } type RetryConfig struct { - Output OutputConfig `json:"output" yaml:"output"` - InlineRetryConfig `json:",inline" yaml:",inline"` + Output OutputConfig `json:"output" yaml:"output"` + InlineRetryConfig ` json:",inline" yaml:",inline"` } type InlineRetryConfig struct { MaxRetries uint64 `json:"max_retries" yaml:"max_retries"` - Backoff Backoff `json:"backoff" yaml:"backoff"` + Backoff Backoff `json:"backoff" yaml:"backoff"` } type Backoff struct { InitialInterval string `json:"initial_interval,omitempty" yaml:"initial_interval,omitempty"` - MaxInterval string `json:"max_interval,omitempty" yaml:"max_interval,omitempty"` + MaxInterval string `json:"max_interval,omitempty" yaml:"max_interval,omitempty"` MaxElapsedTime string `json:"max_elapsed_time,omitempty" yaml:"max_elapsed_time,omitempty"` } type PooledSqlUpdate struct { - ConnectionId string `json:"connection_id" yaml:"connection_id"` - Schema string `json:"schema" yaml:"schema"` - Table string `json:"table" yaml:"table"` - Columns []string `json:"columns" yaml:"columns"` - WhereColumns []string `json:"where_columns" yaml:"where_columns"` + ConnectionId string `json:"connection_id" yaml:"connection_id"` + Schema string `json:"schema" yaml:"schema"` + Table string `json:"table" yaml:"table"` + Columns []string `json:"columns" yaml:"columns"` + WhereColumns []string `json:"where_columns" yaml:"where_columns"` SkipForeignKeyViolations bool `json:"skip_foreign_key_violations" yaml:"skip_foreign_key_violations"` - Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` - MaxInFlight int `json:"max_in_flight,omitempty" yaml:"max_in_flight,omitempty"` + Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` + MaxInFlight int `json:"max_in_flight,omitempty" yaml:"max_in_flight,omitempty"` } type ColumnDefaultProperties struct { - NeedsReset bool `json:"needs_reset" yaml:"needs_reset"` - NeedsOverride bool `json:"needs_override" yaml:"needs_override"` + NeedsReset bool `json:"needs_reset" yaml:"needs_reset"` + NeedsOverride bool `json:"needs_override" yaml:"needs_override"` HasDefaultTransformer bool `json:"has_default_transformer" yaml:"has_default_transformer"` } type PooledSqlInsert struct { - ConnectionId string `json:"connection_id" yaml:"connection_id"` - Schema string `json:"schema" yaml:"schema"` - Table string `json:"table" yaml:"table"` - PrimaryKeyColumns []string `json:"primary_key_columns" yaml:"primary_key_columns"` - ColumnUpdatesDisallowed []string `json:"column_updates_disallowed" yaml:"column_updates_disallowed"` - OnConflictDoNothing bool `json:"on_conflict_do_nothing" yaml:"on_conflict_do_nothing"` - OnConflictDoUpdate bool `json:"on_conflict_do_update" yaml:"on_conflict_do_update"` - TruncateOnRetry bool `json:"truncate_on_retry" yaml:"truncate_on_retry"` - SkipForeignKeyViolations bool `json:"skip_foreign_key_violations" yaml:"skip_foreign_key_violations"` + ConnectionId string `json:"connection_id" yaml:"connection_id"` + Schema string `json:"schema" yaml:"schema"` + Table string `json:"table" yaml:"table"` + PrimaryKeyColumns []string `json:"primary_key_columns" yaml:"primary_key_columns"` + ColumnUpdatesDisallowed []string `json:"column_updates_disallowed" yaml:"column_updates_disallowed"` + OnConflictDoNothing bool `json:"on_conflict_do_nothing" yaml:"on_conflict_do_nothing"` + OnConflictDoUpdate bool `json:"on_conflict_do_update" yaml:"on_conflict_do_update"` + TruncateOnRetry bool `json:"truncate_on_retry" yaml:"truncate_on_retry"` + SkipForeignKeyViolations bool `json:"skip_foreign_key_violations" yaml:"skip_foreign_key_violations"` ShouldOverrideColumnDefault bool `json:"should_override_column_default" yaml:"should_override_column_default"` - Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` - Prefix *string `json:"prefix,omitempty" yaml:"prefix,omitempty"` - Suffix *string `json:"suffix,omitempty" yaml:"suffix,omitempty"` - MaxInFlight int `json:"max_in_flight,omitempty" yaml:"max_in_flight,omitempty"` + Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` + Prefix *string `json:"prefix,omitempty" yaml:"prefix,omitempty"` + Suffix *string `json:"suffix,omitempty" yaml:"suffix,omitempty"` + MaxInFlight int `json:"max_in_flight,omitempty" yaml:"max_in_flight,omitempty"` } type AwsS3Insert struct { - Bucket string `json:"bucket" yaml:"bucket"` - MaxInFlight int `json:"max_in_flight" yaml:"max_in_flight"` - Path string `json:"path" yaml:"path"` - Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` - Timeout string `json:"timeout,omitempty" yaml:"timeout,omitempty"` + Bucket string `json:"bucket" yaml:"bucket"` + MaxInFlight int `json:"max_in_flight" yaml:"max_in_flight"` + Path string `json:"path" yaml:"path"` + Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` + Timeout string `json:"timeout,omitempty" yaml:"timeout,omitempty"` StorageClass string `json:"storage_class,omitempty" yaml:"storage_class,omitempty"` - ContentType string `json:"content_type,omitempty" yaml:"content_type,omitempty"` + ContentType string `json:"content_type,omitempty" yaml:"content_type,omitempty"` - Region string `json:"region,omitempty" yaml:"region,omitempty"` + Region string `json:"region,omitempty" yaml:"region,omitempty"` Endpoint string `json:"endpoint,omitempty" yaml:"endpoint,omitempty"` Credentials *AwsCredentials `json:"credentials,omitempty" yaml:"credentials,omitempty"` } type AwsCredentials struct { - Profile string `json:"profile,omitempty" yaml:"profile,omitempty"` - Id string `json:"id,omitempty" yaml:"id,omitempty"` - Secret string `json:"secret,omitempty" yaml:"secret,omitempty"` - Token string `json:"token,omitempty" yaml:"token,omitempty"` - FromEc2Role bool `json:"from_ec2_role,omitempty" yaml:"from_ec2_role,omitempty"` - Role string `json:"role,omitempty" yaml:"role,omitempty"` + Profile string `json:"profile,omitempty" yaml:"profile,omitempty"` + Id string `json:"id,omitempty" yaml:"id,omitempty"` + Secret string `json:"secret,omitempty" yaml:"secret,omitempty"` + Token string `json:"token,omitempty" yaml:"token,omitempty"` + FromEc2Role bool `json:"from_ec2_role,omitempty" yaml:"from_ec2_role,omitempty"` + Role string `json:"role,omitempty" yaml:"role,omitempty"` RoleExternalId string `json:"role_external_id,omitempty" yaml:"role_external_id,omitempty"` } type GcpCloudStorageOutput struct { - Bucket string `json:"bucket" yaml:"bucket"` - Path string `json:"path" yaml:"path"` - MaxInFlight int `json:"max_in_flight" yaml:"max_in_flight"` + Bucket string `json:"bucket" yaml:"bucket"` + Path string `json:"path" yaml:"path"` + MaxInFlight int `json:"max_in_flight" yaml:"max_in_flight"` Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` - ContentType *string `json:"content_type,omitempty" yaml:"content_type,omitempty"` + ContentType *string `json:"content_type,omitempty" yaml:"content_type,omitempty"` ContentEncoding *string `json:"content_encoding,omitempty" yaml:"content_encoding,omitempty"` - CollisionMode *string `json:"collision_mode,omitempty" yaml:"collision_mode,omitempty"` - ChunkSize *int `json:"chunk_size,omitempty" yaml:"chunk_size,omitempty"` - Timeout *string `json:"timeout,omitempty" yaml:"timeout,omitempty"` + CollisionMode *string `json:"collision_mode,omitempty" yaml:"collision_mode,omitempty"` + ChunkSize *int `json:"chunk_size,omitempty" yaml:"chunk_size,omitempty"` + Timeout *string `json:"timeout,omitempty" yaml:"timeout,omitempty"` } type Batching struct { - Count int `json:"count" yaml:"count"` - ByteSize int `json:"byte_size" yaml:"byte_size"` - Period string `json:"period" yaml:"period"` - Check string `json:"check" yaml:"check"` + Count int `json:"count" yaml:"count"` + ByteSize int `json:"byte_size" yaml:"byte_size"` + Period string `json:"period" yaml:"period"` + Check string `json:"check" yaml:"check"` Processors []*BatchProcessor `json:"processors" yaml:"processors"` } type BatchProcessor struct { - Archive *ArchiveProcessor `json:"archive,omitempty" yaml:"archive,omitempty"` - Compress *CompressProcessor `json:"compress,omitempty" yaml:"compress,omitempty"` - NeosyncToJson *NeosyncToJsonConfig `json:"neosync_to_json,omitempty" yaml:"neosync_to_json,omitempty"` - NeosyncToPgx *NeosyncToPgxConfig `json:"neosync_to_pgx,omitempty" yaml:"neosync_to_pgx,omitempty"` + Archive *ArchiveProcessor `json:"archive,omitempty" yaml:"archive,omitempty"` + Compress *CompressProcessor `json:"compress,omitempty" yaml:"compress,omitempty"` + NeosyncToJson *NeosyncToJsonConfig `json:"neosync_to_json,omitempty" yaml:"neosync_to_json,omitempty"` + NeosyncToPgx *NeosyncToPgxConfig `json:"neosync_to_pgx,omitempty" yaml:"neosync_to_pgx,omitempty"` NeosyncToMysql *NeosyncToMysqlConfig `json:"neosync_to_mysql,omitempty" yaml:"neosync_to_mysql,omitempty"` NeosyncToMssql *NeosyncToMssqlConfig `json:"neosync_to_mssql,omitempty" yaml:"neosync_to_mssql,omitempty"` } type NeosyncToPgxConfig struct { - Columns []string `json:"columns" yaml:"columns"` - ColumnDataTypes map[string]string `json:"column_data_types" yaml:"column_data_types"` + Columns []string `json:"columns" yaml:"columns"` + ColumnDataTypes map[string]string `json:"column_data_types" yaml:"column_data_types"` ColumnDefaultProperties map[string]*ColumnDefaultProperties `json:"column_default_properties" yaml:"column_default_properties"` } type NeosyncToMysqlConfig struct { - Columns []string `json:"columns" yaml:"columns"` - ColumnDataTypes map[string]string `json:"column_data_types" yaml:"column_data_types"` + Columns []string `json:"columns" yaml:"columns"` + ColumnDataTypes map[string]string `json:"column_data_types" yaml:"column_data_types"` ColumnDefaultProperties map[string]*ColumnDefaultProperties `json:"column_default_properties" yaml:"column_default_properties"` } type NeosyncToMssqlConfig struct { - Columns []string `json:"columns" yaml:"columns"` - ColumnDataTypes map[string]string `json:"column_data_types" yaml:"column_data_types"` + Columns []string `json:"columns" yaml:"columns"` + ColumnDataTypes map[string]string `json:"column_data_types" yaml:"column_data_types"` ColumnDefaultProperties map[string]*ColumnDefaultProperties `json:"column_default_properties" yaml:"column_default_properties"` } type NeosyncToJsonConfig struct{} type ArchiveProcessor struct { - Format string `json:"format" yaml:"format"` + Format string `json:"format" yaml:"format"` Path *string `json:"path,omitempty" yaml:"path,omitempty"` } @@ -368,5 +368,5 @@ type OutputBrokerConfig struct { type InputBrokerConfig struct { Pattern string `json:"pattern" yaml:"pattern"` - Inputs []Inputs `json:"inputs" yaml:"inputs"` + Inputs []Inputs `json:"inputs" yaml:"inputs"` } diff --git a/worker/pkg/benthos/default_transform/processor.go b/worker/pkg/benthos/default_transform/processor.go index af508a4623..3302d7dc5e 100644 --- a/worker/pkg/benthos/default_transform/processor.go +++ b/worker/pkg/benthos/default_transform/processor.go @@ -49,7 +49,10 @@ type defaultTransformerProcessor struct { logger *service.Logger } -func newDefaultTransformerProcessor(conf *service.ParsedConfig, mgr *service.Resources) (*defaultTransformerProcessor, error) { +func newDefaultTransformerProcessor( + conf *service.ParsedConfig, + mgr *service.Resources, +) (*defaultTransformerProcessor, error) { mappedKeys, err := conf.FieldStringList("mapped_keys") if err != nil { return nil, err @@ -82,7 +85,9 @@ func newDefaultTransformerProcessor(conf *service.ParsedConfig, mgr *service.Res }, nil } -func getDefaultTransformerMap(jobSourceOptions *mgmtv1alpha1.JobSourceOptions) map[primitiveType]*mgmtv1alpha1.JobMappingTransformer { +func getDefaultTransformerMap( + jobSourceOptions *mgmtv1alpha1.JobSourceOptions, +) map[primitiveType]*mgmtv1alpha1.JobMappingTransformer { switch cfg := jobSourceOptions.Config.(type) { case *mgmtv1alpha1.JobSourceOptions_Dynamodb: unmappedTransformers := cfg.Dynamodb.UnmappedTransforms @@ -100,7 +105,10 @@ func getDefaultTransformerMap(jobSourceOptions *mgmtv1alpha1.JobSourceOptions) m } } -func (m *defaultTransformerProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { +func (m *defaultTransformerProcessor) ProcessBatch( + ctx context.Context, + batch service.MessageBatch, +) ([]service.MessageBatch, error) { newBatch := make(service.MessageBatch, 0, len(batch)) for _, msg := range batch { root, err := msg.AsStructuredMut() @@ -195,7 +203,11 @@ func (m *defaultTransformerProcessor) transformRoot(path string, root any) (any, } } -func (m *defaultTransformerProcessor) getValue(transformerKey primitiveType, value any, shouldMutate bool) (any, error) { +func (m *defaultTransformerProcessor) getValue( + transformerKey primitiveType, + value any, + shouldMutate bool, +) (any, error) { t := m.defaultTransformersInitMap[transformerKey] if t != nil && shouldMutate { return t.Mutate(value, t.Opts) @@ -203,7 +215,9 @@ func (m *defaultTransformerProcessor) getValue(transformerKey primitiveType, val return value, nil } -func initDefaultTransformers(defaultTransformerMap map[primitiveType]*mgmtv1alpha1.JobMappingTransformer) (map[primitiveType]*transformer_executor.TransformerExecutor, error) { +func initDefaultTransformers( + defaultTransformerMap map[primitiveType]*mgmtv1alpha1.JobMappingTransformer, +) (map[primitiveType]*transformer_executor.TransformerExecutor, error) { transformersInit := map[primitiveType]*transformer_executor.TransformerExecutor{} for k, t := range defaultTransformerMap { if !shouldProcess(t) { diff --git a/worker/pkg/benthos/dynamodb/input.go b/worker/pkg/benthos/dynamodb/input.go index 54108541ed..8adaa19fa3 100644 --- a/worker/pkg/benthos/dynamodb/input.go +++ b/worker/pkg/benthos/dynamodb/input.go @@ -46,11 +46,22 @@ func RegisterDynamoDbInput(env *service.Environment) error { } type dynamoDBAPIV2 interface { - DescribeTable(ctx context.Context, params *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) - ExecuteStatement(ctx context.Context, params *dynamodb.ExecuteStatementInput, optFns ...func(*dynamodb.Options)) (*dynamodb.ExecuteStatementOutput, error) + DescribeTable( + ctx context.Context, + params *dynamodb.DescribeTableInput, + optFns ...func(*dynamodb.Options), + ) (*dynamodb.DescribeTableOutput, error) + ExecuteStatement( + ctx context.Context, + params *dynamodb.ExecuteStatementInput, + optFns ...func(*dynamodb.Options), + ) (*dynamodb.ExecuteStatementOutput, error) } -func newDynamoDbBatchInput(conf *service.ParsedConfig, logger *service.Logger) (service.BatchInput, error) { +func newDynamoDbBatchInput( + conf *service.ParsedConfig, + logger *service.Logger, +) (service.BatchInput, error) { table, err := conf.FieldString("table") if err != nil { return nil, err @@ -131,10 +142,13 @@ func (d *dynamodbInput) Connect(ctx context.Context) error { } func isTableActive(output *dynamodb.DescribeTableOutput) bool { - return output != nil && output.Table != nil && output.Table.TableStatus == types.TableStatusActive + return output != nil && output.Table != nil && + output.Table.TableStatus == types.TableStatusActive } -func (d *dynamodbInput) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { +func (d *dynamodbInput) ReadBatch( + ctx context.Context, +) (service.MessageBatch, service.AckFunc, error) { d.readMu.Lock() defer d.readMu.Unlock() if d.client == nil { @@ -197,15 +211,24 @@ func (d *dynamodbInput) Close(ctx context.Context) error { return nil } -func getAwsSession(ctx context.Context, parsedConf *service.ParsedConfig, opts ...func(*config.LoadOptions) error) (*aws.Config, error) { - awsCfg, err := awsmanager.GetAwsConfig(ctx, getAwsCredentialsConfigFromParsedConf(parsedConf), opts...) +func getAwsSession( + ctx context.Context, + parsedConf *service.ParsedConfig, + opts ...func(*config.LoadOptions) error, +) (*aws.Config, error) { + awsCfg, err := awsmanager.GetAwsConfig( + ctx, + getAwsCredentialsConfigFromParsedConf(parsedConf), + opts...) if err != nil { return aws.NewConfig(), err } return awsCfg, nil } -func getAwsCredentialsConfigFromParsedConf(parsedConf *service.ParsedConfig) *awsmanager.AwsCredentialsConfig { +func getAwsCredentialsConfigFromParsedConf( + parsedConf *service.ParsedConfig, +) *awsmanager.AwsCredentialsConfig { output := &awsmanager.AwsCredentialsConfig{} if parsedConf == nil { return output diff --git a/worker/pkg/benthos/dynamodb/output.go b/worker/pkg/benthos/dynamodb/output.go index dedbe0d28a..bfb6b2063d 100644 --- a/worker/pkg/benthos/dynamodb/output.go +++ b/worker/pkg/benthos/dynamodb/output.go @@ -125,7 +125,9 @@ func dynamoOutputConfigSpec() *service.ConfigSpec { } func RegisterDynamoDbOutput(env *service.Environment) error { - return env.RegisterBatchOutput("aws_dynamodb", dynamoOutputConfigSpec(), + return env.RegisterBatchOutput( + "aws_dynamodb", + dynamoOutputConfigSpec(), func(conf *service.ParsedConfig, mgr *service.Resources) (out service.BatchOutput, batchPolicy service.BatchPolicy, maxInFlight int, err error) { if maxInFlight, err = conf.FieldMaxInFlight(); err != nil { return @@ -139,16 +141,41 @@ func RegisterDynamoDbOutput(env *service.Environment) error { } out, err = newDynamoDBWriter(wConf, mgr) return - }) + }, + ) } type dynamoDBAPI interface { - PutItem(ctx context.Context, params *dynamodb.PutItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.PutItemOutput, error) - BatchWriteItem(ctx context.Context, params *dynamodb.BatchWriteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.BatchWriteItemOutput, error) - BatchExecuteStatement(ctx context.Context, params *dynamodb.BatchExecuteStatementInput, optFns ...func(*dynamodb.Options)) (*dynamodb.BatchExecuteStatementOutput, error) - DescribeTable(ctx context.Context, params *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) - GetItem(ctx context.Context, params *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error) - DeleteItem(ctx context.Context, params *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error) + PutItem( + ctx context.Context, + params *dynamodb.PutItemInput, + optFns ...func(*dynamodb.Options), + ) (*dynamodb.PutItemOutput, error) + BatchWriteItem( + ctx context.Context, + params *dynamodb.BatchWriteItemInput, + optFns ...func(*dynamodb.Options), + ) (*dynamodb.BatchWriteItemOutput, error) + BatchExecuteStatement( + ctx context.Context, + params *dynamodb.BatchExecuteStatementInput, + optFns ...func(*dynamodb.Options), + ) (*dynamodb.BatchExecuteStatementOutput, error) + DescribeTable( + ctx context.Context, + params *dynamodb.DescribeTableInput, + optFns ...func(*dynamodb.Options), + ) (*dynamodb.DescribeTableOutput, error) + GetItem( + ctx context.Context, + params *dynamodb.GetItemInput, + optFns ...func(*dynamodb.Options), + ) (*dynamodb.GetItemOutput, error) + DeleteItem( + ctx context.Context, + params *dynamodb.DeleteItemInput, + optFns ...func(*dynamodb.Options), + ) (*dynamodb.DeleteItemOutput, error) } type dynamoDBWriter struct { @@ -407,7 +434,9 @@ func commonRetryBackOffFields( } } -func commonRetryBackOffCtorFromParsed(pConf *service.ParsedConfig) (ctor func() backoff.BackOff, err error) { +func commonRetryBackOffCtorFromParsed( + pConf *service.ParsedConfig, +) (ctor func() backoff.BackOff, err error) { var maxRetries int if maxRetries, err = pConf.FieldInt(crboFieldMaxRetries); err != nil { return nil, err @@ -447,7 +476,11 @@ func fieldDurationOrEmptyStr(pConf *service.ParsedConfig, path ...string) (time. return pConf.FieldDuration(path...) } -func marshalToAttributeValue(key string, root any, keyTypeMap map[string]neosync_types.KeyType) (types.AttributeValue, error) { +func marshalToAttributeValue( + key string, + root any, + keyTypeMap map[string]neosync_types.KeyType, +) (types.AttributeValue, error) { if typeStr, ok := keyTypeMap[key]; ok { switch typeStr { case neosync_types.StringSet: @@ -551,7 +584,11 @@ func formatFloat(f float64) string { return s } -func marshalJSONToDynamoDBAttribute(key, path string, root any, keyTypeMap map[string]neosync_types.KeyType) (types.AttributeValue, error) { +func marshalJSONToDynamoDBAttribute( + key, path string, + root any, + keyTypeMap map[string]neosync_types.KeyType, +) (types.AttributeValue, error) { gObj := gabs.Wrap(root) if path != "" { gObj = gObj.Path(path) diff --git a/worker/pkg/benthos/environment/environment.go b/worker/pkg/benthos/environment/environment.go index 0bb36e9be0..d0c3cd2e05 100644 --- a/worker/pkg/benthos/environment/environment.go +++ b/worker/pkg/benthos/environment/environment.go @@ -89,7 +89,11 @@ func NewEnvironment(logger *slog.Logger, opts ...Option) (*service.Environment, return NewWithEnvironment(service.NewEnvironment(), logger, opts...) } -func NewWithEnvironment(env *service.Environment, logger *slog.Logger, opts ...Option) (*service.Environment, error) { +func NewWithEnvironment( + env *service.Environment, + logger *slog.Logger, + opts ...Option, +) (*service.Environment, error) { if env == nil { env = service.NewEnvironment() } @@ -106,18 +110,32 @@ func NewWithEnvironment(env *service.Environment, logger *slog.Logger, opts ...O if config.meter != nil { err := benthos_metrics.RegisterOtelMetricsExporter(env, config.meter) if err != nil { - return nil, fmt.Errorf("unable to register otel_collector for benthos metering: %w", err) + return nil, fmt.Errorf( + "unable to register otel_collector for benthos metering: %w", + err, + ) } } if config.sqlConfig != nil { - err := neosync_benthos_sql.RegisterPooledSqlInsertOutput(env, config.sqlConfig.Provider, config.sqlConfig.IsRetry, logger) + err := neosync_benthos_sql.RegisterPooledSqlInsertOutput( + env, + config.sqlConfig.Provider, + config.sqlConfig.IsRetry, + logger, + ) if err != nil { - return nil, fmt.Errorf("unable to register pooled_sql_insert output to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register pooled_sql_insert output to benthos instance: %w", + err, + ) } err = neosync_benthos_sql.RegisterPooledSqlUpdateOutput(env, config.sqlConfig.Provider) if err != nil { - return nil, fmt.Errorf("unable to register pooled_sql_update output to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register pooled_sql_update output to benthos instance: %w", + err, + ) } err = neosync_benthos_sql.RegisterPooledSqlRawInput( env, @@ -127,23 +145,36 @@ func NewWithEnvironment(env *service.Environment, logger *slog.Logger, opts ...O config.sqlConfig.InputContinuationToken, ) if err != nil { - return nil, fmt.Errorf("unable to register pooled_sql_raw input to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register pooled_sql_raw input to benthos instance: %w", + err, + ) } } if config.mongoConfig != nil { err := neosync_benthos_mongodb.RegisterPooledMongoDbInput(env, config.mongoConfig.Provider) if err != nil { - return nil, fmt.Errorf("unable to register pooled_mongodb input to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register pooled_mongodb input to benthos instance: %w", + err, + ) } err = neosync_benthos_mongodb.RegisterPooledMongoDbOutput(env, config.mongoConfig.Provider) if err != nil { - return nil, fmt.Errorf("unable to register pooled_mongodb output to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register pooled_mongodb output to benthos instance: %w", + err, + ) } } if config.connectionDataConfig != nil { - err := neosync_benthos_connectiondata.RegisterNeosyncConnectionDataInput(env, config.connectionDataConfig.NeosyncConnectionDataApi, logger) + err := neosync_benthos_connectiondata.RegisterNeosyncConnectionDataInput( + env, + config.connectionDataConfig.NeosyncConnectionDataApi, + logger, + ) if err != nil { return nil, fmt.Errorf("unable to register neosync_connection_data input: %w", err) } @@ -151,7 +182,10 @@ func NewWithEnvironment(env *service.Environment, logger *slog.Logger, opts ...O err := openaigenerate.RegisterOpenaiGenerate(env) if err != nil { - return nil, fmt.Errorf("unable to register openai_generate input to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register openai_generate input to benthos instance: %w", + err, + ) } err = neosync_benthos_error.RegisterErrorProcessor(env, config.stopChannel) @@ -176,32 +210,50 @@ func NewWithEnvironment(env *service.Environment, logger *slog.Logger, opts ...O err = neosync_benthos_defaulttransform.ReisterDefaultTransformerProcessor(env) if err != nil { - return nil, fmt.Errorf("unable to register default mapping processor to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register default mapping processor to benthos instance: %w", + err, + ) } err = neosync_benthos_json.RegisterNeosyncToJsonProcessor(env) if err != nil { - return nil, fmt.Errorf("unable to register Neosync to JSON processor to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register Neosync to JSON processor to benthos instance: %w", + err, + ) } err = neosync_benthos_sql.RegisterNeosyncToPgxProcessor(env) if err != nil { - return nil, fmt.Errorf("unable to register Neosync to PGX processor to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register Neosync to PGX processor to benthos instance: %w", + err, + ) } err = neosync_benthos_sql.RegisterNeosyncToMysqlProcessor(env) if err != nil { - return nil, fmt.Errorf("unable to register Neosync to MYSQL processor to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register Neosync to MYSQL processor to benthos instance: %w", + err, + ) } err = neosync_benthos_sql.RegisterNeosyncToMssqlProcessor(env) if err != nil { - return nil, fmt.Errorf("unable to register Neosync to MSSQL processor to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register Neosync to MSSQL processor to benthos instance: %w", + err, + ) } err = javascript_processor.RegisterNeosyncJavascriptProcessor(env) if err != nil { - return nil, fmt.Errorf("unable to register javascript processor to benthos instance: %w", err) + return nil, fmt.Errorf( + "unable to register javascript processor to benthos instance: %w", + err, + ) } if config.blobEnv != nil { diff --git a/worker/pkg/benthos/error/output_error.go b/worker/pkg/benthos/error/output_error.go index e0a6aeeb4b..2485daf2b8 100644 --- a/worker/pkg/benthos/error/output_error.go +++ b/worker/pkg/benthos/error/output_error.go @@ -21,7 +21,8 @@ func errorOutputSpec() *service.ConfigSpec { // Registers an output on a benthos environment called error func RegisterErrorOutput(env *service.Environment, stopActivityChannel chan<- error) error { return env.RegisterBatchOutput( - "error", errorOutputSpec(), + "error", + errorOutputSpec(), func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchOutput, service.BatchPolicy, int, error) { batchPolicy, err := conf.FieldBatchPolicy("batching") if err != nil { @@ -37,10 +38,15 @@ func RegisterErrorOutput(env *service.Environment, stopActivityChannel chan<- er return nil, service.BatchPolicy{}, -1, err } return out, batchPolicy, maxInFlight, nil - }) + }, + ) } -func newErrorOutput(conf *service.ParsedConfig, mgr *service.Resources, channel chan<- error) (*errorOutput, error) { +func newErrorOutput( + conf *service.ParsedConfig, + mgr *service.Resources, + channel chan<- error, +) (*errorOutput, error) { errMsg, err := conf.FieldInterpolatedString("error_msg") if err != nil { return nil, err @@ -79,7 +85,9 @@ func (e *errorOutput) WriteBatch(ctx context.Context, batch service.MessageBatch return errors.New(errMsg) } // kill activity - e.logger.Error(fmt.Sprintf("Benthos Error output - sending stop activity signal: %s ", errMsg)) + e.logger.Error( + fmt.Sprintf("Benthos Error output - sending stop activity signal: %s ", errMsg), + ) e.stopActivityChannel <- fmt.Errorf("%s", errMsg) } return nil diff --git a/worker/pkg/benthos/error/processor_error.go b/worker/pkg/benthos/error/processor_error.go index ed9731aece..bab52568eb 100644 --- a/worker/pkg/benthos/error/processor_error.go +++ b/worker/pkg/benthos/error/processor_error.go @@ -32,7 +32,11 @@ type errorProcessor struct { errorMsg *service.InterpolatedString } -func newErrorProcessor(conf *service.ParsedConfig, mgr *service.Resources, channel chan<- error) (*errorProcessor, error) { +func newErrorProcessor( + conf *service.ParsedConfig, + mgr *service.Resources, + channel chan<- error, +) (*errorProcessor, error) { errMsg, err := conf.FieldInterpolatedString("error_msg") if err != nil { return nil, err @@ -44,14 +48,19 @@ func newErrorProcessor(conf *service.ParsedConfig, mgr *service.Resources, chann }, nil } -func (r *errorProcessor) ProcessBatch(_ context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { +func (r *errorProcessor) ProcessBatch( + _ context.Context, + batch service.MessageBatch, +) ([]service.MessageBatch, error) { for i := range batch { errMsg, err := batch.TryInterpolatedString(i, r.errorMsg) if err != nil { return nil, fmt.Errorf("error message interpolation error: %w", err) } // kill activity - r.logger.Error(fmt.Sprintf("Benthos Error processor - sending stop activity signal: %s ", errMsg)) + r.logger.Error( + fmt.Sprintf("Benthos Error processor - sending stop activity signal: %s ", errMsg), + ) r.stopActivityChannel <- fmt.Errorf("%s", errMsg) } return []service.MessageBatch{}, nil diff --git a/worker/pkg/benthos/javascript/processor.go b/worker/pkg/benthos/javascript/processor.go index 42f09c234b..87364f4cdc 100644 --- a/worker/pkg/benthos/javascript/processor.go +++ b/worker/pkg/benthos/javascript/processor.go @@ -37,7 +37,10 @@ type javascriptProcessor struct { vmPool sync.Pool } -func newJavascriptProcessorFromConfig(conf *service.ParsedConfig, mgr *service.Resources) (*javascriptProcessor, error) { +func newJavascriptProcessorFromConfig( + conf *service.ParsedConfig, + mgr *service.Resources, +) (*javascriptProcessor, error) { code, err := conf.FieldString(codeField) if err != nil { return nil, err @@ -72,7 +75,10 @@ type vmPoolItem struct { valueApi *benthosValueApi } -func (j *javascriptProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) (result []service.MessageBatch, err error) { +func (j *javascriptProcessor) ProcessBatch( + ctx context.Context, + batch service.MessageBatch, +) (result []service.MessageBatch, err error) { var runner *javascript_vm.Runner var valueApi *benthosValueApi @@ -96,7 +102,11 @@ func (j *javascriptProcessor) ProcessBatch(ctx context.Context, batch service.Me // This here acts as a final catch-all defense for anything we missed so prevent the process from crashing. defer func() { if r := recover(); r != nil { - j.slogger.Error("recovered from panic in neosync_javascript batch processor", "error", fmt.Sprintf("%v", r)) + j.slogger.Error( + "recovered from panic in neosync_javascript batch processor", + "error", + fmt.Sprintf("%v", r), + ) // Set the named return value 'err' err = fmt.Errorf("neosync_javascript batch processor panic recovered: %v", r) return diff --git a/worker/pkg/benthos/json/processor_neosync_json.go b/worker/pkg/benthos/json/processor_neosync_json.go index 74107415fe..f9f25c09c2 100644 --- a/worker/pkg/benthos/json/processor_neosync_json.go +++ b/worker/pkg/benthos/json/processor_neosync_json.go @@ -25,13 +25,19 @@ type neosyncToJsonProcessor struct { logger *service.Logger } -func newNeosyncToJsonProcessor(_ *service.ParsedConfig, mgr *service.Resources) *neosyncToJsonProcessor { +func newNeosyncToJsonProcessor( + _ *service.ParsedConfig, + mgr *service.Resources, +) *neosyncToJsonProcessor { return &neosyncToJsonProcessor{ logger: mgr.Logger(), } } -func (m *neosyncToJsonProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { +func (m *neosyncToJsonProcessor) ProcessBatch( + ctx context.Context, + batch service.MessageBatch, +) ([]service.MessageBatch, error) { newBatch := make(service.MessageBatch, 0, len(batch)) for _, msg := range batch { root, err := msg.AsStructuredMut() diff --git a/worker/pkg/benthos/metrics/otel_metrics.go b/worker/pkg/benthos/metrics/otel_metrics.go index 5a2d4ce4fd..16157a46de 100644 --- a/worker/pkg/benthos/metrics/otel_metrics.go +++ b/worker/pkg/benthos/metrics/otel_metrics.go @@ -43,7 +43,10 @@ func (c *otlpCounter) Incr(count int64) { c.counter.Add(context.Background(), count, metric.WithAttributes(c.labels...)) } -func (om *otlpMetricsExporter) NewCounterCtor(path string, labelNames ...string) service.MetricsExporterCounterCtor { +func (om *otlpMetricsExporter) NewCounterCtor( + path string, + labelNames ...string, +) service.MetricsExporterCounterCtor { return func(labelValues ...string) service.MetricsExporterCounter { var attrs []attribute.KeyValue for idx, label := range labelNames { @@ -71,7 +74,10 @@ func (c *otlpTimer) Timing(delta int64) { c.timer.Record(context.Background(), delta, metric.WithAttributes(c.labels...)) } -func (om *otlpMetricsExporter) NewTimerCtor(path string, labelNames ...string) service.MetricsExporterTimerCtor { +func (om *otlpMetricsExporter) NewTimerCtor( + path string, + labelNames ...string, +) service.MetricsExporterTimerCtor { return func(labelValues ...string) service.MetricsExporterTimer { var attrs []attribute.KeyValue for idx, label := range labelNames { @@ -103,7 +109,10 @@ func (c *otlpGauge) Set(value int64) { c.gaugeChan <- value } -func (om *otlpMetricsExporter) NewGaugeCtor(path string, labelNames ...string) service.MetricsExporterGaugeCtor { +func (om *otlpMetricsExporter) NewGaugeCtor( + path string, + labelNames ...string, +) service.MetricsExporterGaugeCtor { return func(labelValues ...string) service.MetricsExporterGauge { var attrs []attribute.KeyValue for idx, label := range labelNames { diff --git a/worker/pkg/benthos/mongodb/common.go b/worker/pkg/benthos/mongodb/common.go index a140d110d2..4869ca843b 100644 --- a/worker/pkg/benthos/mongodb/common.go +++ b/worker/pkg/benthos/mongodb/common.go @@ -194,7 +194,9 @@ func writeConcernDocs() *service.ConfigField { ).Description("The write concern settings for the mongo connection.") } -func writeConcernCollectionOptionFromParsed(pConf *service.ParsedConfig) (opt *options.CollectionOptions, err error) { +func writeConcernCollectionOptionFromParsed( + pConf *service.ParsedConfig, +) (opt *options.CollectionOptions, err error) { pConf = pConf.Namespace(commonFieldWriteConcern) var w string @@ -245,7 +247,10 @@ func operationFromParsed(pConf *service.ParsedConfig) (operation Operation, err } if operation = NewOperation(operationStr); operation == OperationInvalid { - err = fmt.Errorf("mongodb operation '%s' unknown: must be insert-one, delete-one, delete-many, replace-one or update-one", operationStr) + err = fmt.Errorf( + "mongodb operation '%s' unknown: must be insert-one, delete-one, delete-many, replace-one or update-one", + operationStr, + ) } return } @@ -257,7 +262,10 @@ type writeMaps struct { upsert bool } -func writeMapsFromParsed(conf *service.ParsedConfig, operation Operation) (maps writeMaps, err error) { +func writeMapsFromParsed( + conf *service.ParsedConfig, + operation Operation, +) (maps writeMaps, err error) { if probeStr, _ := conf.FieldString(commonFieldFilterMap); probeStr != "" { if maps.filterMap, err = conf.FieldBloblang(commonFieldFilterMap); err != nil { return maps, err @@ -387,7 +395,11 @@ func convertToMapStringKeyType(i any) (map[string]neosync_types.KeyType, error) return nil, errors.New("input is not of type map[string]KeyType") } -func marshalToBSONValue(key string, root any, keyTypeMap map[string]neosync_types.KeyType) (any, error) { +func marshalToBSONValue( + key string, + root any, + keyTypeMap map[string]neosync_types.KeyType, +) (any, error) { if root == nil { return nil, nil } @@ -511,7 +523,10 @@ func marshalToBSONValue(key string, root any, keyTypeMap map[string]neosync_type } } -func marshalJSONToBSONDocument(root any, keyTypeMap map[string]neosync_types.KeyType) (bson.D, error) { +func marshalJSONToBSONDocument( + root any, + keyTypeMap map[string]neosync_types.KeyType, +) (bson.D, error) { m, ok := root.(map[string]any) if !ok { return bson.D{}, fmt.Errorf("expected map[string]any, got %T", root) diff --git a/worker/pkg/benthos/mongodb/input.go b/worker/pkg/benthos/mongodb/input.go index ae143623b3..3d64d063b2 100644 --- a/worker/pkg/benthos/mongodb/input.go +++ b/worker/pkg/benthos/mongodb/input.go @@ -88,7 +88,11 @@ func RegisterPooledMongoDbInput(env *service.Environment, clientProvider MongoPo ) } -func newMongoInput(conf *service.ParsedConfig, clientProvider MongoPoolProvider, logger *service.Logger) (service.BatchInput, error) { +func newMongoInput( + conf *service.ParsedConfig, + clientProvider MongoPoolProvider, + logger *service.Logger, +) (service.BatchInput, error) { var ( limit, batchSize int sort map[string]int @@ -206,7 +210,10 @@ func (m *mongoInput) Connect(ctx context.Context) error { } m.cursor, err = collection.Aggregate(ctx, m.query, aggregateOptions) default: - return fmt.Errorf("operation '%s' not supported. the supported values are 'find' and 'aggregate'", m.operation) + return fmt.Errorf( + "operation '%s' not supported. the supported values are 'find' and 'aggregate'", + m.operation, + ) } if err != nil { _ = m.client.Disconnect(ctx) diff --git a/worker/pkg/benthos/mongodb/output.go b/worker/pkg/benthos/mongodb/output.go index 0afe8092e8..4ea1980467 100644 --- a/worker/pkg/benthos/mongodb/output.go +++ b/worker/pkg/benthos/mongodb/output.go @@ -79,7 +79,8 @@ func outputSpec() *service.ConfigSpec { func RegisterPooledMongoDbOutput(env *service.Environment, clientProvider MongoPoolProvider) error { return env.RegisterBatchOutput( - "pooled_mongodb", outputSpec(), + "pooled_mongodb", + outputSpec(), func(conf *service.ParsedConfig, mgr *service.Resources) (out service.BatchOutput, batchPolicy service.BatchPolicy, maxInFlight int, err error) { if batchPolicy, err = conf.FieldBatchPolicy(moFieldBatching); err != nil { return @@ -110,7 +111,11 @@ type outputWriter struct { mu sync.Mutex } -func newOutputWriter(conf *service.ParsedConfig, res *service.Resources, clientProvider MongoPoolProvider) (db *outputWriter, err error) { +func newOutputWriter( + conf *service.ParsedConfig, + res *service.Resources, + clientProvider MongoPoolProvider, +) (db *outputWriter, err error) { db = &outputWriter{ log: res.Logger(), } diff --git a/worker/pkg/benthos/neosync_connection_data/neosync_connection_data_input.go b/worker/pkg/benthos/neosync_connection_data/neosync_connection_data_input.go index 9d793ff3cf..5e487d7c77 100644 --- a/worker/pkg/benthos/neosync_connection_data/neosync_connection_data_input.go +++ b/worker/pkg/benthos/neosync_connection_data/neosync_connection_data_input.go @@ -136,7 +136,9 @@ func (g *neosyncInput) Connect(ctx context.Context) error { awsS3Cfg := &mgmtv1alpha1.AwsS3StreamConfig{} if g.connectionOpts != nil { if g.connectionOpts.jobRunId != nil && *g.connectionOpts.jobRunId != "" { - awsS3Cfg.Id = &mgmtv1alpha1.AwsS3StreamConfig_JobRunId{JobRunId: *g.connectionOpts.jobRunId} + awsS3Cfg.Id = &mgmtv1alpha1.AwsS3StreamConfig_JobRunId{ + JobRunId: *g.connectionOpts.jobRunId, + } } else if g.connectionOpts.jobId != nil && *g.connectionOpts.jobId != "" { awsS3Cfg.Id = &mgmtv1alpha1.AwsS3StreamConfig_JobId{JobId: *g.connectionOpts.jobId} } @@ -164,12 +166,15 @@ func (g *neosyncInput) Connect(ctx context.Context) error { } } - resp, err := g.neosyncConnectApi.GetConnectionDataStream(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionDataStreamRequest{ - ConnectionId: g.connectionId, - Schema: g.schema, - Table: g.table, - StreamConfig: streamCfg, - })) + resp, err := g.neosyncConnectApi.GetConnectionDataStream( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionDataStreamRequest{ + ConnectionId: g.connectionId, + Schema: g.schema, + Table: g.table, + StreamConfig: streamCfg, + }), + ) if err != nil { return err } @@ -203,7 +208,10 @@ func (g *neosyncInput) Read(ctx context.Context) (*service.Message, service.AckF decoder := gob.NewDecoder(bytes.NewReader(rowBytes)) err := decoder.Decode(&dynamoDBItem) if err != nil { - return nil, nil, fmt.Errorf("error decoding data connection stream response with gob decoder: %w", err) + return nil, nil, fmt.Errorf( + "error decoding data connection stream response with gob decoder: %w", + err, + ) } resMap, keyTypeMap := unmarshalDynamoDBItem(dynamoDBItem) @@ -220,7 +228,10 @@ func (g *neosyncInput) Read(ctx context.Context) (*service.Message, service.AckF decoder := gob.NewDecoder(bytes.NewReader(rowBytes)) err := decoder.Decode(&valuesMap) if err != nil { - return nil, nil, fmt.Errorf("error decoding data connection stream response with gob decoder: %w", err) + return nil, nil, fmt.Errorf( + "error decoding data connection stream response with gob decoder: %w", + err, + ) } msg := service.NewMessage(nil) msg.SetStructuredMut(valuesMap) @@ -245,7 +256,9 @@ func (g *neosyncInput) Close(ctx context.Context) error { return nil } -func unmarshalDynamoDBItem(item map[string]any) (standardMap map[string]any, keyTypeMap map[string]neosync_types.KeyType) { +func unmarshalDynamoDBItem( + item map[string]any, +) (standardMap map[string]any, keyTypeMap map[string]neosync_types.KeyType) { result := make(map[string]any) ktm := make(map[string]neosync_types.KeyType) for key, value := range item { @@ -255,7 +268,11 @@ func unmarshalDynamoDBItem(item map[string]any) (standardMap map[string]any, key return result, ktm } -func parseDynamoDBAttributeValue(key string, value any, keyTypeMap map[string]neosync_types.KeyType) any { +func parseDynamoDBAttributeValue( + key string, + value any, + keyTypeMap map[string]neosync_types.KeyType, +) any { if m, ok := value.(map[string]any); ok { for dynamoType, dynamoValue := range m { switch dynamoType { @@ -288,7 +305,11 @@ func parseDynamoDBAttributeValue(key string, value any, keyTypeMap map[string]ne list := dynamoValue.([]any) result := make([]any, len(list)) for i, item := range list { - result[i] = parseDynamoDBAttributeValue(fmt.Sprintf("%s[%d]", key, i), item, keyTypeMap) + result[i] = parseDynamoDBAttributeValue( + fmt.Sprintf("%s[%d]", key, i), + item, + keyTypeMap, + ) } return result case "M": diff --git a/worker/pkg/benthos/openai_generate/openai_generate.go b/worker/pkg/benthos/openai_generate/openai_generate.go index d777dc93a9..499e24678b 100644 --- a/worker/pkg/benthos/openai_generate/openai_generate.go +++ b/worker/pkg/benthos/openai_generate/openai_generate.go @@ -33,13 +33,17 @@ func getSpec() *service.ConfigSpec { } func RegisterOpenaiGenerate(env *service.Environment) error { - return env.RegisterBatchInput("openai_generate", getSpec(), func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchInput, error) { - rdr, err := newGenerateReader(conf, mgr) - if err != nil { - return nil, err - } - return service.AutoRetryNacksBatched(rdr), nil - }) + return env.RegisterBatchInput( + "openai_generate", + getSpec(), + func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchInput, error) { + rdr, err := newGenerateReader(conf, mgr) + if err != nil { + return nil, err + } + return service.AutoRetryNacksBatched(rdr), nil + }, + ) } type generateReader struct { @@ -63,7 +67,10 @@ type generateReader struct { log *service.Logger } -func newGenerateReader(conf *service.ParsedConfig, mgr *service.Resources) (*generateReader, error) { +func newGenerateReader( + conf *service.ParsedConfig, + mgr *service.Resources, +) (*generateReader, error) { apiUrl, err := conf.FieldString("api_url") if err != nil { return nil, err @@ -90,7 +97,11 @@ func newGenerateReader(conf *service.ParsedConfig, mgr *service.Resources) (*gen return nil, err } if len(columns) != len(dataTypes) { - return nil, fmt.Errorf("length of columns and data types was not the same: %d v %d", len(columns), len(dataTypes)) + return nil, fmt.Errorf( + "length of columns and data types was not the same: %d v %d", + len(columns), + len(dataTypes), + ) } count, err := conf.FieldInt("count") @@ -165,7 +176,11 @@ func (b *generateReader) Connect(ctx context.Context) error { if b.client != nil { return nil } - client, err := azopenai.NewClientForOpenAI(b.apiUrl, azcore.NewKeyCredential(b.apikey), &azopenai.ClientOptions{}) + client, err := azopenai.NewClientForOpenAI( + b.apiUrl, + azcore.NewKeyCredential(b.apikey), + &azopenai.ClientOptions{}, + ) if err != nil { return err } @@ -173,7 +188,9 @@ func (b *generateReader) Connect(ctx context.Context) error { return nil } -func (b *generateReader) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { +func (b *generateReader) ReadBatch( + ctx context.Context, +) (service.MessageBatch, service.AckFunc, error) { b.promptMut.Lock() defer b.promptMut.Unlock() if b.client == nil { @@ -209,18 +226,29 @@ func (b *generateReader) ReadBatch(ctx context.Context) (service.MessageBatch, s b.log.Warn("openai_generate: hit token limit reached, trimmed conversation") b.conversation = trimNonSystemMessages(b.conversation, 1) case azopenai.CompletionsFinishReasonContentFiltered: - return nil, nil, errors.New("openai: generation stopped due to openai content being filtered due to moderation policies") + return nil, nil, errors.New( + "openai: generation stopped due to openai content being filtered due to moderation policies", + ) default: } b.conversation = append( b.conversation, - &azopenai.ChatRequestAssistantMessage{Content: azopenai.NewChatRequestAssistantMessageContent(*choice.Message.Content)}, - &azopenai.ChatRequestUserMessage{Content: azopenai.NewChatRequestUserMessageContent(fmt.Sprintf("%d more records", batchSize))}, + &azopenai.ChatRequestAssistantMessage{ + Content: azopenai.NewChatRequestAssistantMessageContent(*choice.Message.Content), + }, + &azopenai.ChatRequestUserMessage{ + Content: azopenai.NewChatRequestUserMessageContent( + fmt.Sprintf("%d more records", batchSize), + ), + }, ) records, err := getCsvRecordsFromInput(*choice.Message.Content, b.log) if err != nil { - return nil, nil, fmt.Errorf("openai_generate: unable to fully process records retrieved from openai: %w", err) + return nil, nil, fmt.Errorf( + "openai_generate: unable to fully process records retrieved from openai: %w", + err, + ) } if len(records) == 0 { b.log.Warn("openai_generate: no records were returned from message") @@ -235,7 +263,9 @@ func (b *generateReader) ReadBatch(ctx context.Context) (service.MessageBatch, s // skipping the first record as it returns the headers for _, record := range records[1:] { if b.count == 0 { - b.log.Infof("stopping openai_generate as we've reached a count of 0 even though we had more records to process") + b.log.Infof( + "stopping openai_generate as we've reached a count of 0 even though we had more records to process", + ) break } structuredRecord, err := convertCsvToStructuredRecord(record, b.columns, b.dataTypes) @@ -249,7 +279,9 @@ func (b *generateReader) ReadBatch(ctx context.Context) (service.MessageBatch, s b.count -= 1 } if len(messageBatch) == 0 { - return nil, nil, errors.New("openai_generate: received response from openai but was unable to successfully process records to a structured format. see logs for more details") + return nil, nil, errors.New( + "openai_generate: received response from openai but was unable to successfully process records to a structured format. see logs for more details", + ) } return messageBatch, emptyAck, nil } @@ -274,7 +306,9 @@ func ptr[T any](val T) *T { func convertCsvToStructuredRecord(record, headers, types []string) (map[string]any, error) { if len(record) != len(headers) && len(headers) != len(types) && len(record) != len(types) { - return nil, fmt.Errorf("error converting csv record to structured record, record headers and types not equivalent in length") + return nil, fmt.Errorf( + "error converting csv record to structured record, record headers and types not equivalent in length", + ) } output := map[string]any{} for idx, value := range record { @@ -323,7 +357,9 @@ func toStructuredRecordValueType(value, dataType string) (any, error) { // return time.Parse("15:04:05Z07:00", value) return strings.TrimSpace(value), nil case "interval": - return strings.TrimSpace(value), nil // Parsing intervals can be complex; keeping it as string + return strings.TrimSpace( + value, + ), nil // Parsing intervals can be complex; keeping it as string case "boolean": return strconv.ParseBool(strings.TrimSpace(value)) case "uuid": @@ -368,7 +404,9 @@ func getCsvRecordsFromInput(input string, logger *service.Logger) ([][]string, e row, err := reader.Read() if err != nil { if errors.Is(err, io.EOF) { - return nil, errors.New("openai_generate: unable to process generated csv record response from openai") + return nil, errors.New( + "openai_generate: unable to process generated csv record response from openai", + ) } return nil, fmt.Errorf("unable to process CSV row to retrieve headers: %w", err) } @@ -416,7 +454,10 @@ func getCsvRecordsFromInput(input string, logger *service.Logger) ([][]string, e } } -func trimNonSystemMessages(messages []azopenai.ChatRequestMessageClassification, count int) []azopenai.ChatRequestMessageClassification { +func trimNonSystemMessages( + messages []azopenai.ChatRequestMessageClassification, + count int, +) []azopenai.ChatRequestMessageClassification { if len(messages) <= count { return messages[:0] // Return an empty slice } diff --git a/worker/pkg/benthos/redis/output_hash.go b/worker/pkg/benthos/redis/output_hash.go index a530aafd12..60d24c70a0 100644 --- a/worker/pkg/benthos/redis/output_hash.go +++ b/worker/pkg/benthos/redis/output_hash.go @@ -43,14 +43,16 @@ func redisHashOutputConfig() *service.ConfigSpec { func init() { err := service.RegisterOutput( - "redis_hash_output", redisHashOutputConfig(), + "redis_hash_output", + redisHashOutputConfig(), func(conf *service.ParsedConfig, mgr *service.Resources) (out service.Output, maxInFlight int, err error) { if maxInFlight, err = conf.FieldMaxInFlight(); err != nil { return } out, err = newRedisHashWriter(conf, mgr) return - }) + }, + ) if err != nil { panic(err) } @@ -69,7 +71,10 @@ type redisHashWriter struct { connMut sync.RWMutex } -func newRedisHashWriter(conf *service.ParsedConfig, mgr *service.Resources) (r *redisHashWriter, err error) { +func newRedisHashWriter( + conf *service.ParsedConfig, + mgr *service.Resources, +) (r *redisHashWriter, err error) { r = &redisHashWriter{ clientCtor: func() (redis.UniversalClient, error) { return getClient(conf) @@ -175,7 +180,7 @@ func (r *redisHashWriter) Write(ctx context.Context, msg *service.Message) error } if mapVal != nil { - fieldMappings, ok := mapVal.(map[string]interface{}) //nolint:gofmt + fieldMappings, ok := mapVal.(map[string]any) //nolint:gofmt if !ok { return fmt.Errorf("fieldMappings resulted in a non-object mapping: %T", mapVal) } diff --git a/worker/pkg/benthos/sql/input_sql_raw.go b/worker/pkg/benthos/sql/input_sql_raw.go index 3fd71786c5..3a18dfa151 100644 --- a/worker/pkg/benthos/sql/input_sql_raw.go +++ b/worker/pkg/benthos/sql/input_sql_raw.go @@ -36,7 +36,14 @@ func RegisterPooledSqlRawInput( "pooled_sql_raw", sqlRawInputSpec(), func(conf *service.ParsedConfig, mgr *service.Resources) (service.Input, error) { - input, err := newInput(conf, mgr, dbprovider, stopActivityChannel, onHasMorePages, continuationToken) + input, err := newInput( + conf, + mgr, + dbprovider, + stopActivityChannel, + onHasMorePages, + continuationToken, + ) if err != nil { return nil, err } @@ -157,7 +164,9 @@ func (s *pooledInput) Connect(ctx context.Context) error { var args []any if s.pagedQueryStatic != nil && s.continuationToken != nil && s.expectedTotalRows != nil { if len(s.orderByColumns) != len(s.continuationToken.Contents.LastReadOrderValues) { - columnMisMatchErr := fmt.Errorf("order by columns and last read order values must be the same length") + columnMisMatchErr := fmt.Errorf( + "order by columns and last read order values must be the same length", + ) s.logger.Error(columnMisMatchErr.Error()) s.stopActivityChannel <- columnMisMatchErr return columnMisMatchErr @@ -193,7 +202,9 @@ func (s *pooledInput) Connect(ctx context.Context) error { rows, err := db.QueryContext(ctx, query, args...) if err != nil { if neosync_benthos.IsCriticalError(err.Error()) { - s.logger.Error(fmt.Sprintf("Benthos input error - sending stop activity signal: %s ", err.Error())) + s.logger.Error( + fmt.Sprintf("Benthos input error - sending stop activity signal: %s ", err.Error()), + ) s.stopActivityChannel <- err } return err @@ -214,7 +225,13 @@ func (s *pooledInput) Read(ctx context.Context) (*service.Message, service.AckFu if s.rows == nil { if s.expectedTotalRows != nil && s.onHasMorePages != nil && len(s.orderByColumns) > 0 { // emit order by column values if ok - s.logger.Debug(fmt.Sprintf("rows read: %d, expected total rows: %d", s.rowsRead, *s.expectedTotalRows)) + s.logger.Debug( + fmt.Sprintf( + "rows read: %d, expected total rows: %d", + s.rowsRead, + *s.expectedTotalRows, + ), + ) if s.rowsRead >= *s.expectedTotalRows { s.logger.Debug("emitting order by column values") s.onHasMorePages(s.lastReadOrderValues) @@ -233,9 +250,17 @@ func (s *pooledInput) Read(ctx context.Context) (*service.Message, service.AckFu if s.expectedTotalRows != nil && s.onHasMorePages != nil && len(s.orderByColumns) > 0 { // emit order by column values if ok - s.logger.Debug(fmt.Sprintf("[ROW END] rows read: %d, expected total rows: %d", s.rowsRead, *s.expectedTotalRows)) + s.logger.Debug( + fmt.Sprintf( + "[ROW END] rows read: %d, expected total rows: %d", + s.rowsRead, + *s.expectedTotalRows, + ), + ) if s.rowsRead >= *s.expectedTotalRows { - s.logger.Debug("[ROW END] emitting onHasMorePages as rows read >= expected total rows") + s.logger.Debug( + "[ROW END] emitting onHasMorePages as rows read >= expected total rows", + ) s.onHasMorePages(s.lastReadOrderValues) } } diff --git a/worker/pkg/benthos/sql/output_sql_insert.go b/worker/pkg/benthos/sql/output_sql_insert.go index 55a473bbe0..6790e23df5 100644 --- a/worker/pkg/benthos/sql/output_sql_insert.go +++ b/worker/pkg/benthos/sql/output_sql_insert.go @@ -33,9 +33,15 @@ func sqlInsertOutputSpec() *service.ConfigSpec { } // Registers an output on a benthos environment called pooled_sql_raw -func RegisterPooledSqlInsertOutput(env *service.Environment, dbprovider ConnectionProvider, isRetry bool, logger *slog.Logger) error { +func RegisterPooledSqlInsertOutput( + env *service.Environment, + dbprovider ConnectionProvider, + isRetry bool, + logger *slog.Logger, +) error { return env.RegisterBatchOutput( - "pooled_sql_insert", sqlInsertOutputSpec(), + "pooled_sql_insert", + sqlInsertOutputSpec(), func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchOutput, service.BatchPolicy, int, error) { batchPolicy, err := conf.FieldBatchPolicy("batching") if err != nil { @@ -74,7 +80,13 @@ type pooledInsertOutput struct { isRetry bool } -func newInsertOutput(conf *service.ParsedConfig, mgr *service.Resources, provider ConnectionProvider, isRetry bool, logger *slog.Logger) (*pooledInsertOutput, error) { +func newInsertOutput( + conf *service.ParsedConfig, + mgr *service.Resources, + provider ConnectionProvider, + isRetry bool, + logger *slog.Logger, +) (*pooledInsertOutput, error) { connectionId, err := conf.FieldString("connection_id") if err != nil { return nil, err @@ -242,9 +254,19 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa if !shouldRetry { return fmt.Errorf("failed to execute insert query: %w", err) } - s.logger.Infof("received error during batch write that is retryable, proceeding with row by row insert: %s", err.Error()) - - err = retryInsertRowByRow(ctx, db, s.queryBuilder, rows, s.skipForeignKeyViolations, s.logger) + s.logger.Infof( + "received error during batch write that is retryable, proceeding with row by row insert: %s", + err.Error(), + ) + + err = retryInsertRowByRow( + ctx, + db, + s.queryBuilder, + rows, + s.skipForeignKeyViolations, + s.logger, + ) if err != nil { return fmt.Errorf("failed to retry insert query: %w", err) } @@ -283,7 +305,12 @@ func retryInsertRowByRow( insertCount++ } } - logger.Infof("Completed row-by-row insert with %d foreign key violations. Total Skipped rows: %d, Successfully inserted: %d", fkErrorCount, otherErrorCount, insertCount) + logger.Infof( + "Completed row-by-row insert with %d foreign key violations. Total Skipped rows: %d, Successfully inserted: %d", + fkErrorCount, + otherErrorCount, + insertCount, + ) return nil } diff --git a/worker/pkg/benthos/sql/output_sql_update.go b/worker/pkg/benthos/sql/output_sql_update.go index 78681c6273..886b15e3ba 100644 --- a/worker/pkg/benthos/sql/output_sql_update.go +++ b/worker/pkg/benthos/sql/output_sql_update.go @@ -42,7 +42,8 @@ func sqlUpdateOutputSpec() *service.ConfigSpec { // Registers an output on a benthos environment called pooled_sql_raw func RegisterPooledSqlUpdateOutput(env *service.Environment, dbprovider ConnectionProvider) error { return env.RegisterBatchOutput( - "pooled_sql_update", sqlUpdateOutputSpec(), + "pooled_sql_update", + sqlUpdateOutputSpec(), func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchOutput, service.BatchPolicy, int, error) { batchPolicy, err := conf.FieldBatchPolicy("batching") if err != nil { @@ -79,7 +80,11 @@ type pooledUpdateOutput struct { skipForeignKeyViolations bool } -func newUpdateOutput(conf *service.ParsedConfig, mgr *service.Resources, provider ConnectionProvider) (*pooledUpdateOutput, error) { +func newUpdateOutput( + conf *service.ParsedConfig, + mgr *service.Resources, + provider ConnectionProvider, +) (*pooledUpdateOutput, error) { connectionId, err := conf.FieldString("connection_id") if err != nil { return nil, err @@ -168,12 +173,20 @@ func (s *pooledUpdateOutput) WriteBatch(ctx context.Context, batch service.Messa return fmt.Errorf("message returned non-map result: %T", msgMap) } - query, err := querybuilder.BuildUpdateQuery(s.driver, s.schema, s.table, s.columns, s.whereCols, msgMap) + query, err := querybuilder.BuildUpdateQuery( + s.driver, + s.schema, + s.table, + s.columns, + s.whereCols, + msgMap, + ) if err != nil { return err } if _, err := db.ExecContext(ctx, query); err != nil { - if !s.skipForeignKeyViolations || !neosync_benthos.IsForeignKeyViolationError(err.Error()) { + if !s.skipForeignKeyViolations || + !neosync_benthos.IsForeignKeyViolationError(err.Error()) { return err } } diff --git a/worker/pkg/benthos/sql/processor_neosync_mssql.go b/worker/pkg/benthos/sql/processor_neosync_mssql.go index d49102100e..c3678ab740 100644 --- a/worker/pkg/benthos/sql/processor_neosync_mssql.go +++ b/worker/pkg/benthos/sql/processor_neosync_mssql.go @@ -38,7 +38,10 @@ type neosyncToMssqlProcessor struct { columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties } -func newNeosyncToMssqlProcessor(conf *service.ParsedConfig, mgr *service.Resources) (*neosyncToMssqlProcessor, error) { +func newNeosyncToMssqlProcessor( + conf *service.ParsedConfig, + mgr *service.Resources, +) (*neosyncToMssqlProcessor, error) { columns, err := conf.FieldStringList("columns") if err != nil { return nil, err @@ -67,7 +70,10 @@ func newNeosyncToMssqlProcessor(conf *service.ParsedConfig, mgr *service.Resourc }, nil } -func (p *neosyncToMssqlProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { +func (p *neosyncToMssqlProcessor) ProcessBatch( + ctx context.Context, + batch service.MessageBatch, +) ([]service.MessageBatch, error) { newBatch := make(service.MessageBatch, 0, len(batch)) for _, msg := range batch { root, err := msg.AsStructuredMut() @@ -149,7 +155,10 @@ func getMssqlNeosyncValue(root any) (value any, isNeosyncValue bool, err error) if valuer, ok := root.(neosynctypes.NeosyncMssqlValuer); ok { value, err := valuer.ValueMssql() if err != nil { - return nil, false, fmt.Errorf("unable to get MSSQL value from NeosyncMssqlValuer: %w", err) + return nil, false, fmt.Errorf( + "unable to get MSSQL value from NeosyncMssqlValuer: %w", + err, + ) } return value, true, nil } diff --git a/worker/pkg/benthos/sql/processor_neosync_mysql.go b/worker/pkg/benthos/sql/processor_neosync_mysql.go index e3c3a23b79..454e570008 100644 --- a/worker/pkg/benthos/sql/processor_neosync_mysql.go +++ b/worker/pkg/benthos/sql/processor_neosync_mysql.go @@ -38,7 +38,10 @@ type neosyncToMysqlProcessor struct { columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties } -func newNeosyncToMysqlProcessor(conf *service.ParsedConfig, mgr *service.Resources) (*neosyncToMysqlProcessor, error) { +func newNeosyncToMysqlProcessor( + conf *service.ParsedConfig, + mgr *service.Resources, +) (*neosyncToMysqlProcessor, error) { columns, err := conf.FieldStringList("columns") if err != nil { return nil, err @@ -67,14 +70,22 @@ func newNeosyncToMysqlProcessor(conf *service.ParsedConfig, mgr *service.Resourc }, nil } -func (p *neosyncToMysqlProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { +func (p *neosyncToMysqlProcessor) ProcessBatch( + ctx context.Context, + batch service.MessageBatch, +) ([]service.MessageBatch, error) { newBatch := make(service.MessageBatch, 0, len(batch)) for _, msg := range batch { root, err := msg.AsStructuredMut() if err != nil { return nil, err } - newRoot, err := transformNeosyncToMysql(root, p.columns, p.columnDataTypes, p.columnDefaultProperties) + newRoot, err := transformNeosyncToMysql( + root, + p.columns, + p.columnDataTypes, + p.columnDefaultProperties, + ) if err != nil { return nil, err } @@ -122,7 +133,11 @@ func transformNeosyncToMysql( return newMap, nil } -func getMysqlValue(value any, colDefaults *neosync_benthos.ColumnDefaultProperties, datatype string) (any, error) { +func getMysqlValue( + value any, + colDefaults *neosync_benthos.ColumnDefaultProperties, + datatype string, +) (any, error) { if colDefaults != nil && colDefaults.HasDefaultTransformer { return goqu.Default(), nil } @@ -165,7 +180,10 @@ func getMysqlNeosyncValue(root any) (value any, isNeosyncValue bool, err error) if valuer, ok := root.(neosynctypes.NeosyncMysqlValuer); ok { value, err := valuer.ValueMysql() if err != nil { - return nil, false, fmt.Errorf("unable to get MYSQL value from NeosyncMysqlValuer: %w", err) + return nil, false, fmt.Errorf( + "unable to get MYSQL value from NeosyncMysqlValuer: %w", + err, + ) } return value, true, nil } diff --git a/worker/pkg/benthos/sql/processor_neosync_pgx.go b/worker/pkg/benthos/sql/processor_neosync_pgx.go index a7ab732c12..56ad9a7770 100644 --- a/worker/pkg/benthos/sql/processor_neosync_pgx.go +++ b/worker/pkg/benthos/sql/processor_neosync_pgx.go @@ -43,7 +43,10 @@ type neosyncToPgxProcessor struct { columnDefaultProperties map[string]*neosync_benthos.ColumnDefaultProperties } -func newNeosyncToPgxProcessor(conf *service.ParsedConfig, mgr *service.Resources) (*neosyncToPgxProcessor, error) { +func newNeosyncToPgxProcessor( + conf *service.ParsedConfig, + mgr *service.Resources, +) (*neosyncToPgxProcessor, error) { columnDataTypes, err := conf.FieldStringMap("column_data_types") if err != nil { return nil, err @@ -72,14 +75,22 @@ func newNeosyncToPgxProcessor(conf *service.ParsedConfig, mgr *service.Resources }, nil } -func (p *neosyncToPgxProcessor) ProcessBatch(ctx context.Context, batch service.MessageBatch) ([]service.MessageBatch, error) { +func (p *neosyncToPgxProcessor) ProcessBatch( + ctx context.Context, + batch service.MessageBatch, +) ([]service.MessageBatch, error) { newBatch := make(service.MessageBatch, 0, len(batch)) for _, msg := range batch { root, err := msg.AsStructuredMut() if err != nil { return nil, err } - newRoot, err := transformNeosyncToPgx(root, p.columns, p.columnDataTypes, p.columnDefaultProperties) + newRoot, err := transformNeosyncToPgx( + root, + p.columns, + p.columnDataTypes, + p.columnDefaultProperties, + ) if err != nil { return nil, err } @@ -126,7 +137,11 @@ func transformNeosyncToPgx( return newMap, nil } -func getPgxValue(value any, colDefaults *neosync_benthos.ColumnDefaultProperties, datatype string) (any, error) { +func getPgxValue( + value any, + colDefaults *neosync_benthos.ColumnDefaultProperties, + datatype string, +) (any, error) { value, isNeosyncValue, err := getPgxNeosyncValue(value) if err != nil { return nil, err @@ -241,7 +256,9 @@ func isColumnInList(column string, columns []string) bool { return slices.Contains(columns, column) } -func getColumnDefaultProperties(columnDefaultPropertiesConfig map[string]*service.ParsedConfig) (map[string]*neosync_benthos.ColumnDefaultProperties, error) { +func getColumnDefaultProperties( + columnDefaultPropertiesConfig map[string]*service.ParsedConfig, +) (map[string]*neosync_benthos.ColumnDefaultProperties, error) { columnDefaultProperties := map[string]*neosync_benthos.ColumnDefaultProperties{} for key, properties := range columnDefaultPropertiesConfig { props, err := properties.FieldAny() diff --git a/worker/pkg/benthos/transformer_executor/executor.go b/worker/pkg/benthos/transformer_executor/executor.go index b9ae14abeb..1203ca6f2d 100644 --- a/worker/pkg/benthos/transformer_executor/executor.go +++ b/worker/pkg/benthos/transformer_executor/executor.go @@ -37,7 +37,12 @@ type transformPiiTextConfig struct { defaultLanguage *string } -func WithTransformPiiTextConfig(analyze presidioapi.AnalyzeInterface, anonymize presidioapi.AnonymizeInterface, neosyncOperatorApi ee_transformer_fns.NeosyncOperatorApi, defaultLanguage *string) TransformerExecutorOption { +func WithTransformPiiTextConfig( + analyze presidioapi.AnalyzeInterface, + anonymize presidioapi.AnonymizeInterface, + neosyncOperatorApi ee_transformer_fns.NeosyncOperatorApi, + defaultLanguage *string, +) TransformerExecutorOption { return func(c *TransformerExecutorConfig) { c.transformPiiText = &transformPiiTextConfig{ analyze: analyze, @@ -54,21 +59,32 @@ func WithLogger(logger *slog.Logger) TransformerExecutorOption { } } -func InitializeTransformer(transformerMapping *mgmtv1alpha1.JobMappingTransformer, opts ...TransformerExecutorOption) (*TransformerExecutor, error) { +func InitializeTransformer( + transformerMapping *mgmtv1alpha1.JobMappingTransformer, + opts ...TransformerExecutorOption, +) (*TransformerExecutor, error) { return InitializeTransformerByConfigType(transformerMapping.GetConfig(), opts...) } type UserDefinedTransformerResolver interface { - GetUserDefinedTransformer(ctx context.Context, id string) (*mgmtv1alpha1.TransformerConfig, error) + GetUserDefinedTransformer( + ctx context.Context, + id string, + ) (*mgmtv1alpha1.TransformerConfig, error) } -func WithUserDefinedTransformerResolver(resolver UserDefinedTransformerResolver) TransformerExecutorOption { +func WithUserDefinedTransformerResolver( + resolver UserDefinedTransformerResolver, +) TransformerExecutorOption { return func(c *TransformerExecutorConfig) { c.userDefinedTransformerResolver = resolver } } -func InitializeTransformerByConfigType(transformerConfig *mgmtv1alpha1.TransformerConfig, opts ...TransformerExecutorOption) (*TransformerExecutor, error) { +func InitializeTransformerByConfigType( + transformerConfig *mgmtv1alpha1.TransformerConfig, + opts ...TransformerExecutorOption, +) (*TransformerExecutor, error) { execCfg := &TransformerExecutorConfig{logger: slog.Default()} for _, opt := range opts { opt(execCfg) diff --git a/worker/pkg/benthos/transformers/generate_bool.go b/worker/pkg/benthos/transformers/generate_bool.go index fc43f0dca2..052a5e6da1 100644 --- a/worker/pkg/benthos/transformers/generate_bool.go +++ b/worker/pkg/benthos/transformers/generate_bool.go @@ -17,22 +17,26 @@ func init() { Category("boolean"). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_bool", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) - - return func() (any, error) { - return generateRandomBool(randomizer), nil - }, nil - }) + err := bloblang.RegisterFunctionV2( + "generate_bool", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } + + seed, err := transformer_utils.GetSeedOrDefault(seedArg) + if err != nil { + return nil, err + } + randomizer := rng.New(seed) + + return func() (any, error) { + return generateRandomBool(randomizer), nil + }, nil + }, + ) if err != nil { panic(err) } diff --git a/worker/pkg/benthos/transformers/generate_business_name.go b/worker/pkg/benthos/transformers/generate_business_name.go index 617a55b37b..b3eee35088 100644 --- a/worker/pkg/benthos/transformers/generate_business_name.go +++ b/worker/pkg/benthos/transformers/generate_business_name.go @@ -19,37 +19,44 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(100).Description("Specifies the maximum length for the generated data. This field ensures that the output does not exceed a certain number of characters.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_business_name", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_business_name", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - output, err := generateRandomBusinessName(randomizer, nil, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_business_name: %w", err) + return nil, err } - return output, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + output, err := generateRandomBusinessName(randomizer, nil, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run generate_business_name: %w", err) + } + return output, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateBusinessNameOptsFromConfig(config *mgmtv1alpha1.GenerateBusinessName, maxLength *int64) (*GenerateBusinessNameOpts, error) { +func NewGenerateBusinessNameOptsFromConfig( + config *mgmtv1alpha1.GenerateBusinessName, + maxLength *int64, +) (*GenerateBusinessNameOpts, error) { if config == nil { return NewGenerateBusinessNameOpts(nil, nil) } @@ -67,7 +74,11 @@ func (t *GenerateBusinessName) Generate(opts any) (any, error) { return generateRandomBusinessName(parsedOpts.randomizer, nil, parsedOpts.maxLength) } -func generateRandomBusinessName(randomizer rng.Rand, minLength *int64, maxLength int64) (string, error) { +func generateRandomBusinessName( + randomizer rng.Rand, + minLength *int64, + maxLength int64, +) (string, error) { return transformer_utils.GenerateStringFromCorpus( randomizer, transformers_dataset.BusinessNames, diff --git a/worker/pkg/benthos/transformers/generate_card_number.go b/worker/pkg/benthos/transformers/generate_card_number.go index 3ae72837e6..d11636ef9b 100644 --- a/worker/pkg/benthos/transformers/generate_card_number.go +++ b/worker/pkg/benthos/transformers/generate_card_number.go @@ -22,39 +22,45 @@ func init() { Param(bloblang.NewBoolParam("valid_luhn").Default(false).Description("A boolean indicating whether the generated value should pass the Luhn algorithm check.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_card_number", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - luhn, err := args.GetBool("valid_luhn") - if err != nil { - return nil, err - } - - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "generate_card_number", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + luhn, err := args.GetBool("valid_luhn") + if err != nil { + return nil, err + } - randomizer := rng.New(seed) + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := generateCardNumber(randomizer, luhn) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_card_number: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := generateCardNumber(randomizer, luhn) + if err != nil { + return nil, fmt.Errorf("unable to run generate_card_number: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateCardNumberOptsFromConfig(config *mgmtv1alpha1.GenerateCardNumber) (*GenerateCardNumberOpts, error) { +func NewGenerateCardNumberOptsFromConfig( + config *mgmtv1alpha1.GenerateCardNumber, +) (*GenerateCardNumberOpts, error) { if config == nil { return NewGenerateCardNumberOpts(nil, nil) } diff --git a/worker/pkg/benthos/transformers/generate_categorical.go b/worker/pkg/benthos/transformers/generate_categorical.go index 8978c02480..ccb57f0269 100644 --- a/worker/pkg/benthos/transformers/generate_categorical.go +++ b/worker/pkg/benthos/transformers/generate_categorical.go @@ -19,36 +19,42 @@ func init() { Param(bloblang.NewStringParam("categories").Default("ultimo,proximo,semper").Description("A list of comma-separated string values to randomly select from.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_categorical", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - // get stringified categories - catString, err := args.GetString("categories") - if err != nil { - return nil, err - } - categories := strings.Split(catString, ",") + err := bloblang.RegisterFunctionV2( + "generate_categorical", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + // get stringified categories + catString, err := args.GetString("categories") + if err != nil { + return nil, err + } + categories := strings.Split(catString, ",") - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) + if err != nil { + return nil, err + } + randomizer := rng.New(seed) - return func() (any, error) { - res := generateCategorical(randomizer, categories) - return res, nil - }, nil - }) + return func() (any, error) { + res := generateCategorical(randomizer, categories) + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateCategoricalOptsFromConfig(config *mgmtv1alpha1.GenerateCategorical) (*GenerateCategoricalOpts, error) { +func NewGenerateCategoricalOptsFromConfig( + config *mgmtv1alpha1.GenerateCategorical, +) (*GenerateCategoricalOpts, error) { if config == nil { return NewGenerateCategoricalOpts(nil, nil) } @@ -61,7 +67,10 @@ func (t *GenerateCategorical) Generate(opts any) (any, error) { return nil, fmt.Errorf("invalid parsed opts: %T", opts) } - return generateCategorical(parsedOpts.randomizer, strings.Split(parsedOpts.categories, ",")), nil + return generateCategorical( + parsedOpts.randomizer, + strings.Split(parsedOpts.categories, ","), + ), nil } // Generates a randomly selected value from the user-provided list of categories. We don't account for the maxLength param here because the input is user-provided. We assume that they values they provide in the set abide by the maxCharacterLength constraint. diff --git a/worker/pkg/benthos/transformers/generate_city.go b/worker/pkg/benthos/transformers/generate_city.go index 0fcaa6ad80..4d6646db9c 100644 --- a/worker/pkg/benthos/transformers/generate_city.go +++ b/worker/pkg/benthos/transformers/generate_city.go @@ -19,37 +19,44 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(100).Description("Specifies the maximum length for the generated data. This field ensures that the output does not exceed a certain number of characters.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_city", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_city", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := generateRandomCity(randomizer, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_city: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := generateRandomCity(randomizer, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run generate_city: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateCityOptsFromConfig(config *mgmtv1alpha1.GenerateCity, maxLength *int64) (*GenerateCityOpts, error) { +func NewGenerateCityOptsFromConfig( + config *mgmtv1alpha1.GenerateCity, + maxLength *int64, +) (*GenerateCityOpts, error) { if config == nil { return NewGenerateCityOpts( nil, diff --git a/worker/pkg/benthos/transformers/generate_country.go b/worker/pkg/benthos/transformers/generate_country.go index 58805340f3..aa4709c64b 100644 --- a/worker/pkg/benthos/transformers/generate_country.go +++ b/worker/pkg/benthos/transformers/generate_country.go @@ -13,41 +13,48 @@ import ( // +neosyncTransformerBuilder:generate:generateCountry func init() { - spec := bloblang.NewPluginSpec().Description("Randomly selects a country and by default, returns it as a 2-letter country code."). + spec := bloblang.NewPluginSpec(). + Description("Randomly selects a country and by default, returns it as a 2-letter country code."). Category("string"). Param(bloblang.NewBoolParam("generate_full_name").Default(false).Description("If true returns the full country name instead of the two character country code.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_country", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - generateFullName, err := args.GetBool("generate_full_name") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_country", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + generateFullName, err := args.GetBool("generate_full_name") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - val, err := generateRandomCountry(randomizer, generateFullName) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("failed to generate_country: %w", err) + return nil, err } - return val, nil - }, nil - }) + randomizer := rng.New(seed) + + return func() (any, error) { + val, err := generateRandomCountry(randomizer, generateFullName) + if err != nil { + return nil, fmt.Errorf("failed to generate_country: %w", err) + } + return val, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateCountryOptsFromConfig(config *mgmtv1alpha1.GenerateCountry) (*GenerateCountryOpts, error) { +func NewGenerateCountryOptsFromConfig( + config *mgmtv1alpha1.GenerateCountry, +) (*GenerateCountryOpts, error) { if config == nil { return NewGenerateCountryOpts( nil, diff --git a/worker/pkg/benthos/transformers/generate_email.go b/worker/pkg/benthos/transformers/generate_email.go index 0dbb4f2586..84b2e52b66 100644 --- a/worker/pkg/benthos/transformers/generate_email.go +++ b/worker/pkg/benthos/transformers/generate_email.go @@ -28,7 +28,8 @@ func (g GenerateEmailType) String() string { } func isValidEmailType(emailType string) bool { - return emailType == string(GenerateEmailType_UuidV4) || emailType == string(GenerateEmailType_FullName) + return emailType == string(GenerateEmailType_UuidV4) || + emailType == string(GenerateEmailType_FullName) } func init() { @@ -39,45 +40,57 @@ func init() { Param(bloblang.NewStringParam("email_type").Default(GenerateEmailType_UuidV4.String()).Description("Specifies the type of email type to generate, with options including `uuidv4`, `fullname`, or `any`.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_email", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } - emailTypeArg, err := args.GetString("email_type") - if err != nil { - return nil, err - } - emailType := getEmailTypeOrDefault(emailTypeArg) - - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_email", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } + emailTypeArg, err := args.GetString("email_type") + if err != nil { + return nil, err + } + emailType := getEmailTypeOrDefault(emailTypeArg) - var excludedDomains []string + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - output, err := generateRandomEmail(randomizer, maxLength, emailType, excludedDomains) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_email: %w", err) + return nil, err } - return output, nil - }, nil - }) + randomizer := rng.New(seed) + + var excludedDomains []string + + return func() (any, error) { + output, err := generateRandomEmail( + randomizer, + maxLength, + emailType, + excludedDomains, + ) + if err != nil { + return nil, fmt.Errorf("unable to run generate_email: %w", err) + } + return output, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateEmailOptsFromConfig(config *mgmtv1alpha1.GenerateEmail, maxLength *int64) (*GenerateEmailOpts, error) { +func NewGenerateEmailOptsFromConfig( + config *mgmtv1alpha1.GenerateEmail, + maxLength *int64, +) (*GenerateEmailOpts, error) { if config == nil { return NewGenerateEmailOpts(nil, nil, nil) } @@ -100,7 +113,12 @@ func (t *GenerateEmail) Generate(opts any) (any, error) { var excludedDomains []string - return generateRandomEmail(parsedOpts.randomizer, parsedOpts.maxLength, getEmailTypeOrDefault(parsedOpts.emailType), excludedDomains) + return generateRandomEmail( + parsedOpts.randomizer, + parsedOpts.maxLength, + getEmailTypeOrDefault(parsedOpts.emailType), + excludedDomains, + ) } func getEmailTypeOrDefault(input string) GenerateEmailType { @@ -110,7 +128,11 @@ func getEmailTypeOrDefault(input string) GenerateEmailType { return GenerateEmailType_UuidV4 } -func getRandomEmailDomain(randomizer rng.Rand, maxLength int64, excludedDomains []string) (string, error) { +func getRandomEmailDomain( + randomizer rng.Rand, + maxLength int64, + excludedDomains []string, +) (string, error) { return transformer_utils.GenerateStringFromCorpus( randomizer, transformers_dataset.EmailDomains, @@ -123,7 +145,12 @@ func getRandomEmailDomain(randomizer rng.Rand, maxLength int64, excludedDomains } /* Generates an email in the format such as jdoe@gmail.com */ -func generateRandomEmail(randomizer rng.Rand, maxLength int64, emailType GenerateEmailType, excludedDomains []string) (string, error) { +func generateRandomEmail( + randomizer rng.Rand, + maxLength int64, + emailType GenerateEmailType, + excludedDomains []string, +) (string, error) { if emailType == GenerateEmailType_Any { emailType = getRandomEmailType(randomizer) } @@ -141,10 +168,17 @@ func getRandomEmailType(randomizer rng.Rand) GenerateEmailType { return GenerateEmailType_FullName } -func generateFullnameEmail(randomizer rng.Rand, maxLength int64, excludedDomains []string) (string, error) { +func generateFullnameEmail( + randomizer rng.Rand, + maxLength int64, + excludedDomains []string, +) (string, error) { domainMaxLength := maxLength - 2 // is there enough room for at least one character and an @ sign if (domainMaxLength) <= 0 { - return "", fmt.Errorf("for the given max length, unable to generate an email of sufficient length: %d", maxLength) + return "", fmt.Errorf( + "for the given max length, unable to generate an email of sufficient length: %d", + maxLength, + ) } domain, err := getRandomEmailDomain(randomizer, domainMaxLength, excludedDomains) @@ -152,7 +186,9 @@ func generateFullnameEmail(randomizer rng.Rand, maxLength int64, excludedDomains return "", err } - fullNameMaxLength := maxLength - int64(len(domain)) - 1 // original full length, minus the computed domain, minus an @ sign + fullNameMaxLength := maxLength - int64( + len(domain), + ) - 1 // original full length, minus the computed domain, minus an @ sign generatename, err := generateNameForEmail(randomizer, nil, fullNameMaxLength) if err != nil { @@ -198,8 +234,12 @@ func generateNameForEmail(randomizer rng.Rand, minLength *int64, maxLength int64 } } - randomFirstName = strings.ToLower(transformer_utils.WithoutCharacters(randomFirstName, transformer_utils.SpecialChars)) - randomLastName = strings.ToLower(transformer_utils.WithoutCharacters(randomLastName, transformer_utils.SpecialChars)) + randomFirstName = strings.ToLower( + transformer_utils.WithoutCharacters(randomFirstName, transformer_utils.SpecialChars), + ) + randomLastName = strings.ToLower( + transformer_utils.WithoutCharacters(randomLastName, transformer_utils.SpecialChars), + ) if randomFirstName == "" && randomLastName == "" { return "", errors.New("unable to generate random first and/or last name for email") @@ -221,19 +261,32 @@ func generateNameForEmail(randomizer rng.Rand, minLength *int64, maxLength int64 return fullname, nil } -func generateUuidEmail(randomizer rng.Rand, maxLength int64, excludedDomains []string) (string, error) { +func generateUuidEmail( + randomizer rng.Rand, + maxLength int64, + excludedDomains []string, +) (string, error) { domainMaxLength := maxLength - 2 // is there enough room for at least one character and an @ sign if (domainMaxLength) <= 0 { - return "", fmt.Errorf("for the given max length, unable to generate an email of sufficient length: %d", maxLength) + return "", fmt.Errorf( + "for the given max length, unable to generate an email of sufficient length: %d", + maxLength, + ) } domain, err := getRandomEmailDomain(randomizer, domainMaxLength, excludedDomains) if err != nil { - return "", fmt.Errorf("unable to generate random email domain given the max length when generating a uuid email: %d", maxLength) + return "", fmt.Errorf( + "unable to generate random email domain given the max length when generating a uuid email: %d", + maxLength, + ) } newuuid := strings.ReplaceAll(uuid.NewString(), "-", "") trimmedUuid := transformer_utils.TrimStringIfExceeds(newuuid, maxLength-int64(len(domain))-1) if trimmedUuid == "" { // todo: if this doesn't work, we should try with a different email domain to see if there is one that works. Maybe we could use the closest pair algorithm to find this - return "", fmt.Errorf("for the given max length, unable to use a uuid to generate an email for the given length: %d", maxLength) + return "", fmt.Errorf( + "for the given max length, unable to use a uuid to generate an email for the given length: %d", + maxLength, + ) } return fmt.Sprintf("%s@%s", trimmedUuid, domain), nil diff --git a/worker/pkg/benthos/transformers/generate_first_name.go b/worker/pkg/benthos/transformers/generate_first_name.go index 94c39f2176..ae67e1984b 100644 --- a/worker/pkg/benthos/transformers/generate_first_name.go +++ b/worker/pkg/benthos/transformers/generate_first_name.go @@ -19,37 +19,44 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(100).Description("Specifies the maximum length for the generated data. This field ensures that the output does not exceed a certain number of characters.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_first_name", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_first_name", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - output, err := generateRandomFirstName(randomizer, nil, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_first_name: %w", err) + return nil, err } - return output, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + output, err := generateRandomFirstName(randomizer, nil, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run generate_first_name: %w", err) + } + return output, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateFirstNameOptsFromConfig(config *mgmtv1alpha1.GenerateFirstName, maxLength *int64) (*GenerateFirstNameOpts, error) { +func NewGenerateFirstNameOptsFromConfig( + config *mgmtv1alpha1.GenerateFirstName, + maxLength *int64, +) (*GenerateFirstNameOpts, error) { if config == nil { return NewGenerateFirstNameOpts(nil, nil) } @@ -67,7 +74,11 @@ func (t *GenerateFirstName) Generate(opts any) (any, error) { return generateRandomFirstName(parsedOpts.randomizer, nil, parsedOpts.maxLength) } -func generateRandomFirstName(randomizer rng.Rand, minLength *int64, maxLength int64) (string, error) { +func generateRandomFirstName( + randomizer rng.Rand, + minLength *int64, + maxLength int64, +) (string, error) { return transformer_utils.GenerateStringFromCorpus( randomizer, transformers_dataset.FirstNames, diff --git a/worker/pkg/benthos/transformers/generate_float.go b/worker/pkg/benthos/transformers/generate_float.go index 26b9686cc1..7778531d74 100644 --- a/worker/pkg/benthos/transformers/generate_float.go +++ b/worker/pkg/benthos/transformers/generate_float.go @@ -25,56 +25,70 @@ func init() { Param(bloblang.NewInt64Param("scale").Optional().Description("An optional parameter that defines the number of decimal places for the generated float.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_float64", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - randomizeSign, err := args.GetBool("randomize_sign") - if err != nil { - return nil, err - } - - minVal, err := args.GetFloat64("min") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "generate_float64", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + randomizeSign, err := args.GetBool("randomize_sign") + if err != nil { + return nil, err + } - maxVal, err := args.GetFloat64("max") - if err != nil { - return nil, err - } + minVal, err := args.GetFloat64("min") + if err != nil { + return nil, err + } - precision, err := args.GetOptionalInt64("precision") - if err != nil { - return nil, err - } - scale, err := args.GetOptionalInt64("scale") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + maxVal, err := args.GetFloat64("max") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) + precision, err := args.GetOptionalInt64("precision") + if err != nil { + return nil, err + } + scale, err := args.GetOptionalInt64("scale") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := generateRandomFloat64(randomizer, randomizeSign, minVal, maxVal, precision, scale) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_float: %w", err) + return nil, err } - return res, nil - }, nil - }) + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := generateRandomFloat64( + randomizer, + randomizeSign, + minVal, + maxVal, + precision, + scale, + ) + if err != nil { + return nil, fmt.Errorf("unable to run generate_float: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateFloat64OptsFromConfig(config *mgmtv1alpha1.GenerateFloat64, scale *int64) (*GenerateFloat64Opts, error) { +func NewGenerateFloat64OptsFromConfig( + config *mgmtv1alpha1.GenerateFloat64, + scale *int64, +) (*GenerateFloat64Opts, error) { if config == nil { return NewGenerateFloat64Opts(nil, nil, nil, nil, nil, nil) } @@ -111,7 +125,11 @@ func generateRandomFloat64( minValue, maxValue float64, precision, scale *int64, ) (float64, error) { - randomFloat, err := transformer_utils.GenerateRandomFloat64WithInclusiveBounds(randomizer, minValue, maxValue) + randomFloat, err := transformer_utils.GenerateRandomFloat64WithInclusiveBounds( + randomizer, + minValue, + maxValue, + ) if err != nil { return 0, err } diff --git a/worker/pkg/benthos/transformers/generate_full_address.go b/worker/pkg/benthos/transformers/generate_full_address.go index fa19cb88c6..be1792dd34 100644 --- a/worker/pkg/benthos/transformers/generate_full_address.go +++ b/worker/pkg/benthos/transformers/generate_full_address.go @@ -19,37 +19,44 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(100).Description("Specifies the maximum length for the generated data. This field ensures that the output does not exceed a certain number of characters.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_full_address", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_full_address", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := generateRandomFullAddress(randomizer, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := generateRandomFullAddress(randomizer, maxLength) + if err != nil { + return nil, err + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateFullAddressOptsFromConfig(config *mgmtv1alpha1.GenerateFullAddress, maxLength *int64) (*GenerateFullAddressOpts, error) { +func NewGenerateFullAddressOptsFromConfig( + config *mgmtv1alpha1.GenerateFullAddress, + maxLength *int64, +) (*GenerateFullAddressOpts, error) { if config == nil { return NewGenerateFullAddressOpts( nil, @@ -85,7 +92,9 @@ func generateRandomFullAddress(randomizer rng.Rand, maxLength int64) (string, er // we have a finite set of zipcodes and states so we basically know the max length for the city and street address for each generated permutation. remainder := int64(int(maxLength) - len(state) - len(zipcode) - 4) // -4 for spaces and comma if remainder <= 0 { - return "", fmt.Errorf("the state and zipcode combined are longer than the max length allowed") + return "", fmt.Errorf( + "the state and zipcode combined are longer than the max length allowed", + ) } maxCityIdx, maxAddr1Idx := transformer_utils.FindClosestPair( @@ -94,19 +103,29 @@ func generateRandomFullAddress(randomizer rng.Rand, maxLength int64) (string, er remainder, ) if maxCityIdx == -1 || maxAddr1Idx == -1 { - randStr, err := transformer_utils.GenerateRandomStringWithInclusiveBounds(randomizer, 1, remainder) + randStr, err := transformer_utils.GenerateRandomStringWithInclusiveBounds( + randomizer, + 1, + remainder, + ) if err != nil { return "", err } return fmt.Sprintf(`%s %s, %s`, randStr, state, zipcode), nil } - city, err := generateRandomCity(randomizer, transformers_dataset.Address_CityIndices[maxCityIdx]) + city, err := generateRandomCity( + randomizer, + transformers_dataset.Address_CityIndices[maxCityIdx], + ) if err != nil { return "", err } - street, err := generateRandomStreetAddress(randomizer, transformers_dataset.Address_Address1Indices[maxAddr1Idx]) + street, err := generateRandomStreetAddress( + randomizer, + transformers_dataset.Address_Address1Indices[maxAddr1Idx], + ) if err != nil { return "", err } diff --git a/worker/pkg/benthos/transformers/generate_full_name.go b/worker/pkg/benthos/transformers/generate_full_name.go index 8355030ce1..15afdc3443 100644 --- a/worker/pkg/benthos/transformers/generate_full_name.go +++ b/worker/pkg/benthos/transformers/generate_full_name.go @@ -19,37 +19,44 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(100).Description("Specifies the maximum length for the generated data. This field ensures that the output does not exceed a certain number of characters.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_full_name", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_full_name", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := generateRandomFullName(randomizer, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_full_name: %w", err) + return nil, err } - return res, nil - }, nil - }) + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := generateRandomFullName(randomizer, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run generate_full_name: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateFullNameOptsFromConfig(config *mgmtv1alpha1.GenerateFullName, maxLength *int64) (*GenerateFullNameOpts, error) { +func NewGenerateFullNameOptsFromConfig( + config *mgmtv1alpha1.GenerateFullName, + maxLength *int64, +) (*GenerateFullNameOpts, error) { if config == nil { return NewGenerateFullNameOpts( nil, @@ -74,25 +81,37 @@ func (t *GenerateFullName) Generate(opts any) (any, error) { func generateRandomFullName(randomizer rng.Rand, maxLength int64) (string, error) { maxLengthMinusSpace := maxLength - 1 if maxLengthMinusSpace <= 0 { - return "", fmt.Errorf("unable to generate full name including space with provided max length: %d", maxLength) + return "", fmt.Errorf( + "unable to generate full name including space with provided max length: %d", + maxLength, + ) } maxFirstNameIdx, maxLastNameIdx := transformer_utils.FindClosestPair( transformers_dataset.FirstNameIndices, transformers_dataset.LastNameIndices, maxLengthMinusSpace, ) if maxFirstNameIdx == -1 || maxLastNameIdx == -1 { - return "", fmt.Errorf("unable to generate a full name with the provided max length: %d", maxLength) + return "", fmt.Errorf( + "unable to generate a full name with the provided max length: %d", + maxLength, + ) } maxFirstNameLength := transformers_dataset.FirstNameIndices[maxFirstNameIdx] maxLastNameLength := transformers_dataset.LastNameIndices[maxLastNameIdx] firstname, err := generateRandomFirstName(randomizer, nil, maxFirstNameLength) if err != nil { - return "", fmt.Errorf("unable to generate random first name with length: %d", maxFirstNameLength) + return "", fmt.Errorf( + "unable to generate random first name with length: %d", + maxFirstNameLength, + ) } lastname, err := generateRandomLastName(randomizer, nil, maxLastNameLength) if err != nil { - return "", fmt.Errorf("unable to generate random last name with length: %d", maxLastNameLength) + return "", fmt.Errorf( + "unable to generate random last name with length: %d", + maxLastNameLength, + ) } return fmt.Sprintf("%s %s", firstname, lastname), nil diff --git a/worker/pkg/benthos/transformers/generate_gender.go b/worker/pkg/benthos/transformers/generate_gender.go index 8d0cb195f7..20335bdbec 100644 --- a/worker/pkg/benthos/transformers/generate_gender.go +++ b/worker/pkg/benthos/transformers/generate_gender.go @@ -19,39 +19,46 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(100).Description("Specifies the maximum length for the generated data. This field ensures that the output does not exceed a certain number of characters.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_gender", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - shouldAbbreviate, err := args.GetBool("abbreviate") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "generate_gender", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + shouldAbbreviate, err := args.GetBool("abbreviate") + if err != nil { + return nil, err + } - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) + if err != nil { + return nil, err + } + randomizer := rng.New(seed) - return func() (any, error) { - res := generateRandomGender(randomizer, shouldAbbreviate, maxLength) - return res, nil - }, nil - }) + return func() (any, error) { + res := generateRandomGender(randomizer, shouldAbbreviate, maxLength) + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateGenderOptsFromConfig(config *mgmtv1alpha1.GenerateGender, maxLength *int64) (*GenerateGenderOpts, error) { +func NewGenerateGenderOptsFromConfig( + config *mgmtv1alpha1.GenerateGender, + maxLength *int64, +) (*GenerateGenderOpts, error) { if config == nil { return NewGenerateGenderOpts( nil, @@ -71,7 +78,11 @@ func (t *GenerateGender) Generate(opts any) (any, error) { return nil, fmt.Errorf("invalid parsed opts: %T", opts) } - return generateRandomGender(parsedOpts.randomizer, parsedOpts.abbreviate, parsedOpts.maxLength), nil + return generateRandomGender( + parsedOpts.randomizer, + parsedOpts.abbreviate, + parsedOpts.maxLength, + ), nil } var genders = []string{"undefined", "nonbinary", "female", "male"} diff --git a/worker/pkg/benthos/transformers/generate_int64.go b/worker/pkg/benthos/transformers/generate_int64.go index fc447f27f9..a996adaad0 100644 --- a/worker/pkg/benthos/transformers/generate_int64.go +++ b/worker/pkg/benthos/transformers/generate_int64.go @@ -36,49 +36,55 @@ func init() { Param(bloblang.NewInt64Param("max").Default(10000).Description("Specifies the maximum value for the generated int.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_int64", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - randomizeSign, err := args.GetBool("randomize_sign") - if err != nil { - return nil, err - } - - min, err := args.GetInt64("min") - if err != nil { - return nil, err - } - - max, err := args.GetInt64("max") - if err != nil { - return nil, err - } - - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) - - return func() (any, error) { - res, err := generateRandomInt64(randomizer, randomizeSign, min, max) + err := bloblang.RegisterFunctionV2( + "generate_int64", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + randomizeSign, err := args.GetBool("randomize_sign") if err != nil { - return nil, fmt.Errorf("unable to run generate_int64: %w", err) + return nil, err } - return res, nil - }, nil - }) + + min, err := args.GetInt64("min") + if err != nil { + return nil, err + } + + max, err := args.GetInt64("max") + if err != nil { + return nil, err + } + + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } + + seed, err := transformer_utils.GetSeedOrDefault(seedArg) + if err != nil { + return nil, err + } + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := generateRandomInt64(randomizer, randomizeSign, min, max) + if err != nil { + return nil, fmt.Errorf("unable to run generate_int64: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateInt64OptsFromConfig(config *mgmtv1alpha1.GenerateInt64) (*GenerateInt64Opts, error) { +func NewGenerateInt64OptsFromConfig( + config *mgmtv1alpha1.GenerateInt64, +) (*GenerateInt64Opts, error) { if config == nil { return NewGenerateInt64Opts( nil, @@ -100,13 +106,22 @@ func (t *GenerateInt64) Generate(opts any) (any, error) { return nil, fmt.Errorf("invalid parsed opts: %T", opts) } - return generateRandomInt64(parsedOpts.randomizer, parsedOpts.randomizeSign, parsedOpts.min, parsedOpts.max) + return generateRandomInt64( + parsedOpts.randomizer, + parsedOpts.randomizeSign, + parsedOpts.min, + parsedOpts.max, + ) } /* Generates a random int64 in the interval [min, max]. */ -func generateRandomInt64(randomizer rng.Rand, randomizeSign bool, minValue, maxValue int64) (int64, error) { +func generateRandomInt64( + randomizer rng.Rand, + randomizeSign bool, + minValue, maxValue int64, +) (int64, error) { output, err := transformer_utils.GenerateRandomInt64InValueRange(randomizer, minValue, maxValue) if err != nil { return 0, err diff --git a/worker/pkg/benthos/transformers/generate_int64_phone_number.go b/worker/pkg/benthos/transformers/generate_int64_phone_number.go index f17d7ddbb8..8646ae4cbc 100644 --- a/worker/pkg/benthos/transformers/generate_int64_phone_number.go +++ b/worker/pkg/benthos/transformers/generate_int64_phone_number.go @@ -16,38 +16,45 @@ import ( var defaultPhoneNumberLength = int64(10) func init() { - spec := bloblang.NewPluginSpec().Description("Generates a new int64 phone number with a default length of 10."). + spec := bloblang.NewPluginSpec(). + Description("Generates a new int64 phone number with a default length of 10."). Category("int64"). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_int64_phone_number", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_int64_phone_number", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := generateRandomInt64PhoneNumber(randomizer) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_int64_phone_number: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := generateRandomInt64PhoneNumber(randomizer) + if err != nil { + return nil, fmt.Errorf("unable to run generate_int64_phone_number: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateInt64PhoneNumberOptsFromConfig(config *mgmtv1alpha1.GenerateInt64PhoneNumber) (*GenerateInt64PhoneNumberOpts, error) { +func NewGenerateInt64PhoneNumberOptsFromConfig( + config *mgmtv1alpha1.GenerateInt64PhoneNumber, +) (*GenerateInt64PhoneNumberOpts, error) { return NewGenerateInt64PhoneNumberOpts(nil) } @@ -62,7 +69,10 @@ func (t *GenerateInt64PhoneNumber) Generate(opts any) (any, error) { /* Generates a random 10 digit phone number with a valid US area code and returns it as an int64. */ func generateRandomInt64PhoneNumber(randomizer rng.Rand) (int64, error) { // get a random area code from the areacodes data set - randAreaCodeStr, err := transformer_utils.GetRandomValueFromSlice(randomizer, transformers_dataset.UsAreaCodes) + randAreaCodeStr, err := transformer_utils.GetRandomValueFromSlice( + randomizer, + transformers_dataset.UsAreaCodes, + ) if err != nil { return 0, err } @@ -73,7 +83,10 @@ func generateRandomInt64PhoneNumber(randomizer rng.Rand) (int64, error) { } // generate the rest of the phone number - pn, err := transformer_utils.GenerateRandomInt64FixedLength(randomizer, defaultPhoneNumberLength-3) + pn, err := transformer_utils.GenerateRandomInt64FixedLength( + randomizer, + defaultPhoneNumberLength-3, + ) if err != nil { return 0, err } diff --git a/worker/pkg/benthos/transformers/generate_international_phone_number.go b/worker/pkg/benthos/transformers/generate_international_phone_number.go index 305113d862..efbd3d2221 100644 --- a/worker/pkg/benthos/transformers/generate_international_phone_number.go +++ b/worker/pkg/benthos/transformers/generate_international_phone_number.go @@ -20,44 +20,53 @@ func init() { Param(bloblang.NewInt64Param("max").Default(15).Description("Specifies the maximum value for the generated phone number.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_e164_phone_number", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - min, err := args.GetInt64("min") - if err != nil { - return nil, err - } - - max, err := args.GetInt64("max") - if err != nil { - return nil, err - } - - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) - - return func() (any, error) { - res, err := generateInternationalPhoneNumber(randomizer, min, max) + err := bloblang.RegisterFunctionV2( + "generate_e164_phone_number", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + min, err := args.GetInt64("min") if err != nil { - return nil, fmt.Errorf("unable to run generate_international_phone_number: %w", err) + return nil, err } - return res, nil - }, nil - }) + + max, err := args.GetInt64("max") + if err != nil { + return nil, err + } + + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } + + seed, err := transformer_utils.GetSeedOrDefault(seedArg) + if err != nil { + return nil, err + } + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := generateInternationalPhoneNumber(randomizer, min, max) + if err != nil { + return nil, fmt.Errorf( + "unable to run generate_international_phone_number: %w", + err, + ) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateInternationalPhoneNumberOptsFromConfig(config *mgmtv1alpha1.GenerateE164PhoneNumber) (*GenerateInternationalPhoneNumberOpts, error) { +func NewGenerateInternationalPhoneNumberOptsFromConfig( + config *mgmtv1alpha1.GenerateE164PhoneNumber, +) (*GenerateInternationalPhoneNumberOpts, error) { if config == nil { return NewGenerateInternationalPhoneNumberOpts( nil, @@ -83,7 +92,10 @@ func (t *GenerateInternationalPhoneNumber) Generate(opts any) (any, error) { /* Generates a random phone number in e164 format in the length interval [min, max] with the min length == 9 and the max length == 15. */ -func generateInternationalPhoneNumber(randomizer rng.Rand, minValue, maxValue int64) (string, error) { +func generateInternationalPhoneNumber( + randomizer rng.Rand, + minValue, maxValue int64, +) (string, error) { if minValue < 9 || maxValue > 15 { return "", errors.New("the length has between 9 and 15 characters long") } diff --git a/worker/pkg/benthos/transformers/generate_ip_address.go b/worker/pkg/benthos/transformers/generate_ip_address.go index 183a45a92b..dc9f866659 100644 --- a/worker/pkg/benthos/transformers/generate_ip_address.go +++ b/worker/pkg/benthos/transformers/generate_ip_address.go @@ -49,39 +49,46 @@ func init() { Param(bloblang.NewStringParam("ip_type").Default(string(IpV4_Public)).Description("IP type to generate.")). Param(bloblang.NewInt64Param("seed").Optional().Description("Optional seed for deterministic generation")) - err := bloblang.RegisterFunctionV2("generate_ip", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } - ipType, err := args.GetString("ip_type") - if err != nil { - return nil, err - } - - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) - - return func() (any, error) { - return generateIpAddress(randomizer, IpType(ipType), maxLength) - }, nil - }) + err := bloblang.RegisterFunctionV2( + "generate_ip", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } + ipType, err := args.GetString("ip_type") + if err != nil { + return nil, err + } + + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } + + seed, err := transformer_utils.GetSeedOrDefault(seedArg) + if err != nil { + return nil, err + } + + randomizer := rng.New(seed) + + return func() (any, error) { + return generateIpAddress(randomizer, IpType(ipType), maxLength) + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateIpAddressOptsFromConfig(config *mgmtv1alpha1.GenerateIpAddress, maxlength *int64) (*GenerateIpAddressOpts, error) { +func NewGenerateIpAddressOptsFromConfig( + config *mgmtv1alpha1.GenerateIpAddress, + maxlength *int64, +) (*GenerateIpAddressOpts, error) { if config == nil { return NewGenerateIpAddressOpts(nil, nil, nil) } diff --git a/worker/pkg/benthos/transformers/generate_last_name.go b/worker/pkg/benthos/transformers/generate_last_name.go index 6b1807a64c..71a622eac4 100644 --- a/worker/pkg/benthos/transformers/generate_last_name.go +++ b/worker/pkg/benthos/transformers/generate_last_name.go @@ -19,37 +19,44 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(100).Description("Specifies the maximum length for the generated data. This field ensures that the output does not exceed a certain number of characters.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_last_name", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_last_name", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - output, err := generateRandomLastName(randomizer, nil, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_last_name") + return nil, err } - return output, nil - }, nil - }) + randomizer := rng.New(seed) + + return func() (any, error) { + output, err := generateRandomLastName(randomizer, nil, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run generate_last_name") + } + return output, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateLastNameOptsFromConfig(config *mgmtv1alpha1.GenerateLastName, maxLength *int64) (*GenerateLastNameOpts, error) { +func NewGenerateLastNameOptsFromConfig( + config *mgmtv1alpha1.GenerateLastName, + maxLength *int64, +) (*GenerateLastNameOpts, error) { if config == nil { return NewGenerateLastNameOpts( nil, @@ -70,7 +77,11 @@ func (t *GenerateLastName) Generate(opts any) (any, error) { return generateRandomLastName(parsedOpts.randomizer, nil, parsedOpts.maxLength) } -func generateRandomLastName(randomizer rng.Rand, minLength *int64, maxLength int64) (string, error) { +func generateRandomLastName( + randomizer rng.Rand, + minLength *int64, + maxLength int64, +) (string, error) { return transformer_utils.GenerateStringFromCorpus( randomizer, transformers_dataset.LastNames, diff --git a/worker/pkg/benthos/transformers/generate_random_string.go b/worker/pkg/benthos/transformers/generate_random_string.go index 56075a3d80..934a1e7d42 100644 --- a/worker/pkg/benthos/transformers/generate_random_string.go +++ b/worker/pkg/benthos/transformers/generate_random_string.go @@ -19,44 +19,55 @@ func init() { Param(bloblang.NewInt64Param("max").Default(100).Description("Specifies the maximum length for the generated string.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_string", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - min, err := args.GetInt64("min") - if err != nil { - return nil, err - } - - max, err := args.GetInt64("max") - if err != nil { - return nil, err - } - - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "generate_string", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + min, err := args.GetInt64("min") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } + max, err := args.GetInt64("max") + if err != nil { + return nil, err + } - randomizer := rng.New(seed) + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - out, err := transformer_utils.GenerateRandomStringWithInclusiveBounds(randomizer, min, max) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_string: %w", err) + return nil, err } - return out, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + out, err := transformer_utils.GenerateRandomStringWithInclusiveBounds( + randomizer, + min, + max, + ) + if err != nil { + return nil, fmt.Errorf("unable to run generate_string: %w", err) + } + return out, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateRandomStringOptsFromConfig(config *mgmtv1alpha1.GenerateString, maxLen *int64) (*GenerateRandomStringOpts, error) { +func NewGenerateRandomStringOptsFromConfig( + config *mgmtv1alpha1.GenerateString, + maxLen *int64, +) (*GenerateRandomStringOpts, error) { if config == nil { return NewGenerateRandomStringOpts( nil, @@ -75,7 +86,10 @@ func NewGenerateRandomStringOptsFromConfig(config *mgmtv1alpha1.GenerateString, } } if minValue != nil { - newMin := transformer_utils.MinInt(*minValue, *maxValue) // ensure the min is not larger than the max allowed length + newMin := transformer_utils.MinInt( + *minValue, + *maxValue, + ) // ensure the min is not larger than the max allowed length minValue = &newMin } @@ -92,5 +106,9 @@ func (t *GenerateRandomString) Generate(opts any) (any, error) { return nil, fmt.Errorf("invalid parsed opts: %T", opts) } - return transformer_utils.GenerateRandomStringWithInclusiveBounds(parsedOpts.randomizer, parsedOpts.min, parsedOpts.max) + return transformer_utils.GenerateRandomStringWithInclusiveBounds( + parsedOpts.randomizer, + parsedOpts.min, + parsedOpts.max, + ) } diff --git a/worker/pkg/benthos/transformers/generate_sha256hash.go b/worker/pkg/benthos/transformers/generate_sha256hash.go index 088f7434f8..37ed6442ff 100644 --- a/worker/pkg/benthos/transformers/generate_sha256hash.go +++ b/worker/pkg/benthos/transformers/generate_sha256hash.go @@ -13,24 +13,31 @@ import ( // +neosyncTransformerBuilder:generate:generateSHA256Hash func init() { - spec := bloblang.NewPluginSpec().Description("Generates a random SHA256 hash and returns it as a string."). + spec := bloblang.NewPluginSpec(). + Description("Generates a random SHA256 hash and returns it as a string."). Category("string") - err := bloblang.RegisterFunctionV2("generate_sha256hash", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - return func() (any, error) { - val, err := generateRandomSHA256Hash(uuid.NewString()) - if err != nil { - return false, fmt.Errorf("unable to run generate_sha256hash: %w", err) - } - return val, nil - }, nil - }) + err := bloblang.RegisterFunctionV2( + "generate_sha256hash", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + return func() (any, error) { + val, err := generateRandomSHA256Hash(uuid.NewString()) + if err != nil { + return false, fmt.Errorf("unable to run generate_sha256hash: %w", err) + } + return val, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateSHA256HashOptsFromConfig(config *mgmtv1alpha1.GenerateSha256Hash) (*GenerateSHA256HashOpts, error) { +func NewGenerateSHA256HashOptsFromConfig( + config *mgmtv1alpha1.GenerateSha256Hash, +) (*GenerateSHA256HashOpts, error) { return NewGenerateSHA256HashOpts() } diff --git a/worker/pkg/benthos/transformers/generate_ssn.go b/worker/pkg/benthos/transformers/generate_ssn.go index 83c3bf15a3..051993318c 100644 --- a/worker/pkg/benthos/transformers/generate_ssn.go +++ b/worker/pkg/benthos/transformers/generate_ssn.go @@ -17,22 +17,26 @@ func init() { Category("string"). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_ssn", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "generate_ssn", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) - return func() (any, error) { - val := generateRandomSSN(randomizer) - return val, nil - }, nil - }) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) + if err != nil { + return nil, err + } + randomizer := rng.New(seed) + return func() (any, error) { + val := generateRandomSSN(randomizer) + return val, nil + }, nil + }, + ) if err != nil { panic(err) } diff --git a/worker/pkg/benthos/transformers/generate_state.go b/worker/pkg/benthos/transformers/generate_state.go index eb87086a8a..db3f4aa980 100644 --- a/worker/pkg/benthos/transformers/generate_state.go +++ b/worker/pkg/benthos/transformers/generate_state.go @@ -13,41 +13,48 @@ import ( // +neosyncTransformerBuilder:generate:generateState func init() { - spec := bloblang.NewPluginSpec().Description("Randomly selects a US state and by default, returns it as a 2-letter state code."). + spec := bloblang.NewPluginSpec(). + Description("Randomly selects a US state and by default, returns it as a 2-letter state code."). Category("string"). Param(bloblang.NewBoolParam("generate_full_name").Default(false).Description("If true returns the full state name instead of the two character state code.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_state", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - generateFullName, err := args.GetBool("generate_full_name") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_state", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + generateFullName, err := args.GetBool("generate_full_name") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - val, err := generateRandomState(randomizer, generateFullName) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_state: %w", err) + return nil, err } - return val, nil - }, nil - }) + randomizer := rng.New(seed) + + return func() (any, error) { + val, err := generateRandomState(randomizer, generateFullName) + if err != nil { + return nil, fmt.Errorf("unable to run generate_state: %w", err) + } + return val, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateStateOptsFromConfig(config *mgmtv1alpha1.GenerateState) (*GenerateStateOpts, error) { +func NewGenerateStateOptsFromConfig( + config *mgmtv1alpha1.GenerateState, +) (*GenerateStateOpts, error) { if config == nil { return NewGenerateStateOpts(nil, nil) } diff --git a/worker/pkg/benthos/transformers/generate_street_address.go b/worker/pkg/benthos/transformers/generate_street_address.go index 85a466ddd8..e3ea419ff3 100644 --- a/worker/pkg/benthos/transformers/generate_street_address.go +++ b/worker/pkg/benthos/transformers/generate_street_address.go @@ -27,37 +27,44 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(100).Description("Specifies the maximum length for the generated data. This field ensures that the output does not exceed a certain number of characters.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_street_address", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_street_address", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := generateRandomStreetAddress(randomizer, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_street_address: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := generateRandomStreetAddress(randomizer, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run generate_street_address: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateStreetAddressOptsFromConfig(config *mgmtv1alpha1.GenerateStreetAddress, maxLength *int64) (*GenerateStreetAddressOpts, error) { +func NewGenerateStreetAddressOptsFromConfig( + config *mgmtv1alpha1.GenerateStreetAddress, + maxLength *int64, +) (*GenerateStreetAddressOpts, error) { if config == nil { return NewGenerateStreetAddressOpts(nil, nil) } diff --git a/worker/pkg/benthos/transformers/generate_string_phone_number.go b/worker/pkg/benthos/transformers/generate_string_phone_number.go index 129ad238aa..0ca8333768 100644 --- a/worker/pkg/benthos/transformers/generate_string_phone_number.go +++ b/worker/pkg/benthos/transformers/generate_string_phone_number.go @@ -19,44 +19,50 @@ func init() { Param(bloblang.NewInt64Param("max").Default(15).Description("Specifies the maximum length for the generated phone number.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_string_phone_number", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - min, err := args.GetInt64("min") - if err != nil { - return nil, err - } - - max, err := args.GetInt64("max") - if err != nil { - return nil, err - } - - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) - - return func() (any, error) { - res, err := generateStringPhoneNumber(randomizer, min, max) + err := bloblang.RegisterFunctionV2( + "generate_string_phone_number", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + min, err := args.GetInt64("min") if err != nil { - return nil, fmt.Errorf("unable to run generate_string_phone_number: %w", err) + return nil, err } - return res, nil - }, nil - }) + + max, err := args.GetInt64("max") + if err != nil { + return nil, err + } + + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } + + seed, err := transformer_utils.GetSeedOrDefault(seedArg) + if err != nil { + return nil, err + } + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := generateStringPhoneNumber(randomizer, min, max) + if err != nil { + return nil, fmt.Errorf("unable to run generate_string_phone_number: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateStringPhoneNumberOptsFromConfig(config *mgmtv1alpha1.GenerateStringPhoneNumber) (*GenerateStringPhoneNumberOpts, error) { +func NewGenerateStringPhoneNumberOptsFromConfig( + config *mgmtv1alpha1.GenerateStringPhoneNumber, +) (*GenerateStringPhoneNumberOpts, error) { if config == nil { return NewGenerateStringPhoneNumberOpts(nil, nil, nil) } diff --git a/worker/pkg/benthos/transformers/generate_unix_timestamp.go b/worker/pkg/benthos/transformers/generate_unix_timestamp.go index cbf5384a0e..16803fbe43 100644 --- a/worker/pkg/benthos/transformers/generate_unix_timestamp.go +++ b/worker/pkg/benthos/transformers/generate_unix_timestamp.go @@ -13,31 +13,39 @@ import ( // +neosyncTransformerBuilder:generate:generateUnixTimestamp func init() { - spec := bloblang.NewPluginSpec().Description("Randomly generates a Unix timestamp that is in the past."). + spec := bloblang.NewPluginSpec(). + Description("Randomly generates a Unix timestamp that is in the past."). Category("int64"). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_unixtimestamp", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) - - return func() (any, error) { - return generateRandomUnixTimestamp(randomizer), nil - }, nil - }) + err := bloblang.RegisterFunctionV2( + "generate_unixtimestamp", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } + + seed, err := transformer_utils.GetSeedOrDefault(seedArg) + if err != nil { + return nil, err + } + randomizer := rng.New(seed) + + return func() (any, error) { + return generateRandomUnixTimestamp(randomizer), nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateUnixTimestampOptsFromConfig(config *mgmtv1alpha1.GenerateUnixTimestamp) (*GenerateUnixTimestampOpts, error) { + +func NewGenerateUnixTimestampOptsFromConfig( + config *mgmtv1alpha1.GenerateUnixTimestamp, +) (*GenerateUnixTimestampOpts, error) { return NewGenerateUnixTimestampOpts(nil) } diff --git a/worker/pkg/benthos/transformers/generate_username.go b/worker/pkg/benthos/transformers/generate_username.go index fdfad9317e..81c05d091a 100644 --- a/worker/pkg/benthos/transformers/generate_username.go +++ b/worker/pkg/benthos/transformers/generate_username.go @@ -19,36 +19,44 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(100).Description("Specifies the maximum length for the generated data. This field ensures that the output does not exceed a certain number of characters.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_username", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_username", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := generateUsername(randomizer, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run generate_username: %w", err) + return nil, err } - return res, nil - }, nil - }) + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := generateUsername(randomizer, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run generate_username: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateUsernameOptsFromConfig(config *mgmtv1alpha1.GenerateUsername, maxLength *int64) (*GenerateUsernameOpts, error) { + +func NewGenerateUsernameOptsFromConfig( + config *mgmtv1alpha1.GenerateUsername, + maxLength *int64, +) (*GenerateUsernameOpts, error) { if config == nil { return NewGenerateUsernameOpts(nil, nil) } diff --git a/worker/pkg/benthos/transformers/generate_utc_timestamp.go b/worker/pkg/benthos/transformers/generate_utc_timestamp.go index 6327736c32..a684940faf 100644 --- a/worker/pkg/benthos/transformers/generate_utc_timestamp.go +++ b/worker/pkg/benthos/transformers/generate_utc_timestamp.go @@ -17,28 +17,34 @@ func init() { Category("int64"). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_utctimestamp", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) - - return func() (any, error) { - return generateRandomUTCTimestamp(randomizer), nil - }, nil - }) + err := bloblang.RegisterFunctionV2( + "generate_utctimestamp", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } + + seed, err := transformer_utils.GetSeedOrDefault(seedArg) + if err != nil { + return nil, err + } + randomizer := rng.New(seed) + + return func() (any, error) { + return generateRandomUTCTimestamp(randomizer), nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateUTCTimestampOptsFromConfig(config *mgmtv1alpha1.GenerateUtcTimestamp) (*GenerateUTCTimestampOpts, error) { +func NewGenerateUTCTimestampOptsFromConfig( + config *mgmtv1alpha1.GenerateUtcTimestamp, +) (*GenerateUTCTimestampOpts, error) { return NewGenerateUTCTimestampOpts(nil) } diff --git a/worker/pkg/benthos/transformers/generate_uuid.go b/worker/pkg/benthos/transformers/generate_uuid.go index c09318a175..5c363b1738 100644 --- a/worker/pkg/benthos/transformers/generate_uuid.go +++ b/worker/pkg/benthos/transformers/generate_uuid.go @@ -17,19 +17,24 @@ func init() { Category("string"). Param(bloblang.NewBoolParam("include_hyphens"). Default(true). - Description("Determines whether the generated UUID should include hyphens. If set to true, the UUID will be formatted with hyphens (e.g., d853d251-e135-4fe4-a4eb-0aea6bfaf645). If set to false, the hyphens will be omitted (e.g., d853d251e1354fe4a4eb0aea6bfaf645).")) + Description("Determines whether the generated UUID should include hyphens. If set to true, the UUID will be formatted with hyphens (e.g., d853d251-e135-4fe4-a4eb-0aea6bfaf645). If set to false, the hyphens will be omitted (e.g., d853d251e1354fe4a4eb0aea6bfaf645)."), + ) - err := bloblang.RegisterFunctionV2("generate_uuid", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - include_hyphen, err := args.GetBool("include_hyphens") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "generate_uuid", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + include_hyphen, err := args.GetBool("include_hyphens") + if err != nil { + return nil, err + } - return func() (any, error) { - val := generateUuid(include_hyphen) - return val, nil - }, nil - }) + return func() (any, error) { + val := generateUuid(include_hyphen) + return val, nil + }, nil + }, + ) if err != nil { panic(err) } diff --git a/worker/pkg/benthos/transformers/generate_zipcode.go b/worker/pkg/benthos/transformers/generate_zipcode.go index 51544eeaba..7d07660bf8 100644 --- a/worker/pkg/benthos/transformers/generate_zipcode.go +++ b/worker/pkg/benthos/transformers/generate_zipcode.go @@ -17,33 +17,39 @@ func init() { Category("string"). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("generate_zipcode", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) + err := bloblang.RegisterFunctionV2( + "generate_zipcode", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - val, err := generateRandomZipcode(randomizer) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("failed to generate_zipcode: %w", err) + return nil, err } - return val, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + val, err := generateRandomZipcode(randomizer) + if err != nil { + return nil, fmt.Errorf("failed to generate_zipcode: %w", err) + } + return val, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewGenerateZipcodeOptsFromConfig(config *mgmtv1alpha1.GenerateZipcode) (*GenerateZipcodeOpts, error) { +func NewGenerateZipcodeOptsFromConfig( + config *mgmtv1alpha1.GenerateZipcode, +) (*GenerateZipcodeOpts, error) { return NewGenerateZipcodeOpts(nil) } diff --git a/worker/pkg/benthos/transformers/generator_utils.go b/worker/pkg/benthos/transformers/generator_utils.go index 751604b03e..4321378e38 100644 --- a/worker/pkg/benthos/transformers/generator_utils.go +++ b/worker/pkg/benthos/transformers/generator_utils.go @@ -96,7 +96,9 @@ func ExtractBenthosSpec(fileSet *token.FileSet) ([]*BenthosSpec, error) { } func ParseBloblangSpec(benthosSpec *BenthosSpec) (*ParsedBenthosSpec, error) { - paramRegex := regexp.MustCompile(`bloblang\.New(\w+)Param\("(\w+)"\)(?:\.Optional\(\))?(?:\.Default\(([^()]*(?:\([^()]*\))?[^()]*)\))?(?:\.Description\("([^"]*)"\))?`) + paramRegex := regexp.MustCompile( + `bloblang\.New(\w+)Param\("(\w+)"\)(?:\.Optional\(\))?(?:\.Default\(([^()]*(?:\([^()]*\))?[^()]*)\))?(?:\.Description\("([^"]*)"\))?`, + ) specDescriptionRegex := regexp.MustCompile(`\.Description\("([^"]*)"\)`) params := []*BenthosSpecParam{} readFile, err := os.Open(benthosSpec.SourceFile) @@ -126,7 +128,9 @@ func ParseBloblangSpec(benthosSpec *BenthosSpec) (*ParsedBenthosSpec, error) { categoryRegex := regexp.MustCompile(`\.Category\("([^"]*)"\)`) var category string - if categoryMatches := categoryRegex.FindStringSubmatch(benthosSpecStr); len(categoryMatches) > 0 { + if categoryMatches := categoryRegex.FindStringSubmatch(benthosSpecStr); len( + categoryMatches, + ) > 0 { category = categoryMatches[1] } if category == "" { @@ -149,7 +153,9 @@ func ParseBloblangSpec(benthosSpec *BenthosSpec) (*ParsedBenthosSpec, error) { // seed hack if strings.Contains(line, "Default(time.Now().UnixNano())") { defaultVal = "time.Now().UnixNano()" - if specMatches := specDescriptionRegex.FindStringSubmatch(line); len(specMatches) > 0 { + if specMatches := specDescriptionRegex.FindStringSubmatch(line); len( + specMatches, + ) > 0 { description = specMatches[1] } } diff --git a/worker/pkg/benthos/transformers/transform_character_scramble.go b/worker/pkg/benthos/transformers/transform_character_scramble.go index 4dc795c44a..b1599edc87 100644 --- a/worker/pkg/benthos/transformers/transform_character_scramble.go +++ b/worker/pkg/benthos/transformers/transform_character_scramble.go @@ -28,49 +28,55 @@ func init() { Param(bloblang.NewStringParam("user_provided_regex").Optional().Description("A custom regular expression. This regex is used to manipulate input data during the transformation process.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("transform_character_scramble", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - value, err := args.GetOptionalString("value") - if err != nil { - return nil, err - } - - regexPtr, err := args.GetOptionalString("user_provided_regex") - if err != nil { - return nil, err - } - - var regex string - if regexPtr != nil { - regex = *regexPtr - } + err := bloblang.RegisterFunctionV2( + "transform_character_scramble", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + value, err := args.GetOptionalString("value") + if err != nil { + return nil, err + } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + regexPtr, err := args.GetOptionalString("user_provided_regex") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } + var regex string + if regexPtr != nil { + regex = *regexPtr + } - randomizer := rng.New(seed) + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := transformCharacterScramble(randomizer, value, regex) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run transform_character_scramble: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := transformCharacterScramble(randomizer, value, regex) + if err != nil { + return nil, fmt.Errorf("unable to run transform_character_scramble: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformCharacterScrambleOptsFromConfig(config *mgmtv1alpha1.TransformCharacterScramble) (*TransformCharacterScrambleOpts, error) { +func NewTransformCharacterScrambleOptsFromConfig( + config *mgmtv1alpha1.TransformCharacterScramble, +) (*TransformCharacterScrambleOpts, error) { if config == nil { return NewTransformCharacterScrambleOpts(nil, nil) } @@ -137,7 +143,10 @@ func transformCharacterScramble(randomizer rng.Rand, value *string, regex string for _, match := range matches { start, end := match[0], match[1] // run the scrambler for the substring - matchTransformed := strings.Map(randomizedScrambleChar(randomizer), transformedString[start:end]) + matchTransformed := strings.Map( + randomizedScrambleChar(randomizer), + transformedString[start:end], + ) // replace the original substring with its transformed version transformedString = transformedString[:start] + matchTransformed + transformedString[end:] } diff --git a/worker/pkg/benthos/transformers/transform_e164_phone_number.go b/worker/pkg/benthos/transformers/transform_e164_phone_number.go index 8ecb431a9a..b1e954c989 100644 --- a/worker/pkg/benthos/transformers/transform_e164_phone_number.go +++ b/worker/pkg/benthos/transformers/transform_e164_phone_number.go @@ -22,54 +22,61 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(15).Description("Specifies the maximum length for the transformed data. This field ensures that the output does not exceed a certain number of characters.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("transform_e164_phone_number", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - valuePtr, err := args.GetOptionalString("value") - if err != nil { - return nil, err - } - - var value string - if valuePtr != nil { - value = *valuePtr - } - - preserveLength, err := args.GetBool("preserve_length") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "transform_e164_phone_number", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + valuePtr, err := args.GetOptionalString("value") + if err != nil { + return nil, err + } - maxLength, err := args.GetOptionalInt64("max_length") - if err != nil { - return nil, err - } + var value string + if valuePtr != nil { + value = *valuePtr + } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + preserveLength, err := args.GetBool("preserve_length") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } + maxLength, err := args.GetOptionalInt64("max_length") + if err != nil { + return nil, err + } - randomizer := rng.New(seed) + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := transformE164PhoneNumber(randomizer, value, preserveLength, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run transform_e164_phone_number: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := transformE164PhoneNumber(randomizer, value, preserveLength, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run transform_e164_phone_number: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformE164PhoneNumberOptsFromConfig(config *mgmtv1alpha1.TransformE164PhoneNumber, maxLength *int64) (*TransformE164PhoneNumberOpts, error) { +func NewTransformE164PhoneNumberOptsFromConfig( + config *mgmtv1alpha1.TransformE164PhoneNumber, + maxLength *int64, +) (*TransformE164PhoneNumberOpts, error) { if config == nil { return NewTransformE164PhoneNumberOpts(nil, nil, nil) } @@ -87,11 +94,21 @@ func (t *TransformE164PhoneNumber) Transform(value, opts any) (any, error) { return nil, errors.New("value is not a string") } - return transformE164PhoneNumber(parsedOpts.randomizer, valueStr, parsedOpts.preserveLength, &parsedOpts.maxLength) + return transformE164PhoneNumber( + parsedOpts.randomizer, + valueStr, + parsedOpts.preserveLength, + &parsedOpts.maxLength, + ) } // Generates a random phone number and returns it as a string -func transformE164PhoneNumber(randomizer rng.Rand, phone string, preserveLength bool, maxLength *int64) (*string, error) { +func transformE164PhoneNumber( + randomizer rng.Rand, + phone string, + preserveLength bool, + maxLength *int64, +) (*string, error) { var returnValue string if phone == "" { @@ -123,7 +140,10 @@ func transformE164PhoneNumber(randomizer rng.Rand, phone string, preserveLength } // generates a random E164 phone number and returns it as a string -func generateE164FormatPhoneNumberPreserveLength(randomizer rng.Rand, number string) (string, error) { +func generateE164FormatPhoneNumberPreserveLength( + randomizer rng.Rand, + number string, +) (string, error) { val := strings.Split(number, "+") length := int64(len(val[1])) diff --git a/worker/pkg/benthos/transformers/transform_email.go b/worker/pkg/benthos/transformers/transform_email.go index 0613fef078..2371854fa4 100644 --- a/worker/pkg/benthos/transformers/transform_email.go +++ b/worker/pkg/benthos/transformers/transform_email.go @@ -50,89 +50,96 @@ func init() { Param(bloblang.NewStringParam("email_type").Default(GenerateEmailType_UuidV4.String()).Description("Specifies the type of email to transform, with options including `uuidv4`, `fullname`, or `any`.")). Param(bloblang.NewStringParam("invalid_email_action").Default(InvalidEmailAction_Reject.String()).Description("Specifies the action to take when an invalid email is encountered, with options including `reject`, `passthrough`, `null`, or `generate`.")) - err := bloblang.RegisterFunctionV2("transform_email", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - emailPtr, err := args.GetOptionalString("value") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "transform_email", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + emailPtr, err := args.GetOptionalString("value") + if err != nil { + return nil, err + } - var email string - if emailPtr != nil { - email = *emailPtr - } + var email string + if emailPtr != nil { + email = *emailPtr + } - preserveLength, err := args.GetBool("preserve_length") - if err != nil { - return nil, err - } + preserveLength, err := args.GetBool("preserve_length") + if err != nil { + return nil, err + } - preserveDomain, err := args.GetBool("preserve_domain") - if err != nil { - return nil, err - } + preserveDomain, err := args.GetBool("preserve_domain") + if err != nil { + return nil, err + } - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } - excludedDomainsArg, err := args.Get("excluded_domains") - if err != nil { - return nil, err - } + excludedDomainsArg, err := args.Get("excluded_domains") + if err != nil { + return nil, err + } - excludedDomains, err := fromAnyToStringSlice(excludedDomainsArg) - if err != nil { - return nil, err - } + excludedDomains, err := fromAnyToStringSlice(excludedDomainsArg) + if err != nil { + return nil, err + } - emailTypeArg, err := args.GetString("email_type") - if err != nil { - return nil, err - } - emailType := getEmailTypeOrDefault(emailTypeArg) + emailTypeArg, err := args.GetString("email_type") + if err != nil { + return nil, err + } + emailType := getEmailTypeOrDefault(emailTypeArg) - invalidEmailActionArg, err := args.GetString("invalid_email_action") - if err != nil { - return nil, err - } - if !isValidInvalidEmailAction(invalidEmailActionArg) { - return nil, errors.New("not a valid invalid_email_action argument") - } + invalidEmailActionArg, err := args.GetString("invalid_email_action") + if err != nil { + return nil, err + } + if !isValidInvalidEmailAction(invalidEmailActionArg) { + return nil, errors.New("not a valid invalid_email_action argument") + } - invalidEmailAction := InvalidEmailAction(invalidEmailActionArg) + invalidEmailAction := InvalidEmailAction(invalidEmailActionArg) - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) - return func() (any, error) { - output, err := transformEmail(randomizer, email, transformeEmailOptions{ - PreserveLength: preserveLength, - PreserveDomain: preserveDomain, - MaxLength: maxLength, - ExcludedDomains: excludedDomains, - EmailType: emailType, - InvalidEmailAction: invalidEmailAction, - }) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run transform_email: %w", err) + return nil, err } - return output, nil - }, nil - }) + randomizer := rng.New(seed) + return func() (any, error) { + output, err := transformEmail(randomizer, email, transformeEmailOptions{ + PreserveLength: preserveLength, + PreserveDomain: preserveDomain, + MaxLength: maxLength, + ExcludedDomains: excludedDomains, + EmailType: emailType, + InvalidEmailAction: invalidEmailAction, + }) + if err != nil { + return nil, fmt.Errorf("unable to run transform_email: %w", err) + } + return output, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformEmailOptsFromConfig(config *mgmtv1alpha1.TransformEmail, maxLength *int64) (*TransformEmailOpts, error) { +func NewTransformEmailOptsFromConfig( + config *mgmtv1alpha1.TransformEmail, + maxLength *int64, +) (*TransformEmailOpts, error) { if config == nil { var excludedDomains any = "[]" return NewTransformEmailOpts(nil, nil, &excludedDomains, nil, nil, nil, nil) @@ -144,7 +151,9 @@ func NewTransformEmailOptsFromConfig(config *mgmtv1alpha1.TransformEmail, maxLen } var invalidEmailAction *string if config.InvalidEmailAction != nil { - invalidEmailActionStr := dtoInvalidEmailActionToTransformerInvalidEmailAction(config.GetInvalidEmailAction()).String() + invalidEmailActionStr := dtoInvalidEmailActionToTransformerInvalidEmailAction( + config.GetInvalidEmailAction(), + ).String() invalidEmailAction = &invalidEmailActionStr } excludedDomainsStr, err := convertStringSliceToString(config.GetExcludedDomains()) @@ -260,7 +269,12 @@ func transformEmail( case InvalidEmailAction_Null: return nil, nil case InvalidEmailAction_Generate: - newEmail, err := generateRandomEmail(randomizer, opts.MaxLength, opts.EmailType, opts.ExcludedDomains) + newEmail, err := generateRandomEmail( + randomizer, + opts.MaxLength, + opts.EmailType, + opts.ExcludedDomains, + ) if err != nil { return nil, err } @@ -293,7 +307,10 @@ func transformEmail( domainMaxLength = int64(len(email)) - 3 } if (domainMaxLength) <= 0 { - return nil, fmt.Errorf("for the given max length, unable to generate an email of sufficient length: %d", maxLength) + return nil, fmt.Errorf( + "for the given max length, unable to generate an email of sufficient length: %d", + maxLength, + ) } newdomain := domain @@ -319,7 +336,10 @@ func transformEmail( newuuid := strings.ReplaceAll(uuid.NewString(), "-", "") trimmeduuid := transformer_utils.TrimStringIfExceeds(newuuid, maxNameLength) if trimmeduuid == "" { - return nil, fmt.Errorf("for the given max length, unable to use uuid to generate transformed email: %d", maxNameLength) + return nil, fmt.Errorf( + "for the given max length, unable to use uuid to generate transformed email: %d", + maxNameLength, + ) } newname = trimmeduuid } else { @@ -343,7 +363,9 @@ func dtoEmailTypeToTransformerEmailType(dto mgmtv1alpha1.GenerateEmailType) Gene } } -func dtoInvalidEmailActionToTransformerInvalidEmailAction(dto mgmtv1alpha1.InvalidEmailAction) InvalidEmailAction { +func dtoInvalidEmailActionToTransformerInvalidEmailAction( + dto mgmtv1alpha1.InvalidEmailAction, +) InvalidEmailAction { switch dto { case mgmtv1alpha1.InvalidEmailAction_INVALID_EMAIL_ACTION_GENERATE: return InvalidEmailAction_Generate diff --git a/worker/pkg/benthos/transformers/transform_first_name.go b/worker/pkg/benthos/transformers/transform_first_name.go index 87401db946..41efee3ecc 100644 --- a/worker/pkg/benthos/transformers/transform_first_name.go +++ b/worker/pkg/benthos/transformers/transform_first_name.go @@ -21,54 +21,61 @@ func init() { Param(bloblang.NewBoolParam("preserve_length").Default(false).Description("Whether the original length of the input data should be preserved during transformation. If set to true, the transformation logic will ensure that the output data has the same length as the input data.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used for generating deterministic transformations.")) - err := bloblang.RegisterFunctionV2("transform_first_name", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - valuePtr, err := args.GetOptionalString("value") - if err != nil { - return nil, err - } - - var value string - if valuePtr != nil { - value = *valuePtr - } - - preserveLength, err := args.GetBool("preserve_length") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "transform_first_name", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + valuePtr, err := args.GetOptionalString("value") + if err != nil { + return nil, err + } - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } + var value string + if valuePtr != nil { + value = *valuePtr + } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + preserveLength, err := args.GetBool("preserve_length") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } - randomizer := rng.New(seed) + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := transformFirstName(randomizer, value, preserveLength, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run transform_first_name: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := transformFirstName(randomizer, value, preserveLength, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run transform_first_name: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformFirstNameOptsFromConfig(config *mgmtv1alpha1.TransformFirstName, maxLength *int64) (*TransformFirstNameOpts, error) { +func NewTransformFirstNameOptsFromConfig( + config *mgmtv1alpha1.TransformFirstName, + maxLength *int64, +) (*TransformFirstNameOpts, error) { if config == nil { return NewTransformFirstNameOpts(nil, nil, nil) } @@ -90,11 +97,21 @@ func (t *TransformFirstName) Transform(value, opts any) (any, error) { return nil, errors.New("value is not a string") } - return transformFirstName(parsedOpts.randomizer, valueStr, parsedOpts.preserveLength, parsedOpts.maxLength) + return transformFirstName( + parsedOpts.randomizer, + valueStr, + parsedOpts.preserveLength, + parsedOpts.maxLength, + ) } // Generates a random first name which can be of either random length or as long as the input name -func transformFirstName(randomizer rng.Rand, value string, preserveLength bool, maxLength int64) (*string, error) { +func transformFirstName( + randomizer rng.Rand, + value string, + preserveLength bool, + maxLength int64, +) (*string, error) { if value == "" { return &value, nil } @@ -119,7 +136,10 @@ func transformFirstName(randomizer rng.Rand, value string, preserveLength bool, // pad the string so that we can get the correct value if preserveLength && int64(len(output)) != maxValue { - output += transformer_utils.GetRandomCharacterString(randomizer, maxValue-int64(len(output))) + output += transformer_utils.GetRandomCharacterString( + randomizer, + maxValue-int64(len(output)), + ) } return &output, nil } diff --git a/worker/pkg/benthos/transformers/transform_float.go b/worker/pkg/benthos/transformers/transform_float.go index cb9d2bb1e5..0876fb7bfb 100644 --- a/worker/pkg/benthos/transformers/transform_float.go +++ b/worker/pkg/benthos/transformers/transform_float.go @@ -25,58 +25,73 @@ func init() { Param(bloblang.NewInt64Param("scale").Optional().Description("An optional parameter that defines the number of decimal places for the float.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used for generating deterministic transformations.")) - err := bloblang.RegisterFunctionV2("transform_float64", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - value, err := args.Get("value") - if err != nil { - return nil, err - } - - rMin, err := args.GetFloat64("randomization_range_min") - if err != nil { - return nil, err - } - - rMax, err := args.GetFloat64("randomization_range_max") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "transform_float64", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + value, err := args.Get("value") + if err != nil { + return nil, err + } - precision, err := args.GetOptionalInt64("precision") - if err != nil { - return nil, err - } - scale, err := args.GetOptionalInt64("scale") - if err != nil { - return nil, err - } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + rMin, err := args.GetFloat64("randomization_range_min") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - randomizer := rng.New(seed) + rMax, err := args.GetFloat64("randomization_range_max") + if err != nil { + return nil, err + } - maxnumgetter := newMaxNumCache() + precision, err := args.GetOptionalInt64("precision") + if err != nil { + return nil, err + } + scale, err := args.GetOptionalInt64("scale") + if err != nil { + return nil, err + } + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := transformFloat(randomizer, maxnumgetter, value, rMin, rMax, precision, scale) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run transform_float64: %w", err) + return nil, err } - return res, nil - }, nil - }) + randomizer := rng.New(seed) + + maxnumgetter := newMaxNumCache() + + return func() (any, error) { + res, err := transformFloat( + randomizer, + maxnumgetter, + value, + rMin, + rMax, + precision, + scale, + ) + if err != nil { + return nil, fmt.Errorf("unable to run transform_float64: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformFloat64OptsFromConfig(config *mgmtv1alpha1.TransformFloat64, scale, precision *int64) (*TransformFloat64Opts, error) { +func NewTransformFloat64OptsFromConfig( + config *mgmtv1alpha1.TransformFloat64, + scale, precision *int64, +) (*TransformFloat64Opts, error) { if config == nil { return NewTransformFloat64Opts(nil, nil, nil, nil, nil) } @@ -108,7 +123,13 @@ func (t *TransformFloat64) Transform(value, opts any) (any, error) { ) } -func transformFloat(randomizer rng.Rand, maxnumgetter maxNum, value any, rMin, rMax float64, precision, scale *int64) (*float64, error) { +func transformFloat( + randomizer rng.Rand, + maxnumgetter maxNum, + value any, + rMin, rMax float64, + precision, scale *int64, +) (*float64, error) { if value == nil { return nil, nil } @@ -136,7 +157,12 @@ func transformFloat(randomizer rng.Rand, maxnumgetter maxNum, value any, rMin, r newVal, err := generateRandomFloat64(randomizer, false, minValue, maxValue, precision, scale) if err != nil { - return nil, fmt.Errorf("unable to generate a random float64 with inclusive bounds with length [%f:%f]: %w", minValue, maxValue, err) + return nil, fmt.Errorf( + "unable to generate a random float64 with inclusive bounds with length [%f:%f]: %w", + minValue, + maxValue, + err, + ) } return &newVal, nil } diff --git a/worker/pkg/benthos/transformers/transform_full_name.go b/worker/pkg/benthos/transformers/transform_full_name.go index 52632978fe..c4d3fa2ad9 100644 --- a/worker/pkg/benthos/transformers/transform_full_name.go +++ b/worker/pkg/benthos/transformers/transform_full_name.go @@ -22,54 +22,61 @@ func init() { Param(bloblang.NewBoolParam("preserve_length").Default(false).Description("Whether the original length of the input data should be preserved during transformation. If set to true, the transformation logic will ensure that the output data has the same length as the input data.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used for generating deterministic transformations.")) - err := bloblang.RegisterFunctionV2("transform_full_name", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - valuePtr, err := args.GetOptionalString("value") - if err != nil { - return nil, err - } - - var value string - if valuePtr != nil { - value = *valuePtr - } - - preserveLength, err := args.GetBool("preserve_length") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "transform_full_name", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + valuePtr, err := args.GetOptionalString("value") + if err != nil { + return nil, err + } - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } + var value string + if valuePtr != nil { + value = *valuePtr + } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + preserveLength, err := args.GetBool("preserve_length") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } - randomizer := rng.New(seed) + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := transformFullName(randomizer, value, preserveLength, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run transform_full_name: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := transformFullName(randomizer, value, preserveLength, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run transform_full_name: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformFullNameOptsFromConfig(config *mgmtv1alpha1.TransformFullName, maxLength *int64) (*TransformFullNameOpts, error) { +func NewTransformFullNameOptsFromConfig( + config *mgmtv1alpha1.TransformFullName, + maxLength *int64, +) (*TransformFullNameOpts, error) { if config == nil { return NewTransformFullNameOpts(nil, nil, nil) } @@ -91,10 +98,20 @@ func (t *TransformFullName) Transform(value, opts any) (any, error) { return nil, errors.New("value is not a string") } - return transformFullName(parsedOpts.randomizer, valueStr, parsedOpts.preserveLength, parsedOpts.maxLength) + return transformFullName( + parsedOpts.randomizer, + valueStr, + parsedOpts.preserveLength, + parsedOpts.maxLength, + ) } -func transformFullName(randomizer rng.Rand, name string, preserveLength bool, maxLength int64) (*string, error) { +func transformFullName( + randomizer rng.Rand, + name string, + preserveLength bool, + maxLength int64, +) (*string, error) { if name == "" { return nil, nil } @@ -108,7 +125,10 @@ func transformFullName(randomizer rng.Rand, name string, preserveLength bool, ma if newfirstname == "" { newfirstname, _ = generateRandomFirstName(randomizer, nil, minFirst) if int64(len(newfirstname)) != minFirst { - newfirstname += transformer_utils.GetRandomCharacterString(randomizer, minFirst-int64(len(newfirstname))) + newfirstname += transformer_utils.GetRandomCharacterString( + randomizer, + minFirst-int64(len(newfirstname)), + ) } } minLast := int64(len(lastname)) @@ -116,7 +136,10 @@ func transformFullName(randomizer rng.Rand, name string, preserveLength bool, ma if newlastname == "" { newfirstname, _ = generateRandomLastName(randomizer, nil, minLast) if int64(len(newlastname)) != minLast { - newlastname += transformer_utils.GetRandomCharacterString(randomizer, minFirst-int64(len(newlastname))) + newlastname += transformer_utils.GetRandomCharacterString( + randomizer, + minFirst-int64(len(newlastname)), + ) } } if newfirstname != "" && newlastname != "" { @@ -130,7 +153,10 @@ func transformFullName(randomizer rng.Rand, name string, preserveLength bool, ma return nil, err } if preserveLength && len(output) != int(maxLength) { - output += transformer_utils.GetRandomCharacterString(randomizer, maxLength-int64(len(output))) + output += transformer_utils.GetRandomCharacterString( + randomizer, + maxLength-int64(len(output)), + ) } return &output, nil } diff --git a/worker/pkg/benthos/transformers/transform_identity_scramble.go b/worker/pkg/benthos/transformers/transform_identity_scramble.go index 549c006658..48ec444b7b 100644 --- a/worker/pkg/benthos/transformers/transform_identity_scramble.go +++ b/worker/pkg/benthos/transformers/transform_identity_scramble.go @@ -11,26 +11,32 @@ import ( "github.com/redpanda-data/benthos/v4/public/bloblang" ) -func RegisterTransformIdentityScramble(env *bloblang.Environment, allocator tablesync_shared.IdentityAllocator) error { +func RegisterTransformIdentityScramble( + env *bloblang.Environment, + allocator tablesync_shared.IdentityAllocator, +) error { spec := bloblang.NewPluginSpec(). Description("Scrambles the identity of the input"). Category("int64"). Param(bloblang.NewAnyParam("value").Description("The value to scramble").Optional()). Param(bloblang.NewStringParam("token").Description("The token used to exchange for a block of identity values")) - err := env.RegisterFunctionV2("transform_identity_scramble", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - value, err := args.Get("value") - if err != nil { - return nil, err - } - token, err := args.GetString("token") - if err != nil { - return nil, err - } - return func() (any, error) { - return transformIdentityScramble(allocator, token, value) - }, nil - }, + err := env.RegisterFunctionV2( + "transform_identity_scramble", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + value, err := args.Get("value") + if err != nil { + return nil, err + } + token, err := args.GetString("token") + if err != nil { + return nil, err + } + return func() (any, error) { + return transformIdentityScramble(allocator, token, value) + }, nil + }, ) if err != nil { return fmt.Errorf("unable to register transform_identity_scramble: %w", err) @@ -38,14 +44,18 @@ func RegisterTransformIdentityScramble(env *bloblang.Environment, allocator tabl return nil } -func NewTransformIdentityScrambleOptsFromConfig(config *mgmtv1alpha1.TransformScrambleIdentity) (*TransformIdentityScrambleOpts, error) { +func NewTransformIdentityScrambleOptsFromConfig( + config *mgmtv1alpha1.TransformScrambleIdentity, +) (*TransformIdentityScrambleOpts, error) { if config == nil { return NewTransformIdentityScrambleOpts("token-not-implemented") } return NewTransformIdentityScrambleOpts("token-not-implemented") } -func NewTransformIdentityScrambleOptsFromConfigWithToken(token string) (*TransformIdentityScrambleOpts, error) { +func NewTransformIdentityScrambleOptsFromConfigWithToken( + token string, +) (*TransformIdentityScrambleOpts, error) { return NewTransformIdentityScrambleOpts(token) } @@ -58,7 +68,11 @@ func (t *TransformIdentityScramble) Transform(value, opts any) (any, error) { return transformIdentityScramble(nil, "token-not-implemented", value) } -func transformIdentityScramble(allocator tablesync_shared.IdentityAllocator, token string, value any) (any, error) { +func transformIdentityScramble( + allocator tablesync_shared.IdentityAllocator, + token string, + value any, +) (any, error) { if value == nil { return nil, nil // todo: we should instead return a new scrambled value } @@ -76,13 +90,21 @@ func transformIdentityScramble(allocator tablesync_shared.IdentityAllocator, tok } identity, err = allocator.GetIdentity(context.Background(), token, &input) if err != nil { - return nil, fmt.Errorf("unable to get identity from value with token %s: %w", token, err) + return nil, fmt.Errorf( + "unable to get identity from value with token %s: %w", + token, + err, + ) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: input := uint(v.Uint()) identity, err = allocator.GetIdentity(context.Background(), token, &input) if err != nil { - return nil, fmt.Errorf("unable to get identity from value with token %s: %w", token, err) + return nil, fmt.Errorf( + "unable to get identity from value with token %s: %w", + token, + err, + ) } default: return nil, fmt.Errorf("unable to get identity from value as input was %T", value) diff --git a/worker/pkg/benthos/transformers/transform_int64.go b/worker/pkg/benthos/transformers/transform_int64.go index 042d230a97..0376ffd9d5 100644 --- a/worker/pkg/benthos/transformers/transform_int64.go +++ b/worker/pkg/benthos/transformers/transform_int64.go @@ -21,49 +21,55 @@ func init() { Param(bloblang.NewInt64Param("randomization_range_max").Default(10000).Description("Specifies the maximum value for the range of the int.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("transform_int64", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - valuePtr, err := args.GetOptionalInt64("value") - if err != nil { - return nil, err - } - - rMin, err := args.GetInt64("randomization_range_min") - if err != nil { - return nil, err - } - - rMax, err := args.GetInt64("randomization_range_max") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "transform_int64", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + valuePtr, err := args.GetOptionalInt64("value") + if err != nil { + return nil, err + } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + rMin, err := args.GetInt64("randomization_range_min") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } + rMax, err := args.GetInt64("randomization_range_max") + if err != nil { + return nil, err + } - randomizer := rng.New(seed) + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := transformInt(randomizer, valuePtr, rMin, rMax) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run transform_int64: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := transformInt(randomizer, valuePtr, rMin, rMax) + if err != nil { + return nil, fmt.Errorf("unable to run transform_int64: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformInt64OptsFromConfig(config *mgmtv1alpha1.TransformInt64) (*TransformInt64Opts, error) { +func NewTransformInt64OptsFromConfig( + config *mgmtv1alpha1.TransformInt64, +) (*TransformInt64Opts, error) { if config == nil { return NewTransformInt64Opts(nil, nil, nil) } @@ -80,7 +86,12 @@ func (t *TransformInt64) Transform(value, opts any) (any, error) { return nil, fmt.Errorf("invalid parsed opts: %T", opts) } - return transformInt(parsedOpts.randomizer, value, parsedOpts.randomizationRangeMin, parsedOpts.randomizationRangeMax) + return transformInt( + parsedOpts.randomizer, + value, + parsedOpts.randomizationRangeMin, + parsedOpts.randomizationRangeMax, + ) } func transformInt(randomizer rng.Rand, value any, rMin, rMax int64) (*int64, error) { @@ -105,7 +116,12 @@ func transformInt(randomizer rng.Rand, value any, rMin, rMax int64) (*int64, err val, err := transformer_utils.GenerateRandomInt64InValueRange(randomizer, minRange, maxRange) if err != nil { - return nil, fmt.Errorf("unable to generate a random int64 with length [%d:%d]:%w", minRange, maxRange, err) + return nil, fmt.Errorf( + "unable to generate a random int64 with length [%d:%d]:%w", + minRange, + maxRange, + err, + ) } return &val, nil } diff --git a/worker/pkg/benthos/transformers/transform_int64_phone_number.go b/worker/pkg/benthos/transformers/transform_int64_phone_number.go index 4bb24169ed..0be51d6b3c 100644 --- a/worker/pkg/benthos/transformers/transform_int64_phone_number.go +++ b/worker/pkg/benthos/transformers/transform_int64_phone_number.go @@ -22,49 +22,55 @@ func init() { Param(bloblang.NewBoolParam("preserve_length").Default(false).Description("Whether the original length of the input data should be preserved during transformation. If set to true, the transformation logic will ensure that the output data has the same length as the input data.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("transform_int64_phone_number", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - valuePtr, err := args.GetOptionalInt64("value") - if err != nil { - return nil, err - } - - var value int64 - if valuePtr != nil { - value = *valuePtr - } - - preserveLength, err := args.GetBool("preserve_length") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "transform_int64_phone_number", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + valuePtr, err := args.GetOptionalInt64("value") + if err != nil { + return nil, err + } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + var value int64 + if valuePtr != nil { + value = *valuePtr + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } + preserveLength, err := args.GetBool("preserve_length") + if err != nil { + return nil, err + } - randomizer := rng.New(seed) + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := transformInt64PhoneNumber(randomizer, value, preserveLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run transform_int64_phone_number: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := transformInt64PhoneNumber(randomizer, value, preserveLength) + if err != nil { + return nil, fmt.Errorf("unable to run transform_int64_phone_number: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformInt64PhoneNumberOptsFromConfig(config *mgmtv1alpha1.TransformInt64PhoneNumber) (*TransformInt64PhoneNumberOpts, error) { +func NewTransformInt64PhoneNumberOptsFromConfig( + config *mgmtv1alpha1.TransformInt64PhoneNumber, +) (*TransformInt64PhoneNumberOpts, error) { if config == nil { return NewTransformInt64PhoneNumberOpts(nil, nil) } @@ -81,7 +87,11 @@ func (t *TransformInt64PhoneNumber) Transform(value, opts any) (any, error) { } // generates a random phone number and returns it as an int64 -func transformInt64PhoneNumber(randomizer rng.Rand, value any, preserveLength bool) (*int64, error) { +func transformInt64PhoneNumber( + randomizer rng.Rand, + value any, + preserveLength bool, +) (*int64, error) { if value == nil { return nil, nil } @@ -119,7 +129,10 @@ func transformInt64PhoneNumber(randomizer rng.Rand, value any, preserveLength bo func generateIntPhoneNumberPreserveLength(randomizer rng.Rand, number int64) (int64, error) { // get a random area code from the areacodes data set - randAreaCodeStr, err := transformer_utils.GetRandomValueFromSlice(randomizer, transformers_dataset.UsAreaCodes) + randAreaCodeStr, err := transformer_utils.GetRandomValueFromSlice( + randomizer, + transformers_dataset.UsAreaCodes, + ) if err != nil { return 0, err } @@ -129,7 +142,10 @@ func generateIntPhoneNumberPreserveLength(randomizer rng.Rand, number int64) (in return 0, err } - pn, err := transformer_utils.GenerateRandomInt64FixedLength(randomizer, transformer_utils.GetInt64Length(number)-3) + pn, err := transformer_utils.GenerateRandomInt64FixedLength( + randomizer, + transformer_utils.GetInt64Length(number)-3, + ) if err != nil { return 0, err } diff --git a/worker/pkg/benthos/transformers/transform_lastname.go b/worker/pkg/benthos/transformers/transform_lastname.go index e841e19300..0db2100be7 100644 --- a/worker/pkg/benthos/transformers/transform_lastname.go +++ b/worker/pkg/benthos/transformers/transform_lastname.go @@ -21,53 +21,60 @@ func init() { Param(bloblang.NewBoolParam("preserve_length").Default(false).Description("Whether the original length of the input data should be preserved during transformation. If set to true, the transformation logic will ensure that the output data has the same length as the input data.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used for generating deterministic transformations.")) - err := bloblang.RegisterFunctionV2("transform_last_name", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - valuePtr, err := args.GetOptionalString("value") - if err != nil { - return nil, err - } - - var value string - if valuePtr != nil { - value = *valuePtr - } - preserveLength, err := args.GetBool("preserve_length") - if err != nil { - return nil, err - } - - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "transform_last_name", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + valuePtr, err := args.GetOptionalString("value") + if err != nil { + return nil, err + } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + var value string + if valuePtr != nil { + value = *valuePtr + } + preserveLength, err := args.GetBool("preserve_length") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } - randomizer := rng.New(seed) + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := transformLastName(randomizer, value, preserveLength, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run transform_last_name: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := transformLastName(randomizer, value, preserveLength, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run transform_last_name: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformLastNameOptsFromConfig(config *mgmtv1alpha1.TransformLastName, maxLength *int64) (*TransformLastNameOpts, error) { +func NewTransformLastNameOptsFromConfig( + config *mgmtv1alpha1.TransformLastName, + maxLength *int64, +) (*TransformLastNameOpts, error) { if config == nil { return NewTransformLastNameOpts(nil, nil, nil) } @@ -89,11 +96,21 @@ func (t *TransformLastName) Transform(value, opts any) (any, error) { return nil, errors.New("value is not a string") } - return transformLastName(parsedOpts.randomizer, valueStr, parsedOpts.preserveLength, parsedOpts.maxLength) + return transformLastName( + parsedOpts.randomizer, + valueStr, + parsedOpts.preserveLength, + parsedOpts.maxLength, + ) } // Generates a random last name which can be of either random length between [2,12] characters or as long as the input name -func transformLastName(randomizer rng.Rand, name string, preserveLength bool, maxLength int64) (*string, error) { +func transformLastName( + randomizer rng.Rand, + name string, + preserveLength bool, + maxLength int64, +) (*string, error) { if name == "" { return nil, nil } @@ -118,7 +135,10 @@ func transformLastName(randomizer rng.Rand, name string, preserveLength bool, ma // pad the string so that we can get the correct value if preserveLength && int64(len(output)) != maxValue { - output += transformer_utils.GetRandomCharacterString(randomizer, maxValue-int64(len(output))) + output += transformer_utils.GetRandomCharacterString( + randomizer, + maxValue-int64(len(output)), + ) } return &output, nil } diff --git a/worker/pkg/benthos/transformers/transform_string.go b/worker/pkg/benthos/transformers/transform_string.go index 8b20d7c834..cc43edb720 100644 --- a/worker/pkg/benthos/transformers/transform_string.go +++ b/worker/pkg/benthos/transformers/transform_string.go @@ -22,54 +22,61 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(100).Description("Specifies the maximum length of the transformed value.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("transform_string", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - value, err := args.GetOptionalString("value") - if err != nil { - return nil, err - } - - preserveLength, err := args.GetBool("preserve_length") - if err != nil { - return nil, err - } - - minLength, err := args.GetInt64("min_length") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "transform_string", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + value, err := args.GetOptionalString("value") + if err != nil { + return nil, err + } - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } + preserveLength, err := args.GetBool("preserve_length") + if err != nil { + return nil, err + } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + minLength, err := args.GetInt64("min_length") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } - randomizer := rng.New(seed) + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := transformString(randomizer, value, preserveLength, minLength, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run transform_string: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := transformString(randomizer, value, preserveLength, minLength, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run transform_string: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformStringOptsFromConfig(config *mgmtv1alpha1.TransformString, minLength, maxLength *int64) (*TransformStringOpts, error) { +func NewTransformStringOptsFromConfig( + config *mgmtv1alpha1.TransformString, + minLength, maxLength *int64, +) (*TransformStringOpts, error) { if config == nil { return NewTransformStringOpts(nil, nil, nil, nil) } @@ -92,11 +99,22 @@ func (t *TransformString) Transform(value, opts any) (any, error) { return nil, errors.New("value is not a string") } - return transformString(parsedOpts.randomizer, &valueStr, parsedOpts.preserveLength, parsedOpts.minLength, parsedOpts.maxLength) + return transformString( + parsedOpts.randomizer, + &valueStr, + parsedOpts.preserveLength, + parsedOpts.minLength, + parsedOpts.maxLength, + ) } // Transforms an existing string value into another string. Does not account for numbers and other characters. If you want to preserve spaces, capitalization and other characters, use the Transform_Characters transformer. -func transformString(randomizer rng.Rand, value *string, preserveLength bool, minLength, maxLength int64) (*string, error) { +func transformString( + randomizer rng.Rand, + value *string, + preserveLength bool, + minLength, maxLength int64, +) (*string, error) { if value == nil || *value == "" { return value, nil } @@ -115,7 +133,12 @@ func transformString(randomizer rng.Rand, value *string, preserveLength bool, mi } val, err := transformer_utils.GenerateRandomStringWithInclusiveBounds(randomizer, minL, maxL) if err != nil { - return nil, fmt.Errorf("unable to transform a random string with length: [%d:%d]: %w", minL, maxL, err) + return nil, fmt.Errorf( + "unable to transform a random string with length: [%d:%d]: %w", + minL, + maxL, + err, + ) } return &val, nil } diff --git a/worker/pkg/benthos/transformers/transform_string_phone_number.go b/worker/pkg/benthos/transformers/transform_string_phone_number.go index e7c7aa9793..fd760dc73a 100644 --- a/worker/pkg/benthos/transformers/transform_string_phone_number.go +++ b/worker/pkg/benthos/transformers/transform_string_phone_number.go @@ -21,49 +21,56 @@ func init() { Param(bloblang.NewInt64Param("max_length").Default(100).Description("Specifies the maximum length for the transformed data. This field ensures that the output does not exceed a certain number of characters.")). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used to generate deterministic outputs.")) - err := bloblang.RegisterFunctionV2("transform_phone_number", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - value, err := args.GetOptionalString("value") - if err != nil { - return nil, err - } - - preserveLength, err := args.GetBool("preserve_length") - if err != nil { - return nil, err - } - - maxLength, err := args.GetInt64("max_length") - if err != nil { - return nil, err - } + err := bloblang.RegisterFunctionV2( + "transform_phone_number", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + value, err := args.GetOptionalString("value") + if err != nil { + return nil, err + } - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } + preserveLength, err := args.GetBool("preserve_length") + if err != nil { + return nil, err + } - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } + maxLength, err := args.GetInt64("max_length") + if err != nil { + return nil, err + } - randomizer := rng.New(seed) + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } - return func() (any, error) { - res, err := transformPhoneNumber(randomizer, value, preserveLength, maxLength) + seed, err := transformer_utils.GetSeedOrDefault(seedArg) if err != nil { - return nil, fmt.Errorf("unable to run transform_phone_number: %w", err) + return nil, err } - return res, nil - }, nil - }) + + randomizer := rng.New(seed) + + return func() (any, error) { + res, err := transformPhoneNumber(randomizer, value, preserveLength, maxLength) + if err != nil { + return nil, fmt.Errorf("unable to run transform_phone_number: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformStringPhoneNumberOptsFromConfig(config *mgmtv1alpha1.TransformPhoneNumber, maxLength *int64) (*TransformStringPhoneNumberOpts, error) { +func NewTransformStringPhoneNumberOptsFromConfig( + config *mgmtv1alpha1.TransformPhoneNumber, + maxLength *int64, +) (*TransformStringPhoneNumberOpts, error) { if config == nil { return NewTransformStringPhoneNumberOpts(nil, nil, nil) } @@ -84,11 +91,21 @@ func (t *TransformStringPhoneNumber) Transform(value, opts any) (any, error) { return nil, errors.New("value is not a string") } - return transformPhoneNumber(parsedOpts.randomizer, &valueStr, parsedOpts.preserveLength, parsedOpts.maxLength) + return transformPhoneNumber( + parsedOpts.randomizer, + &valueStr, + parsedOpts.preserveLength, + parsedOpts.maxLength, + ) } // Generates a random phone number and returns it as a string -func transformPhoneNumber(randomizer rng.Rand, value *string, preserveLength bool, maxLength int64) (*string, error) { +func transformPhoneNumber( + randomizer rng.Rand, + value *string, + preserveLength bool, + maxLength int64, +) (*string, error) { if value == nil || *value == "" { return value, nil } @@ -106,7 +123,12 @@ func transformPhoneNumber(randomizer rng.Rand, value *string, preserveLength boo } val, err := generateStringPhoneNumber(randomizer, minL, maxL) if err != nil { - return nil, fmt.Errorf("unable to transform phone number with length: [%d:%d]: %w", minL, maxL, err) + return nil, fmt.Errorf( + "unable to transform phone number with length: [%d:%d]: %w", + minL, + maxL, + err, + ) } return &val, nil } diff --git a/worker/pkg/benthos/transformers/transform_uuid.go b/worker/pkg/benthos/transformers/transform_uuid.go index b3befbeb68..e50821c9d5 100644 --- a/worker/pkg/benthos/transformers/transform_uuid.go +++ b/worker/pkg/benthos/transformers/transform_uuid.go @@ -21,44 +21,50 @@ func init() { Param(bloblang.NewAnyParam("value").Optional()). Param(bloblang.NewInt64Param("seed").Optional().Description("An optional seed value used for generating deterministic transformations.")) - err := bloblang.RegisterFunctionV2("transform_uuid", spec, func(args *bloblang.ParsedParams) (bloblang.Function, error) { - valuePtr, err := args.GetOptionalString("value") - if err != nil { - return nil, err - } - - var value string - if valuePtr != nil { - value = *valuePtr - } - - seedArg, err := args.GetOptionalInt64("seed") - if err != nil { - return nil, err - } - - seed, err := transformer_utils.GetSeedOrDefault(seedArg) - if err != nil { - return nil, err - } - - randomizer := rng.New(seed) - - return func() (any, error) { - res := transformUuid(randomizer, value) + err := bloblang.RegisterFunctionV2( + "transform_uuid", + spec, + func(args *bloblang.ParsedParams) (bloblang.Function, error) { + valuePtr, err := args.GetOptionalString("value") if err != nil { - return nil, fmt.Errorf("unable to run transform_uuid: %w", err) + return nil, err } - return res, nil - }, nil - }) + + var value string + if valuePtr != nil { + value = *valuePtr + } + + seedArg, err := args.GetOptionalInt64("seed") + if err != nil { + return nil, err + } + + seed, err := transformer_utils.GetSeedOrDefault(seedArg) + if err != nil { + return nil, err + } + + randomizer := rng.New(seed) + + return func() (any, error) { + res := transformUuid(randomizer, value) + if err != nil { + return nil, fmt.Errorf("unable to run transform_uuid: %w", err) + } + return res, nil + }, nil + }, + ) if err != nil { panic(err) } } -func NewTransformUuidOptsFromConfig(config *mgmtv1alpha1.TransformUuid) (*TransformUuidOpts, error) { +func NewTransformUuidOptsFromConfig( + config *mgmtv1alpha1.TransformUuid, +) (*TransformUuidOpts, error) { if config == nil { return NewTransformUuidOpts(nil) } diff --git a/worker/pkg/benthos/transformers/utils/float_utils.go b/worker/pkg/benthos/transformers/utils/float_utils.go index 421654f078..7782243277 100644 --- a/worker/pkg/benthos/transformers/utils/float_utils.go +++ b/worker/pkg/benthos/transformers/utils/float_utils.go @@ -12,7 +12,10 @@ import ( /* FLOAT MANIPULATION UTILS */ // Generates a random float64 in the range of the min and max float64 values -func GenerateRandomFloat64WithInclusiveBounds(randomizer rng.Rand, minValue, maxValue float64) (float64, error) { +func GenerateRandomFloat64WithInclusiveBounds( + randomizer rng.Rand, + minValue, maxValue float64, +) (float64, error) { if minValue > maxValue { minValue, maxValue = maxValue, minValue } diff --git a/worker/pkg/benthos/transformers/utils/integer_utils.go b/worker/pkg/benthos/transformers/utils/integer_utils.go index 9f4bc5c99c..ca3362cf29 100644 --- a/worker/pkg/benthos/transformers/utils/integer_utils.go +++ b/worker/pkg/benthos/transformers/utils/integer_utils.go @@ -32,7 +32,10 @@ func GenerateRandomInt64FixedLength(randomizer rng.Rand, l int64) (int64, error) /* Generates a random int64 with length in the inclusive range of [min, max]. For example, given a length range of [4, 7], possible values will have a length ranging from 4 -> 7 digits. */ -func GenerateRandomInt64InLengthRange(randomizer rng.Rand, minValue, maxValue int64) (int64, error) { +func GenerateRandomInt64InLengthRange( + randomizer rng.Rand, + minValue, maxValue int64, +) (int64, error) { if minValue > maxValue { minValue, maxValue = maxValue, minValue } @@ -44,12 +47,22 @@ func GenerateRandomInt64InLengthRange(randomizer rng.Rand, minValue, maxValue in val, err := GenerateRandomInt64InValueRange(randomizer, minValue, maxValue) if err != nil { - return 0, fmt.Errorf("unable to generate a value in the range provided [%d:%d]: %w", minValue, maxValue, err) + return 0, fmt.Errorf( + "unable to generate a value in the range provided [%d:%d]: %w", + minValue, + maxValue, + err, + ) } res, err := GenerateRandomInt64FixedLength(randomizer, val) if err != nil { - return 0, fmt.Errorf("unable to generate fixed int64 in the range provided [%d:%d: %w]", minValue, maxValue, err) + return 0, fmt.Errorf( + "unable to generate fixed int64 in the range provided [%d:%d: %w]", + minValue, + maxValue, + err, + ) } return res, nil @@ -68,7 +81,11 @@ func GenerateRandomInt64InValueRange(randomizer rng.Rand, minValue, maxValue int // Calculate range without the +1 to avoid overflow rangeVal := maxValue - minValue if rangeVal < 0 { - return 0, fmt.Errorf("invalid range: difference between max (%d) and min (%d) would result in non-positive range", maxValue, minValue) + return 0, fmt.Errorf( + "invalid range: difference between max (%d) and min (%d) would result in non-positive range", + maxValue, + minValue, + ) } // Special case when maxValue is MaxInt64 to avoid overflow diff --git a/worker/pkg/benthos/transformers/utils/slice_utils.go b/worker/pkg/benthos/transformers/utils/slice_utils.go index 54517b4fc6..98028f4a4d 100644 --- a/worker/pkg/benthos/transformers/utils/slice_utils.go +++ b/worker/pkg/benthos/transformers/utils/slice_utils.go @@ -30,7 +30,9 @@ func FindClosestPair(sortedSlice1, sortedSlice2 []int64, maxValue int64) (leftid // Initialize variables to track the best pair found so far and the best individual value. bestPair := [2]int64{-1, -1} // Initialize to (-1, -1) to indicate failure. closestDiff := int64(math.MaxInt64) // Initialize with the largest int64 value. - maxSum := int64(0) // Track the maximum sum less than or equal to maxLength with the smallest difference. + maxSum := int64( + 0, + ) // Track the maximum sum less than or equal to maxLength with the smallest difference. // Check if any of the lists is empty and handle accordingly if len(sortedSlice1) == 0 || len(sortedSlice2) == 0 { diff --git a/worker/pkg/benthos/transformers/utils/string_utils.go b/worker/pkg/benthos/transformers/utils/string_utils.go index cd6a3a752e..2087039d16 100644 --- a/worker/pkg/benthos/transformers/utils/string_utils.go +++ b/worker/pkg/benthos/transformers/utils/string_utils.go @@ -80,9 +80,16 @@ var ( ) // Generate a random alphanumeric string within the interval [min, max] -func GenerateRandomStringWithInclusiveBounds(randomizer rng.Rand, minValue, maxValue int64) (string, error) { +func GenerateRandomStringWithInclusiveBounds( + randomizer rng.Rand, + minValue, maxValue int64, +) (string, error) { if minValue < 0 || maxValue < 0 || minValue > maxValue { - return "", fmt.Errorf("invalid bounds when attempting to generate random string: [%d:%d]", minValue, maxValue) + return "", fmt.Errorf( + "invalid bounds when attempting to generate random string: [%d:%d]", + minValue, + maxValue, + ) } // Cap the maximum length @@ -176,7 +183,8 @@ func IsValidEmail(email string) bool { // use MaxASCII to ensure that the unicode value is only within the ASCII block which only contains latin numbers, letters and characters. func IsValidChar(s string) bool { for _, r := range s { - if r > unicode.MaxASCII || (!unicode.IsNumber(r) && !unicode.IsLetter(r) && !unicode.IsSpace(r) && !IsAllowedSpecialChar(r)) { + if r > unicode.MaxASCII || + (!unicode.IsNumber(r) && !unicode.IsLetter(r) && !unicode.IsSpace(r) && !IsAllowedSpecialChar(r)) { return false } } @@ -250,7 +258,10 @@ func GenerateStringFromCorpus( excludedset := ToSet(exclusions) idxCandidates := ClampInts(mapKeys, minLength, &maxLength) if len(idxCandidates) == 0 { - return "", fmt.Errorf("unable to find candidates with range %s", getRangeText(minLength, maxLength)) + return "", fmt.Errorf( + "unable to find candidates with range %s", + getRangeText(minLength, maxLength), + ) } rangeIdxs := getRangeFromCandidates(idxCandidates, lengthMap) @@ -258,7 +269,9 @@ func GenerateStringFromCorpus( rightIdx := rangeIdxs[1] if leftIdx == -1 || rightIdx == -1 { - return "", errors.New("unable to generate string from corpus due to invalid dictionary ranges") + return "", errors.New( + "unable to generate string from corpus due to invalid dictionary ranges", + ) } attemptedValues := map[int64]struct{}{} @@ -276,7 +289,9 @@ func GenerateStringFromCorpus( } return value, nil } - return "", errors.New("unable to generate random value given the max length and excluded values") + return "", errors.New( + "unable to generate random value given the max length and excluded values", + ) } func getRangeFromCandidates(candidates []int64, lengthMap map[int64][2]int) [2]int64 { diff --git a/worker/pkg/query-builder/insert-query-builder.go b/worker/pkg/query-builder/insert-query-builder.go index 7364ae1908..a4c8a7c6e2 100644 --- a/worker/pkg/query-builder/insert-query-builder.go +++ b/worker/pkg/query-builder/insert-query-builder.go @@ -133,7 +133,9 @@ type PostgresDriver struct { options *InsertOptions } -func (d *PostgresDriver) BuildInsertQuery(rows []map[string]any) (query string, queryargs []any, err error) { +func (d *PostgresDriver) BuildInsertQuery( + rows []map[string]any, +) (query string, queryargs []any, err error) { insertQuery, args, err := d.buildInsertQuery(rows) if err != nil { return "", nil, fmt.Errorf("failed to build postgres insert query: %w", err) @@ -145,7 +147,9 @@ func (d *PostgresDriver) BuildInsertQuery(rows []map[string]any) (query string, return insertQuery, args, nil } -func (d *PostgresDriver) buildInsertQuery(rows []map[string]any) (sql string, args []any, err error) { +func (d *PostgresDriver) buildInsertQuery( + rows []map[string]any, +) (sql string, args []any, err error) { goquRows := toGoquRecords(rows) if d.options.conflictConfig.onConflictDoUpdate != nil { if len(rows) == 0 { @@ -160,21 +164,43 @@ func (d *PostgresDriver) buildInsertQuery(rows []map[string]any) (sql string, ar updateColumns = append(updateColumns, col) } } - if len(d.options.conflictConfig.onConflictDoUpdate.conflictColumns) == 0 || len(updateColumns) == 0 { - d.logger.Warn("no conflict columns specified for on conflict do update, defaulting to on conflict do nothing") + if len(d.options.conflictConfig.onConflictDoUpdate.conflictColumns) == 0 || + len(updateColumns) == 0 { + d.logger.Warn( + "no conflict columns specified for on conflict do update, defaulting to on conflict do nothing", + ) onConflictDoNothing := true - insertQuery, args, err := BuildInsertQuery(d.driver, d.schema, d.table, goquRows, &onConflictDoNothing) + insertQuery, args, err := BuildInsertQuery( + d.driver, + d.schema, + d.table, + goquRows, + &onConflictDoNothing, + ) if err != nil { - return "", nil, fmt.Errorf("failed to build insert query on conflict do nothing fallback: %w", err) + return "", nil, fmt.Errorf( + "failed to build insert query on conflict do nothing fallback: %w", + err, + ) } return insertQuery, args, nil } - return d.buildInsertOnConflictDoUpdateQuery(goquRows, d.options.conflictConfig.onConflictDoUpdate.conflictColumns, updateColumns) + return d.buildInsertOnConflictDoUpdateQuery( + goquRows, + d.options.conflictConfig.onConflictDoUpdate.conflictColumns, + updateColumns, + ) } onConflictDoNothing := d.options.conflictConfig.onConflictDoNothing != nil - insertQuery, args, err := BuildInsertQuery(d.driver, d.schema, d.table, goquRows, &onConflictDoNothing) + insertQuery, args, err := BuildInsertQuery( + d.driver, + d.schema, + d.table, + goquRows, + &onConflictDoNothing, + ) if err != nil { return "", nil, fmt.Errorf("failed to build insert query: %w", err) } @@ -212,7 +238,9 @@ type MysqlDriver struct { options *InsertOptions } -func (d *MysqlDriver) BuildInsertQuery(rows []map[string]any) (query string, queryargs []any, err error) { +func (d *MysqlDriver) BuildInsertQuery( + rows []map[string]any, +) (query string, queryargs []any, err error) { goquRows := toGoquRecords(rows) if d.options.conflictConfig.onConflictDoUpdate != nil { @@ -235,7 +263,13 @@ func (d *MysqlDriver) BuildInsertQuery(rows []map[string]any) (query string, que } onConflictDoNothing := d.options.conflictConfig.onConflictDoNothing != nil - insertQuery, args, err := BuildInsertQuery(d.driver, d.schema, d.table, goquRows, &onConflictDoNothing) + insertQuery, args, err := BuildInsertQuery( + d.driver, + d.schema, + d.table, + goquRows, + &onConflictDoNothing, + ) if err != nil { return "", nil, err } @@ -279,7 +313,9 @@ type MssqlDriver struct { options *InsertOptions } -func (d *MssqlDriver) BuildInsertQuery(rows []map[string]any) (query string, queryargs []any, err error) { +func (d *MssqlDriver) BuildInsertQuery( + rows []map[string]any, +) (query string, queryargs []any, err error) { if len(rows) == 0 || areAllRowsEmpty(rows) { return getSqlServerDefaultValuesInsertSql(d.schema, d.table, len(rows)), []any{}, nil } @@ -287,7 +323,13 @@ func (d *MssqlDriver) BuildInsertQuery(rows []map[string]any) (query string, que goquRows := toGoquRecords(rows) onConflictDoNothing := d.options.conflictConfig.onConflictDoNothing != nil - insertQuery, args, err := BuildInsertQuery(d.driver, d.schema, d.table, goquRows, &onConflictDoNothing) + insertQuery, args, err := BuildInsertQuery( + d.driver, + d.schema, + d.table, + goquRows, + &onConflictDoNothing, + ) if err != nil { return "", nil, err } diff --git a/worker/pkg/select-query-builder/querybuilder.go b/worker/pkg/select-query-builder/querybuilder.go index c0d3e644fb..a9a3c1257f 100644 --- a/worker/pkg/select-query-builder/querybuilder.go +++ b/worker/pkg/select-query-builder/querybuilder.go @@ -27,7 +27,11 @@ type QueryBuilder struct { pageLimit uint } -func NewSelectQueryBuilder(defaultSchema, driver string, subsetByForeignKeyConstraints bool, pageLimit int) *QueryBuilder { +func NewSelectQueryBuilder( + defaultSchema, driver string, + subsetByForeignKeyConstraints bool, + pageLimit int, +) *QueryBuilder { limit := uint(0) if pageLimit > 0 { limit = uint(pageLimit) @@ -53,10 +57,15 @@ func (qb *QueryBuilder) getDialect() goqu.DialectWrapper { // BuildQuery constructs a SQL Select query from a RunConfig, returning the query string, // returns initial select and paged select queries, a flag indicating foreign key safety -func (qb *QueryBuilder) BuildQuery(runconfig *runconfigs.RunConfig) (sqlstatement string, args []any, pagesql string, isNotForeignKeySafeSubset bool, err error) { +func (qb *QueryBuilder) BuildQuery( + runconfig *runconfigs.RunConfig, +) (sqlstatement string, args []any, pagesql string, isNotForeignKeySafeSubset bool, err error) { query, pageQuery, notFkSafe, err := qb.buildFlattenedQuery(runconfig) if query == nil { - return "", nil, "", false, fmt.Errorf("received no error, but query was nil for %s", runconfig.Id()) + return "", nil, "", false, fmt.Errorf( + "received no error, but query was nil for %s", + runconfig.Id(), + ) } if err != nil { return "", nil, "", false, err @@ -64,12 +73,20 @@ func (qb *QueryBuilder) BuildQuery(runconfig *runconfigs.RunConfig) (sqlstatemen sql, args, err := query.Limit(qb.pageLimit).ToSQL() if err != nil { - return "", nil, "", false, fmt.Errorf("unable to convert structured query to string for %s: %w", runconfig.Id(), err) + return "", nil, "", false, fmt.Errorf( + "unable to convert structured query to string for %s: %w", + runconfig.Id(), + err, + ) } pageSql, _, err := pageQuery.Limit(qb.pageLimit).ToSQL() if err != nil { - return "", nil, "", false, fmt.Errorf("unable to convert structured page query to string for %s: %w", runconfig.Id(), err) + return "", nil, "", false, fmt.Errorf( + "unable to convert structured page query to string for %s: %w", + runconfig.Id(), + err, + ) } if qb.driver == sqlmanager_shared.MssqlDriver { // MSSQL TOP needs to be cast to int @@ -79,10 +96,14 @@ func (qb *QueryBuilder) BuildQuery(runconfig *runconfigs.RunConfig) (sqlstatemen } // buildFlattenedQuery builds the query for the root table, adding joins if needed. -func (qb *QueryBuilder) buildFlattenedQuery(rootTable *runconfigs.RunConfig) (sql, pageSql *goqu.SelectDataset, isNotForeignKeySafeSubset bool, err error) { +func (qb *QueryBuilder) buildFlattenedQuery( + rootTable *runconfigs.RunConfig, +) (sql, pageSql *goqu.SelectDataset, isNotForeignKeySafeSubset bool, err error) { dialect := qb.getDialect() rootAlias := rootTable.SchemaTable().Table - rootAliasExpression := goqu.S(rootTable.SchemaTable().Schema).Table(rootTable.SchemaTable().Table).As(rootAlias) + rootAliasExpression := goqu.S(rootTable.SchemaTable().Schema). + Table(rootTable.SchemaTable().Table). + As(rootAlias) query := dialect.From(rootAliasExpression) // Select columns for the root table @@ -124,7 +145,11 @@ func (qb *QueryBuilder) buildFlattenedQuery(rootTable *runconfigs.RunConfig) (sq } // buildPageQuery builds a pagination version of the query. -func (qb *QueryBuilder) buildPageQuery(query *goqu.SelectDataset, rootAlias string, orderByColumns []string) *goqu.SelectDataset { +func (qb *QueryBuilder) buildPageQuery( + query *goqu.SelectDataset, + rootAlias string, + orderByColumns []string, +) *goqu.SelectDataset { if len(orderByColumns) > 0 { // Build lexicographical ordering conditions var conditions []exp.Expression @@ -132,10 +157,16 @@ func (qb *QueryBuilder) buildPageQuery(query *goqu.SelectDataset, rootAlias stri var subConditions []exp.Expression // Add equality conditions for all columns before current for j := 0; j < i; j++ { - subConditions = append(subConditions, goqu.T(rootAlias).Col(orderByColumns[j]).Eq(goqu.L("?", 0))) + subConditions = append( + subConditions, + goqu.T(rootAlias).Col(orderByColumns[j]).Eq(goqu.L("?", 0)), + ) } // Add greater than condition for current column - subConditions = append(subConditions, goqu.T(rootAlias).Col(orderByColumns[i]).Gt(goqu.L("?", 0))) + subConditions = append( + subConditions, + goqu.T(rootAlias).Col(orderByColumns[i]).Gt(goqu.L("?", 0)), + ) conditions = append(conditions, goqu.And(subConditions...)) } query = query.Where(goqu.Or(conditions...)) @@ -145,7 +176,11 @@ func (qb *QueryBuilder) buildPageQuery(query *goqu.SelectDataset, rootAlias stri // addSubsetJoins adds joins to the query based on foreign key relationships defined in the subset paths. // returns the modified query, a boolean indicating if the subset is not foreign key safe. -func (qb *QueryBuilder) addSubsetJoins(query *goqu.SelectDataset, rootTable *runconfigs.RunConfig, rootAlias string) (*goqu.SelectDataset, bool, error) { +func (qb *QueryBuilder) addSubsetJoins( + query *goqu.SelectDataset, + rootTable *runconfigs.RunConfig, + rootAlias string, +) (*goqu.SelectDataset, bool, error) { subsets := rootTable.SubsetPaths() isSubset := false @@ -193,7 +228,9 @@ func (qb *QueryBuilder) addSubsetJoins(query *goqu.SelectDataset, rootTable *run // Build join conditions based on the foreign key. joinConditions := make([]exp.Expression, len(step.ForeignKey.Columns)) for i, col := range step.ForeignKey.Columns { - joinConditions[i] = goqu.T(childAlias).Col(step.ForeignKey.ReferenceColumns[i]).Eq(goqu.T(parentAlias).Col(col)) + joinConditions[i] = goqu.T(childAlias). + Col(step.ForeignKey.ReferenceColumns[i]). + Eq(goqu.T(parentAlias).Col(col)) } query = query.InnerJoin( goqu.I(childTable).As(childAlias), @@ -227,7 +264,10 @@ func getClippedHash(input string) string { return hex.EncodeToString(hash[:][:8]) } -func (qb *QueryBuilder) qualifyWhereCondition(schema *string, table, condition string) (string, error) { +func (qb *QueryBuilder) qualifyWhereCondition( + schema *string, + table, condition string, +) (string, error) { query := qb.getDialect().From(goqu.T(table)).Select(goqu.Star()).Where(goqu.L(condition)) sql, _, err := query.ToSQL() if err != nil { diff --git a/worker/pkg/select-query-builder/tsql/query-qualifier.go b/worker/pkg/select-query-builder/tsql/query-qualifier.go index 1eb5385606..32160e26c7 100644 --- a/worker/pkg/select-query-builder/tsql/query-qualifier.go +++ b/worker/pkg/select-query-builder/tsql/query-qualifier.go @@ -9,7 +9,8 @@ import ( ) /* - Updates columns names in where clause to be fully qualified +Updates columns names in where clause to be fully qualified + ex: SELECT * FROM users WHERE name = 'John' becomes SELECT * FROM users WHERE "users"."name" = 'John' To view query tree use @@ -44,7 +45,8 @@ import ( (comparison_operator =) (expression (primitive_expression - (primitive_constant 'John')))))))))))) ) + +(primitive_constant 'John')))))))))))) ) */ func QualifyWhereCondition(sql string) (string, error) { inputStream := antlr.NewInputStream(sql) @@ -85,7 +87,13 @@ func newTSqlErrorListener() *tSqlErrorListener { } } -func (l *tSqlErrorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol any, line, column int, msg string, e antlr.RecognitionException) { +func (l *tSqlErrorListener) SyntaxError( + recognizer antlr.Recognizer, + offendingSymbol any, + line, column int, + msg string, + e antlr.RecognitionException, +) { errorMessage := fmt.Sprintf("line %d:%d %s", line, column, msg) l.Errors = append(l.Errors, errorMessage) } diff --git a/worker/pkg/workflows/datasync/activities/account-status/activity.go b/worker/pkg/workflows/datasync/activities/account-status/activity.go index 9811dde736..59faa5189c 100644 --- a/worker/pkg/workflows/datasync/activities/account-status/activity.go +++ b/worker/pkg/workflows/datasync/activities/account-status/activity.go @@ -61,10 +61,13 @@ func (a *Activity) CheckAccountStatus( logger.Debug("checking account status") - resp, err := a.userclient.IsAccountStatusValid(ctx, connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{ - AccountId: req.AccountId, - RequestedRecordCount: req.RequestedRecordCount, - })) + resp, err := a.userclient.IsAccountStatusValid( + ctx, + connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{ + AccountId: req.AccountId, + RequestedRecordCount: req.RequestedRecordCount, + }), + ) if err != nil { return nil, fmt.Errorf("unable to retrieve account status: %w", err) } @@ -74,7 +77,11 @@ func (a *Activity) CheckAccountStatus( "reason", withReasonOrDefault(resp.Msg.GetReason()), ) - return &CheckAccountStatusResponse{IsValid: resp.Msg.GetIsValid(), Reason: resp.Msg.Reason, ShouldPoll: resp.Msg.GetShouldPoll()}, nil + return &CheckAccountStatusResponse{ + IsValid: resp.Msg.GetIsValid(), + Reason: resp.Msg.Reason, + ShouldPoll: resp.Msg.GetShouldPoll(), + }, nil } const defaultReason = "no reason provided" diff --git a/worker/pkg/workflows/datasync/activities/gen-benthos-configs/activity.go b/worker/pkg/workflows/datasync/activities/gen-benthos-configs/activity.go index bfbb1ff1be..96f90887d0 100644 --- a/worker/pkg/workflows/datasync/activities/gen-benthos-configs/activity.go +++ b/worker/pkg/workflows/datasync/activities/gen-benthos-configs/activity.go @@ -94,5 +94,13 @@ func (a *Activity) GenerateBenthosConfigs( a.pageLimit, ) slogger := temporallogger.NewSlogger(logger) - return bbuilder.GenerateBenthosConfigsNew(ctx, req, &workflowMetadata{WorkflowId: info.WorkflowExecution.ID, RunId: info.WorkflowExecution.RunID}, slogger) + return bbuilder.GenerateBenthosConfigsNew( + ctx, + req, + &workflowMetadata{ + WorkflowId: info.WorkflowExecution.ID, + RunId: info.WorkflowExecution.RunID, + }, + slogger, + ) } diff --git a/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder.go b/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder.go index 47c94f3a5b..02bedece1c 100644 --- a/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder.go +++ b/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder.go @@ -92,9 +92,17 @@ func (b *benthosBuilder) GenerateBenthosConfigsNew( destConnections := []*mgmtv1alpha1.Connection{} for _, destination := range job.Destinations { - destinationConnection, err := shared.GetConnectionById(ctx, b.connclient, destination.ConnectionId) + destinationConnection, err := shared.GetConnectionById( + ctx, + b.connclient, + destination.ConnectionId, + ) if err != nil { - return nil, fmt.Errorf("unable to get destination connection (%s) by id: %w", destination.ConnectionId, err) + return nil, fmt.Errorf( + "unable to get destination connection (%s) by id: %w", + destination.ConnectionId, + err, + ) } destConnections = append(destConnections, destinationConnection) } @@ -112,8 +120,10 @@ func (b *benthosBuilder) GenerateBenthosConfigsNew( SelectQueryBuilder: &selectquerybuilder.QueryMapBuilderWrapper{}, MetricsEnabled: b.metricsEnabled, MetricLabelKeyVals: map[string]string{ - metrics.TemporalWorkflowId: bb_shared.WithEnvInterpolation(metrics.TemporalWorkflowIdEnvKey), - metrics.TemporalRunId: bb_shared.WithEnvInterpolation(metrics.TemporalRunIdEnvKey), + metrics.TemporalWorkflowId: bb_shared.WithEnvInterpolation( + metrics.TemporalWorkflowIdEnvKey, + ), + metrics.TemporalRunId: bb_shared.WithEnvInterpolation(metrics.TemporalRunIdEnvKey), }, PageLimit: &b.pageLimit, } @@ -135,7 +145,10 @@ func (b *benthosBuilder) GenerateBenthosConfigsNew( postTableSyncRunCtx := buildPostTableSyncRunCtx(responses, job.Destinations) err = b.setPostTableSyncRunCtx(ctx, postTableSyncRunCtx, job.GetAccountId()) if err != nil { - return nil, fmt.Errorf("unable to set all run contexts for post table sync configs: %w", err) + return nil, fmt.Errorf( + "unable to set all run contexts for post table sync configs: %w", + err, + ) } outputConfigs, err := b.setRunContexts(ctx, responses, job.GetAccountId()) @@ -210,7 +223,10 @@ func (b *benthosBuilder) setRunContexts( _, err := rcstream.CloseAndReceive() if err != nil { - return nil, fmt.Errorf("unable to receive response from benthos runcontext request: %w", err) + return nil, fmt.Errorf( + "unable to receive response from benthos runcontext request: %w", + err, + ) } return responses, nil } @@ -242,7 +258,10 @@ func (b *benthosBuilder) setPostTableSyncRunCtx( _, err := rcstream.CloseAndReceive() if err != nil { - return fmt.Errorf("unable to receive response from post table sync runcontext request: %w", err) + return fmt.Errorf( + "unable to receive response from post table sync runcontext request: %w", + err, + ) } return nil } @@ -261,7 +280,10 @@ func (b *benthosBuilder) getJobById( return getjobResp.Msg.Job, nil } -func buildPostTableSyncRunCtx(benthosConfigs []*benthosbuilder.BenthosConfigResponse, destinations []*mgmtv1alpha1.JobDestination) map[string]*shared.PostTableSyncConfig { +func buildPostTableSyncRunCtx( + benthosConfigs []*benthosbuilder.BenthosConfigResponse, + destinations []*mgmtv1alpha1.JobDestination, +) map[string]*shared.PostTableSyncConfig { postTableSyncRunCtx := map[string]*shared.PostTableSyncConfig{} // benthos_config_name -> config for _, bc := range benthosConfigs { destConfigs := map[string]*shared.PostTableSyncDestConfig{} @@ -297,7 +319,11 @@ func buildPgPostTableSyncStatement(bc *benthosbuilder.BenthosConfigResponse) []s for colName, p := range colDefaultProps { if p.NeedsReset && !p.HasDefaultTransformer { // resets sequences and identities - resetSql := sqlmanager_postgres.BuildPgIdentityColumnResetCurrentSql(bc.TableSchema, bc.TableName, colName) + resetSql := sqlmanager_postgres.BuildPgIdentityColumnResetCurrentSql( + bc.TableSchema, + bc.TableName, + colName, + ) statements = append(statements, resetSql) } } @@ -313,7 +339,10 @@ func buildMssqlPostTableSyncStatement(bc *benthosbuilder.BenthosConfigResponse) for _, p := range colDefaultProps { if p.NeedsOverride { // reset identity - resetSql := sqlmanager_mssql.BuildMssqlIdentityColumnResetCurrent(bc.TableSchema, bc.TableName) + resetSql := sqlmanager_mssql.BuildMssqlIdentityColumnResetCurrent( + bc.TableSchema, + bc.TableName, + ) statements = append(statements, resetSql) } } diff --git a/worker/pkg/workflows/datasync/activities/jobhooks-by-timing/activity.go b/worker/pkg/workflows/datasync/activities/jobhooks-by-timing/activity.go index 6528f6d9b9..183bc447b8 100644 --- a/worker/pkg/workflows/datasync/activities/jobhooks-by-timing/activity.go +++ b/worker/pkg/workflows/datasync/activities/jobhooks-by-timing/activity.go @@ -34,7 +34,12 @@ func New( sqlmanagerclient sqlmanager.SqlManagerClient, license License, ) *Activity { - return &Activity{jobclient: jobclient, connclient: connclient, sqlmanagerclient: sqlmanagerclient, license: license} + return &Activity{ + jobclient: jobclient, + connclient: connclient, + sqlmanagerclient: sqlmanagerclient, + license: license, + } } type RunJobHooksByTimingRequest struct { @@ -85,10 +90,13 @@ func (a *Activity) RunJobHooksByTiming( logger.Debug(fmt.Sprintf("retrieving job hooks by timing %q", req.Timing)) - resp, err := a.jobclient.GetActiveJobHooksByTiming(ctx, connect.NewRequest(&mgmtv1alpha1.GetActiveJobHooksByTimingRequest{ - JobId: req.JobId, - Timing: req.Timing, - })) + resp, err := a.jobclient.GetActiveJobHooksByTiming( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetActiveJobHooksByTimingRequest{ + JobId: req.JobId, + Timing: req.Timing, + }), + ) if err != nil { return nil, fmt.Errorf("unable to retrieve active hooks by timing: %w", err) } @@ -102,7 +110,9 @@ func (a *Activity) RunJobHooksByTiming( } }() - session := connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(activityInfo.WorkflowExecution.ID)) + session := connectionmanager.NewUniqueSession( + connectionmanager.WithSessionGroup(activityInfo.WorkflowExecution.ID), + ) execCount := uint(0) for _, hook := range hooks { @@ -147,9 +157,12 @@ func (a *Activity) getCachedConnectionFn( return conn.Db(), nil } logger.Debug("initializing connection for hook") - connectionResp, err := a.connclient.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: connectionId, - })) + connectionResp, err := a.connclient.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: connectionId, + }), + ) if err != nil { return nil, err } diff --git a/worker/pkg/workflows/datasync/activities/post-table-sync/activity.go b/worker/pkg/workflows/datasync/activities/post-table-sync/activity.go index fb7ce15f57..61ae4c451e 100644 --- a/worker/pkg/workflows/datasync/activities/post-table-sync/activity.go +++ b/worker/pkg/workflows/datasync/activities/post-table-sync/activity.go @@ -59,7 +59,9 @@ func (a *Activity) RunPostTableSync( req *RunPostTableSyncRequest, ) (*RunPostTableSyncResponse, error) { activityInfo := activity.GetInfo(ctx) - session := connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(activityInfo.WorkflowExecution.ID)) + session := connectionmanager.NewUniqueSession( + connectionmanager.WithSessionGroup(activityInfo.WorkflowExecution.ID), + ) externalId := shared.GetPostTableSyncConfigExternalId(req.Name) loggerKeyVals := []any{ "accountId", req.AccountId, @@ -74,13 +76,16 @@ func (a *Activity) RunPostTableSync( logger.Debug("running post table sync activity") slogger := temporallogger.NewSlogger(logger) - rcResp, err := a.jobclient.GetRunContext(ctx, connect.NewRequest(&mgmtv1alpha1.GetRunContextRequest{ - Id: &mgmtv1alpha1.RunContextKey{ - JobRunId: activityInfo.WorkflowExecution.ID, - ExternalId: externalId, - AccountId: req.AccountId, - }, - })) + rcResp, err := a.jobclient.GetRunContext( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetRunContextRequest{ + Id: &mgmtv1alpha1.RunContextKey{ + JobRunId: activityInfo.WorkflowExecution.ID, + ExternalId: externalId, + AccountId: req.AccountId, + }, + }), + ) if err != nil && runContextNotFound(err) { slogger.Debug("no runcontext found. continuing") return nil, nil @@ -97,7 +102,11 @@ func (a *Activity) RunPostTableSync( var config *shared.PostTableSyncConfig err = json.Unmarshal(configBits, &config) if err != nil { - return nil, fmt.Errorf("unable to unmarshal posttablesync runcontext for %s: %w", req.Name, err) + return nil, fmt.Errorf( + "unable to unmarshal posttablesync runcontext for %s: %w", + req.Name, + err, + ) } if len(config.DestinationConfigs) == 0 { @@ -114,13 +123,21 @@ func (a *Activity) RunPostTableSync( errors := []*PostTableSyncError{} for destConnectionId, destCfg := range config.DestinationConfigs { - slogger.Debug(fmt.Sprintf("found %d post table sync statements", len(destCfg.Statements)), "destinationConnectionId", destConnectionId) + slogger.Debug( + fmt.Sprintf("found %d post table sync statements", len(destCfg.Statements)), + "destinationConnectionId", + destConnectionId, + ) if len(destCfg.Statements) == 0 { continue } destinationConnection, err := shared.GetConnectionById(ctx, a.connclient, destConnectionId) if err != nil { - return nil, fmt.Errorf("unable to get destination connection (%s) by id: %w", destConnectionId, err) + return nil, fmt.Errorf( + "unable to get destination connection (%s) by id: %w", + destConnectionId, + err, + ) } execErrors := &PostTableSyncError{ ConnectionId: destConnectionId, @@ -161,5 +178,6 @@ func runContextNotFound(err error) bool { if ok && connectErr.Code() == connect.CodeNotFound { return true } - return strings.Contains(err.Error(), "unable to find key") || strings.Contains(err.Error(), "no run context exists with the provided key") + return strings.Contains(err.Error(), "unable to find key") || + strings.Contains(err.Error(), "no run context exists with the provided key") } diff --git a/worker/pkg/workflows/datasync/activities/shared/shared.go b/worker/pkg/workflows/datasync/activities/shared/shared.go index a61010962e..64159a793d 100644 --- a/worker/pkg/workflows/datasync/activities/shared/shared.go +++ b/worker/pkg/workflows/datasync/activities/shared/shared.go @@ -183,7 +183,9 @@ func GetRedisConfig() *neosync_redis.RedisConfig { } } -func BuildBenthosRedisTlsConfig(redisConfig *neosync_redis.RedisConfig) *neosync_benthos.RedisTlsConfig { +func BuildBenthosRedisTlsConfig( + redisConfig *neosync_redis.RedisConfig, +) *neosync_benthos.RedisTlsConfig { var tls *neosync_benthos.RedisTlsConfig if redisConfig.Tls != nil && redisConfig.Tls.Enabled { tls = &neosync_benthos.RedisTlsConfig{ @@ -251,9 +253,12 @@ func GetConnectionById( connclient mgmtv1alpha1connect.ConnectionServiceClient, connectionId string, ) (*mgmtv1alpha1.Connection, error) { - getConnResp, err := connclient.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: connectionId, - })) + getConnResp, err := connclient.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: connectionId, + }), + ) if err != nil { return nil, err } diff --git a/worker/pkg/workflows/datasync/activities/sync-activity-opts/activity.go b/worker/pkg/workflows/datasync/activities/sync-activity-opts/activity.go index 482dae04be..78528890a6 100644 --- a/worker/pkg/workflows/datasync/activities/sync-activity-opts/activity.go +++ b/worker/pkg/workflows/datasync/activities/sync-activity-opts/activity.go @@ -53,7 +53,10 @@ func (a *Activity) RetrieveActivityOptions( ) logger.Debug("retrieving activity options") - jobResp, err := a.jobclient.GetJob(ctx, connect.NewRequest(&mgmtv1alpha1.GetJobRequest{Id: req.JobId})) + jobResp, err := a.jobclient.GetJob( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetJobRequest{Id: req.JobId}), + ) if err != nil { return nil, fmt.Errorf("unable to get job by id: %w", err) } @@ -129,10 +132,14 @@ func getSyncActivityOptionsFromJob(job *mgmtv1alpha1.Job) *workflow.ActivityOpti } if job.SyncOptions != nil { if job.SyncOptions.StartToCloseTimeout != nil { - syncActivityOptions.StartToCloseTimeout = time.Duration(*job.SyncOptions.StartToCloseTimeout) + syncActivityOptions.StartToCloseTimeout = time.Duration( + *job.SyncOptions.StartToCloseTimeout, + ) } if job.SyncOptions.ScheduleToCloseTimeout != nil { - syncActivityOptions.ScheduleToCloseTimeout = time.Duration(*job.SyncOptions.ScheduleToCloseTimeout) + syncActivityOptions.ScheduleToCloseTimeout = time.Duration( + *job.SyncOptions.ScheduleToCloseTimeout, + ) } if job.SyncOptions.RetryPolicy != nil { if job.SyncOptions.RetryPolicy.MaximumAttempts != nil { @@ -150,7 +157,8 @@ func getSyncActivityOptionsFromJob(job *mgmtv1alpha1.Job) *workflow.ActivityOpti }, } } - if syncActivityOptions.StartToCloseTimeout == 0 && syncActivityOptions.ScheduleToCloseTimeout == 0 { + if syncActivityOptions.StartToCloseTimeout == 0 && + syncActivityOptions.ScheduleToCloseTimeout == 0 { syncActivityOptions.StartToCloseTimeout = defaultStartCloseTimeout } if syncActivityOptions.RetryPolicy == nil { diff --git a/worker/pkg/workflows/datasync/workflow/register/register.go b/worker/pkg/workflows/datasync/workflow/register/register.go index 8202b7f1aa..2f42da7fd9 100644 --- a/worker/pkg/workflows/datasync/workflow/register/register.go +++ b/worker/pkg/workflows/datasync/workflow/register/register.go @@ -47,7 +47,12 @@ func Register( retrieveActivityOpts := syncactivityopts_activity.New(jobclient, postgresSchemaDrift) accountStatusActivity := accountstatus_activity.New(userclient) runPostTableSyncActivity := posttablesync_activity.New(jobclient, sqlmanager, connclient) - jobhookByTimingActivity := jobhooks_by_timing_activity.New(jobclient, connclient, sqlmanager, eelicense) + jobhookByTimingActivity := jobhooks_by_timing_activity.New( + jobclient, + connclient, + sqlmanager, + eelicense, + ) redisCleanUpActivity := syncrediscleanup_activity.New(redisclient) wf := datasync_workflow.New(eelicense) diff --git a/worker/pkg/workflows/datasync/workflow/workflow.go b/worker/pkg/workflows/datasync/workflow/workflow.go index b5c260b902..5b9a8e77a3 100644 --- a/worker/pkg/workflows/datasync/workflow/workflow.go +++ b/worker/pkg/workflows/datasync/workflow/workflow.go @@ -127,12 +127,19 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes err = workflow.ExecuteActivity( withCheckAccountStatusActivityOptions(ctx), a.CheckAccountStatus, - &accountstatus_activity.CheckAccountStatusRequest{AccountId: actOptResp.AccountId, RequestedRecordCount: actOptResp.RequestedRecordCount}). + &accountstatus_activity.CheckAccountStatusRequest{ + AccountId: actOptResp.AccountId, + RequestedRecordCount: actOptResp.RequestedRecordCount, + }, + ). Get(ctx, &initialCheckAccountStatusResponse) if err != nil { logger.Error("encountered error while checking account status", "error", err) cancelHandler() - return nil, fmt.Errorf("unable to continue workflow due to error when checking account status: %w", err) + return nil, fmt.Errorf( + "unable to continue workflow due to error when checking account status: %w", + err, + ) } if !initialCheckAccountStatusResponse.IsValid { logger.Warn("account is no longer is valid state") @@ -141,7 +148,11 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes if initialCheckAccountStatusResponse.Reason != nil { reason = *initialCheckAccountStatusResponse.Reason } - return nil, fmt.Errorf("halting job run due to account in invalid state. Reason: %q: %w", reason, errInvalidAccountStatusError) + return nil, fmt.Errorf( + "halting job run due to account in invalid state. Reason: %q: %w", + reason, + errInvalidAccountStatusError, + ) } info := workflow.GetInfo(ctx) @@ -165,12 +176,27 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes return &WorkflowResponse{}, nil } - err = execRunJobHooksByTiming(ctx, &jobhooks_by_timing_activity.RunJobHooksByTimingRequest{JobId: req.JobId, Timing: mgmtv1alpha1.GetActiveJobHooksByTimingRequest_TIMING_PRESYNC}, logger) + err = execRunJobHooksByTiming( + ctx, + &jobhooks_by_timing_activity.RunJobHooksByTimingRequest{ + JobId: req.JobId, + Timing: mgmtv1alpha1.GetActiveJobHooksByTimingRequest_TIMING_PRESYNC, + }, + logger, + ) if err != nil { return nil, err } - err = runSchemaInitWorkflowByDestination(ctx, logger, actOptResp.AccountId, req.JobId, info.WorkflowExecution.ID, actOptResp.Destinations, actOptResp.PostgresSchemaDrift) + err = runSchemaInitWorkflowByDestination( + ctx, + logger, + actOptResp.AccountId, + req.JobId, + info.WorkflowExecution.ID, + actOptResp.Destinations, + actOptResp.PostgresSchemaDrift, + ) if err != nil { return nil, err } @@ -199,10 +225,17 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes err = workflow.ExecuteActivity( withCheckAccountStatusActivityOptions(ctx), a.CheckAccountStatus, - &accountstatus_activity.CheckAccountStatusRequest{AccountId: actOptResp.AccountId}). + &accountstatus_activity.CheckAccountStatusRequest{ + AccountId: actOptResp.AccountId, + }, + ). Get(ctx, &result) if err != nil { - logger.Error("encountered error while checking account status", "error", err) + logger.Error( + "encountered error while checking account status", + "error", + err, + ) stopChan.Send(ctx, true) shouldStop = true cancelHandler() @@ -224,7 +257,11 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes return } if ctx.Err() != nil { - logger.Warn("workflow canceled due to error or stop signal", "error", ctx.Err()) + logger.Warn( + "workflow canceled due to error or stop signal", + "error", + ctx.Err(), + ) return } } @@ -253,7 +290,15 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes completed := sync.Map{} executeSyncActivity := func(bc *benthosbuilder.BenthosConfigResponse, logger log.Logger) { - future := invokeSync(bc, ctx, &started, &completed, logger, &bcResp.AccountId, actOptResp.SyncActivityOptions) + future := invokeSync( + bc, + ctx, + &started, + &completed, + logger, + &bcResp.AccountId, + actOptResp.SyncActivityOptions, + ) inFlight++ workselector.AddFuture(future, func(f workflow.Future) { var wfResult tablesync_workflow.TableSyncResponse @@ -266,7 +311,12 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes cancelHandler() detachedCtx, _ := workflow.NewDisconnectedContext(ctx) - redisErr := runRedisCleanUpActivity(detachedCtx, logger, req.JobId, bcResp.BenthosConfigs) + redisErr := runRedisCleanUpActivity( + detachedCtx, + logger, + req.JobId, + bcResp.BenthosConfigs, + ) if redisErr != nil { logger.Error("redis clean up activity did not complete") } @@ -275,7 +325,13 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes logger.Info("config sync completed", "name", bc.Name) err = runPostTableSyncActivity(ctx, logger, actOptResp, bc.Name) if err != nil { - logger.Error(fmt.Sprintf("post table sync activity did not complete: %s", err.Error()), "schema", bc.TableSchema, "table", bc.TableName) + logger.Error( + fmt.Sprintf("post table sync activity did not complete: %s", err.Error()), + "schema", + bc.TableSchema, + "table", + bc.TableName, + ) } }) } @@ -300,7 +356,10 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes if ctx.Err() != nil { if errors.Is(ctx.Err(), context.Canceled) { - return nil, fmt.Errorf("workflow canceled due to error or stop signal: %w", ctx.Err()) + return nil, fmt.Errorf( + "workflow canceled due to error or stop signal: %w", + ctx.Err(), + ) } return nil, ctx.Err() } @@ -321,7 +380,10 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes if ctx.Err() != nil { if errors.Is(ctx.Err(), context.Canceled) { - return nil, fmt.Errorf("workflow canceled due to error or stop signal: %w", ctx.Err()) + return nil, fmt.Errorf( + "workflow canceled due to error or stop signal: %w", + ctx.Err(), + ) } return nil, fmt.Errorf("exiting workflow in root sync due to err: %w", ctx.Err()) } @@ -330,7 +392,10 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes for _, bc := range splitConfigs.Dependents { if ctx.Err() != nil { if errors.Is(ctx.Err(), context.Canceled) { - return nil, fmt.Errorf("workflow canceled due to error or stop signal: %w", ctx.Err()) + return nil, fmt.Errorf( + "workflow canceled due to error or stop signal: %w", + ctx.Err(), + ) } return nil, fmt.Errorf("exiting workflow in dependent sync due err: %w", ctx.Err()) } @@ -349,16 +414,24 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes // Ensures concurrency limits are respected. if inFlight >= maxConcurrency { - logger.Debug("max concurrency reached; blocking until one sync finishes for a dependent") + logger.Debug( + "max concurrency reached; blocking until one sync finishes for a dependent", + ) workselector.Select(ctx) if activityErr != nil { return nil, activityErr } if ctx.Err() != nil { if errors.Is(ctx.Err(), context.Canceled) { - return nil, fmt.Errorf("workflow canceled due to error or stop signal: %w", ctx.Err()) + return nil, fmt.Errorf( + "workflow canceled due to error or stop signal: %w", + ctx.Err(), + ) } - return nil, fmt.Errorf("exiting workflow in dependent sync due to err: %w", ctx.Err()) + return nil, fmt.Errorf( + "exiting workflow in dependent sync due to err: %w", + ctx.Err(), + ) } } @@ -368,7 +441,14 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes logger.Info("data syncs completed") - err = execRunJobHooksByTiming(ctx, &jobhooks_by_timing_activity.RunJobHooksByTimingRequest{JobId: req.JobId, Timing: mgmtv1alpha1.GetActiveJobHooksByTimingRequest_TIMING_POSTSYNC}, logger) + err = execRunJobHooksByTiming( + ctx, + &jobhooks_by_timing_activity.RunJobHooksByTimingRequest{ + JobId: req.JobId, + Timing: mgmtv1alpha1.GetActiveJobHooksByTimingRequest_TIMING_POSTSYNC, + }, + logger, + ) if err != nil { return nil, err } @@ -382,7 +462,11 @@ func executeWorkflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowRes return &WorkflowResponse{}, nil } -func execRunJobHooksByTiming(ctx workflow.Context, req *jobhooks_by_timing_activity.RunJobHooksByTimingRequest, logger log.Logger) error { +func execRunJobHooksByTiming( + ctx workflow.Context, + req *jobhooks_by_timing_activity.RunJobHooksByTimingRequest, + logger log.Logger, +) error { logger.Info(fmt.Sprintf("scheduling %q RunJobHooksByTiming for execution", req.Timing)) var resp *jobhooks_by_timing_activity.RunJobHooksByTimingResponse var timingActivity *jobhooks_by_timing_activity.Activity @@ -415,7 +499,11 @@ func runSchemaInitWorkflowByDestination( for _, destination := range destinations { // right now only mysql supports schema drift schemaDrift := shouldUseSchemaDrift(destination, postgresSchemaDrift) - logger.Info("scheduling Schema Initialization workflow for execution.", "destinationId", destination.GetId()) + logger.Info( + "scheduling Schema Initialization workflow for execution.", + "destinationId", + destination.GetId(), + ) siWf := &schemainit_workflow.Workflow{} var wfResult schemainit_workflow.SchemaInitResponse id := fmt.Sprintf("init-schema-%s", destination.GetId()) @@ -432,11 +520,16 @@ func runSchemaInitWorkflowByDestination( JobRunId: jobRunId, DestinationId: destination.GetId(), UseSchemaDrift: schemaDrift, - }).Get(ctx, &wfResult) + }). + Get(ctx, &wfResult) if err != nil { return err } - logger.Info("completed Schema Initialization workflow.", "destinationId", destination.GetId()) + logger.Info( + "completed Schema Initialization workflow.", + "destinationId", + destination.GetId(), + ) } return nil } @@ -596,7 +689,8 @@ func invokeSync( TableSchema: config.TableSchema, TableName: config.TableName, ColumnIdentityCursors: config.ColumnIdentityCursors, - }).Get(ctx, &wfResult) + }). + Get(ctx, &wfResult) if err == nil { tn := neosync_benthos.BuildBenthosTable(config.TableSchema, config.TableName) err = updateCompletedMap(tn, completed, config.Columns) @@ -614,7 +708,10 @@ func updateCompletedMap(tableName string, completed *sync.Map, columns []string) if loaded { currCols, ok := val.([]string) if !ok { - return fmt.Errorf("unable to retrieve completed columns from completed map. Expected []string, received: %T", val) + return fmt.Errorf( + "unable to retrieve completed columns from completed map. Expected []string, received: %T", + val, + ) } currCols = append(currCols, columns...) completed.Store(tableName, currCols) @@ -642,7 +739,10 @@ func toStringSliceMap(m *sync.Map) (map[string][]string, error) { return result, typeErr } -func isConfigReady(config *benthosbuilder.BenthosConfigResponse, completed *sync.Map) (bool, error) { +func isConfigReady( + config *benthosbuilder.BenthosConfigResponse, + completed *sync.Map, +) (bool, error) { if completed == nil { return false, fmt.Errorf("completed map is nil: cannot determine if config is ready") } diff --git a/worker/pkg/workflows/ee/account_hooks/activities/execute/activity.go b/worker/pkg/workflows/ee/account_hooks/activities/execute/activity.go index 3ceb67b177..bd33eb81e4 100644 --- a/worker/pkg/workflows/ee/account_hooks/activities/execute/activity.go +++ b/worker/pkg/workflows/ee/account_hooks/activities/execute/activity.go @@ -73,9 +73,12 @@ func (a *Activity) ExecuteAccountHook( slogger.Debug("retrieving hook") - resp, err := a.accounthookclient.GetAccountHook(ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountHookRequest{ - Id: req.HookId, - })) + resp, err := a.accounthookclient.GetAccountHook( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetAccountHookRequest{ + Id: req.HookId, + }), + ) if err != nil { return nil, fmt.Errorf("unable to retrieve hook: %w", err) } @@ -133,7 +136,13 @@ func executeWebhook( event *accounthook_events.Event, logger *slog.Logger, ) error { - logger.Debug(fmt.Sprintf("webhook url: %s, skipVerify: %t", webhook.GetUrl(), webhook.GetDisableSslVerification())) + logger.Debug( + fmt.Sprintf( + "webhook url: %s, skipVerify: %t", + webhook.GetUrl(), + webhook.GetDisableSslVerification(), + ), + ) jsonPayload, err := getPayload(event) if err != nil { return fmt.Errorf("unable to get payload: %w", err) @@ -169,7 +178,12 @@ func executeWebhookRequest( signature string, skipSslVerification bool, ) error { - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonPayload)) + httpReq, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + url, + bytes.NewBuffer(jsonPayload), + ) if err != nil { return fmt.Errorf("unable to create webhook request: %w", err) } @@ -185,7 +199,9 @@ func executeWebhookRequest( } if skipSslVerification { client.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // we want to enable this if it's user specified + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, //nolint:gosec // we want to enable this if it's user specified } } resp, err := client.Do(httpReq) @@ -196,7 +212,11 @@ func executeWebhookRequest( if resp.StatusCode < 200 || resp.StatusCode >= 300 { body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("webhook request failed with status %d: %s", resp.StatusCode, string(body)) + return fmt.Errorf( + "webhook request failed with status %d: %s", + resp.StatusCode, + string(body), + ) } return nil diff --git a/worker/pkg/workflows/ee/account_hooks/activities/hooks-by-event/activity.go b/worker/pkg/workflows/ee/account_hooks/activities/hooks-by-event/activity.go index c4844a4ed3..c4d50117ea 100644 --- a/worker/pkg/workflows/ee/account_hooks/activities/hooks-by-event/activity.go +++ b/worker/pkg/workflows/ee/account_hooks/activities/hooks-by-event/activity.go @@ -62,10 +62,13 @@ func (a *Activity) GetAccountHooksByEvent( slogger.Debug("retrieving hooks by event") - resp, err := a.accounthookclient.GetActiveAccountHooksByEvent(ctx, connect.NewRequest(&mgmtv1alpha1.GetActiveAccountHooksByEventRequest{ - AccountId: req.AccountId, - Event: req.EventName, - })) + resp, err := a.accounthookclient.GetActiveAccountHooksByEvent( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetActiveAccountHooksByEventRequest{ + AccountId: req.AccountId, + Event: req.EventName, + }), + ) if err != nil { return nil, fmt.Errorf("unable to retrieve active hooks by event: %w", err) } diff --git a/worker/pkg/workflows/ee/account_hooks/workflow/workflow.go b/worker/pkg/workflows/ee/account_hooks/workflow/workflow.go index 948f36127c..c4a03ead2a 100644 --- a/worker/pkg/workflows/ee/account_hooks/workflow/workflow.go +++ b/worker/pkg/workflows/ee/account_hooks/workflow/workflow.go @@ -18,7 +18,10 @@ type ProcessAccountHookRequest struct { type ProcessAccountHookResponse struct{} -func ProcessAccountHook(wfctx workflow.Context, req *ProcessAccountHookRequest) (*ProcessAccountHookResponse, error) { +func ProcessAccountHook( + wfctx workflow.Context, + req *ProcessAccountHookRequest, +) (*ProcessAccountHookResponse, error) { var hooksByEventActivity *hooks_by_event_activity.Activity var resp *hooks_by_event_activity.RunHooksByEventResponse err := workflow.ExecuteActivity( diff --git a/worker/pkg/workflows/ee/piidetect/workflows/job/activities/activities.go b/worker/pkg/workflows/ee/piidetect/workflows/job/activities/activities.go index 8785499445..1f313913ba 100644 --- a/worker/pkg/workflows/ee/piidetect/workflows/job/activities/activities.go +++ b/worker/pkg/workflows/ee/piidetect/workflows/job/activities/activities.go @@ -52,7 +52,10 @@ type GetPiiDetectJobDetailsResponse struct { } // Used to retrieve information about the PII Detect job to funnel in to the remaining workflow. -func (a *Activities) GetPiiDetectJobDetails(ctx context.Context, req *GetPiiDetectJobDetailsRequest) (*GetPiiDetectJobDetailsResponse, error) { +func (a *Activities) GetPiiDetectJobDetails( + ctx context.Context, + req *GetPiiDetectJobDetailsRequest, +) (*GetPiiDetectJobDetailsResponse, error) { logger := log.With(activity.GetLogger(ctx), "jobId", req.JobId) jobResp, err := a.jobclient.GetJob(ctx, connect.NewRequest(&mgmtv1alpha1.GetJobRequest{ @@ -96,16 +99,28 @@ type GetLastSuccessfulWorkflowIdResponse struct { // Used to retrieve the last successful workflow run from a job's schedule // This is used for incremental PII syncs in order to find the last result set of scanned PII. // This is then funneled in to computing the diff of what tables have changed and should be rescanned. -func (a *Activities) GetLastSuccessfulWorkflowId(ctx context.Context, req *GetLastSuccessfulWorkflowIdRequest) (*GetLastSuccessfulWorkflowIdResponse, error) { +func (a *Activities) GetLastSuccessfulWorkflowId( + ctx context.Context, + req *GetLastSuccessfulWorkflowIdRequest, +) (*GetLastSuccessfulWorkflowIdResponse, error) { logger := log.With(activity.GetLogger(ctx), "accountId", req.AccountId, "jobId", req.JobId) - workflowIds, err := getRecentRunsFromHandle(ctx, a.tmprlScheduleClient.GetHandle(ctx, req.JobId)) + workflowIds, err := getRecentRunsFromHandle( + ctx, + a.tmprlScheduleClient.GetHandle(ctx, req.JobId), + ) if err != nil { logger.Error("unable to get recent runs from handle", "error", err) return &GetLastSuccessfulWorkflowIdResponse{WorkflowId: nil}, nil } logger.Debug("retrieved workflow ids", "workflowIds", workflowIds) - lastSuccessfulRun, err := a.getMostRecentSuccessfulRun(ctx, req.AccountId, req.JobId, workflowIds, logger) + lastSuccessfulRun, err := a.getMostRecentSuccessfulRun( + ctx, + req.AccountId, + req.JobId, + workflowIds, + logger, + ) if err != nil { logger.Error("unable to get most recent successful run", "error", err) return &GetLastSuccessfulWorkflowIdResponse{WorkflowId: nil}, nil @@ -116,7 +131,12 @@ func (a *Activities) GetLastSuccessfulWorkflowId(ctx context.Context, req *GetLa return &GetLastSuccessfulWorkflowIdResponse{WorkflowId: &lastSuccessfulRun}, nil } -func (a *Activities) getMostRecentSuccessfulRun(ctx context.Context, accountId, jobId string, workflowIds []string, logger log.Logger) (string, error) { +func (a *Activities) getMostRecentSuccessfulRun( + ctx context.Context, + accountId, jobId string, + workflowIds []string, + logger log.Logger, +) (string, error) { for _, workflowId := range workflowIds { jobReport, found, err := a.getJobPiiDetectReport(ctx, accountId, workflowId, jobId) if err != nil { @@ -186,7 +206,10 @@ type GetTablesToPiiScanResponse struct { // This retrieves all tables from the source connection and filters them based on the user provided filter. // If an incremental config is provided, it will also filter the tables based on the previous successful workflow run. -func (a *Activities) GetTablesToPiiScan(ctx context.Context, req *GetTablesToPiiScanRequest) (*GetTablesToPiiScanResponse, error) { +func (a *Activities) GetTablesToPiiScan( + ctx context.Context, + req *GetTablesToPiiScanRequest, +) (*GetTablesToPiiScanResponse, error) { logger := log.With(activity.GetLogger(ctx), "sourceConnectionId", req.SourceConnectionId) slogger := temporallogger.NewSlogger(logger) @@ -204,15 +227,30 @@ func (a *Activities) GetTablesToPiiScan(ctx context.Context, req *GetTablesToPii var previousReports map[TableIdentifier]*TableReport if req.IncrementalConfig != nil { logger.Debug("getting tables from previous run to further filter tables") - tableReports, err := a.getTableReportsFromPreviousRun(ctx, req.AccountId, req.IncrementalConfig.LastWorkflowId, req.JobId, logger) + tableReports, err := a.getTableReportsFromPreviousRun( + ctx, + req.AccountId, + req.IncrementalConfig.LastWorkflowId, + req.JobId, + logger, + ) if err != nil { return nil, fmt.Errorf("unable to get tables from previous run: %w", err) } previousReports = tableReports oldTableCount := len(filteredTables) - filteredTables = filterTablesByFingerprint(filteredTables, getFingerprintsFromReports(tableReports)) + filteredTables = filterTablesByFingerprint( + filteredTables, + getFingerprintsFromReports(tableReports), + ) newTableCount := len(filteredTables) - logger.Debug("filtered tables in incremental scan", "oldTableCount", oldTableCount, "newTableCount", newTableCount) + logger.Debug( + "filtered tables in incremental scan", + "oldTableCount", + oldTableCount, + "newTableCount", + newTableCount, + ) } previousReportsArray := make([]*TableReport, 0, len(previousReports)) @@ -220,10 +258,15 @@ func (a *Activities) GetTablesToPiiScan(ctx context.Context, req *GetTablesToPii previousReportsArray = append(previousReportsArray, report) } - return &GetTablesToPiiScanResponse{Tables: filteredTables, PreviousReports: previousReportsArray}, nil + return &GetTablesToPiiScanResponse{ + Tables: filteredTables, + PreviousReports: previousReportsArray, + }, nil } -func getFingerprintsFromReports(reports map[TableIdentifier]*TableReport) map[TableIdentifier]string { +func getFingerprintsFromReports( + reports map[TableIdentifier]*TableReport, +) map[TableIdentifier]string { fingerprints := make(map[TableIdentifier]string) for identifier, report := range reports { fingerprints[identifier] = report.ScanFingerprint @@ -231,7 +274,10 @@ func getFingerprintsFromReports(reports map[TableIdentifier]*TableReport) map[Ta return fingerprints } -func filterTablesByFingerprint(tables []TableIdentifierWithFingerprint, fingerprints map[TableIdentifier]string) []TableIdentifierWithFingerprint { +func filterTablesByFingerprint( + tables []TableIdentifierWithFingerprint, + fingerprints map[TableIdentifier]string, +) []TableIdentifierWithFingerprint { filteredTables := []TableIdentifierWithFingerprint{} for _, table := range tables { fingerprint, ok := fingerprints[table.TableIdentifier] @@ -248,7 +294,11 @@ func filterTablesByFingerprint(tables []TableIdentifierWithFingerprint, fingerpr return filteredTables } -func (a *Activities) getTableReportsFromPreviousRun(ctx context.Context, accountId, workflowId, jobId string, logger log.Logger) (map[TableIdentifier]*TableReport, error) { +func (a *Activities) getTableReportsFromPreviousRun( + ctx context.Context, + accountId, workflowId, jobId string, + logger log.Logger, +) (map[TableIdentifier]*TableReport, error) { runCtx, found, err := a.getJobPiiDetectReport(ctx, accountId, workflowId, jobId) if err != nil { return nil, fmt.Errorf("unable to get job pii detect report: %w", err) @@ -287,19 +337,28 @@ func getTableColumnFingerprint(tableSchema, tableName string, columns []string) return fmt.Sprintf("%x", h.Sum(nil)) } -func (a *Activities) getJobPiiDetectReport(ctx context.Context, accountId, workflowId, jobId string) (*JobPiiDetectReport, bool, error) { - runCtxResp, err := a.jobclient.GetRunContext(ctx, connect.NewRequest(&mgmtv1alpha1.GetRunContextRequest{ - Id: &mgmtv1alpha1.RunContextKey{ - AccountId: accountId, - JobRunId: workflowId, - ExternalId: BuildJobReportExternalId(jobId), - }, - })) +func (a *Activities) getJobPiiDetectReport( + ctx context.Context, + accountId, workflowId, jobId string, +) (*JobPiiDetectReport, bool, error) { + runCtxResp, err := a.jobclient.GetRunContext( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetRunContextRequest{ + Id: &mgmtv1alpha1.RunContextKey{ + AccountId: accountId, + JobRunId: workflowId, + ExternalId: BuildJobReportExternalId(jobId), + }, + }), + ) if err != nil { if isConnectNotFoundError(err) { return nil, false, nil } - return nil, false, fmt.Errorf("unable to get run context for job pii detect report: %w", err) + return nil, false, fmt.Errorf( + "unable to get run context for job pii detect report: %w", + err, + ) } runCtxBytes := runCtxResp.Msg.GetValue() var runCtx JobPiiDetectReport @@ -318,7 +377,10 @@ func isConnectNotFoundError(err error) bool { return false } -func (a *Activities) getFilteredTables(allTables []TableIdentifierWithFingerprint, filter *mgmtv1alpha1.JobTypeConfig_JobTypePiiDetect_TableScanFilter) []TableIdentifierWithFingerprint { +func (a *Activities) getFilteredTables( + allTables []TableIdentifierWithFingerprint, + filter *mgmtv1alpha1.JobTypeConfig_JobTypePiiDetect_TableScanFilter, +) []TableIdentifierWithFingerprint { if filter == nil { return allTables } @@ -358,7 +420,9 @@ func (a *Activities) getFilteredTables(allTables []TableIdentifierWithFingerprin } // Helper function to convert proto TableIdentifier to our internal TableIdentifier -func convertProtoTablesToTableIdentifiers(protoTables []*mgmtv1alpha1.JobTypeConfig_JobTypePiiDetect_TableIdentifier) []TableIdentifier { +func convertProtoTablesToTableIdentifiers( + protoTables []*mgmtv1alpha1.JobTypeConfig_JobTypePiiDetect_TableIdentifier, +) []TableIdentifier { tables := make([]TableIdentifier, len(protoTables)) for i, pt := range protoTables { tables[i] = TableIdentifier{ @@ -369,10 +433,17 @@ func convertProtoTablesToTableIdentifiers(protoTables []*mgmtv1alpha1.JobTypeCon return tables } -func (a *Activities) getAllTablesFromConnection(ctx context.Context, sourceConnectionId string, logger *slog.Logger) ([]TableIdentifierWithFingerprint, error) { - connResp, err := a.connclient.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: sourceConnectionId, - })) +func (a *Activities) getAllTablesFromConnection( + ctx context.Context, + sourceConnectionId string, + logger *slog.Logger, +) ([]TableIdentifierWithFingerprint, error) { + connResp, err := a.connclient.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: sourceConnectionId, + }), + ) if err != nil { return nil, fmt.Errorf("unable to get connection: %w", err) } @@ -400,7 +471,11 @@ func (a *Activities) getAllTablesFromConnection(ctx context.Context, sourceConne for identifier, columns := range dbCols { tableFingerprints = append(tableFingerprints, TableIdentifierWithFingerprint{ TableIdentifier: identifier, - Fingerprint: getTableColumnFingerprint(identifier.Schema, identifier.Table, columns), + Fingerprint: getTableColumnFingerprint( + identifier.Schema, + identifier.Table, + columns, + ), }) } @@ -453,7 +528,10 @@ func (t *TableReport) ToTableIdentifier() TableIdentifier { } // After all of the tables have been scanned, this saves the final report for the run in the run context. -func (a *Activities) SaveJobPiiDetectReport(ctx context.Context, req *SaveJobPiiDetectReportRequest) (*SaveJobPiiDetectReportResponse, error) { +func (a *Activities) SaveJobPiiDetectReport( + ctx context.Context, + req *SaveJobPiiDetectReportRequest, +) (*SaveJobPiiDetectReportResponse, error) { info := activity.GetInfo(ctx) jobRunId := info.WorkflowExecution.ID diff --git a/worker/pkg/workflows/ee/piidetect/workflows/job/workflow.go b/worker/pkg/workflows/ee/piidetect/workflows/job/workflow.go index 7f22a66f64..73ffe6c4b5 100644 --- a/worker/pkg/workflows/ee/piidetect/workflows/job/workflow.go +++ b/worker/pkg/workflows/ee/piidetect/workflows/job/workflow.go @@ -33,7 +33,10 @@ type PiiDetectResponse struct { ReportKey *mgmtv1alpha1.RunContextKey } -func (w *Workflow) JobPiiDetect(ctx workflow.Context, req *PiiDetectRequest) (*PiiDetectResponse, error) { +func (w *Workflow) JobPiiDetect( + ctx workflow.Context, + req *PiiDetectRequest, +) (*PiiDetectResponse, error) { logger := log.With( workflow.GetLogger(ctx), "jobId", req.JobId, @@ -88,8 +91,13 @@ func executeWorkflow( activities *piidetect_job_activities.Activities, ) (*PiiDetectResponse, error) { var filter *mgmtv1alpha1.JobTypeConfig_JobTypePiiDetect_TableScanFilter - if jobDetailsResp != nil && jobDetailsResp.PiiDetectConfig != nil && jobDetailsResp.PiiDetectConfig.TableScanFilter != nil { - logger.Debug("using table scan filter", "filter", jobDetailsResp.PiiDetectConfig.TableScanFilter) + if jobDetailsResp != nil && jobDetailsResp.PiiDetectConfig != nil && + jobDetailsResp.PiiDetectConfig.TableScanFilter != nil { + logger.Debug( + "using table scan filter", + "filter", + jobDetailsResp.PiiDetectConfig.TableScanFilter, + ) filter = jobDetailsResp.PiiDetectConfig.TableScanFilter } @@ -118,7 +126,11 @@ func executeWorkflow( return nil, fmt.Errorf("unable to get last successful workflow id: %w", err) } if lastSuccessfulWorkflowIdResp.WorkflowId != nil { - logger.Debug("using last successful workflow id", "workflowId", *lastSuccessfulWorkflowIdResp.WorkflowId) + logger.Debug( + "using last successful workflow id", + "workflowId", + *lastSuccessfulWorkflowIdResp.WorkflowId, + ) incrementalConfig = &piidetect_job_activities.GetIncrementalTablesConfig{ LastWorkflowId: *lastSuccessfulWorkflowIdResp.WorkflowId, } @@ -207,7 +219,11 @@ func buildFinalReport( } } - successfulTableReports := make([]*piidetect_job_activities.TableReport, 0, len(fullSuccessfulTableReports)) + successfulTableReports := make( + []*piidetect_job_activities.TableReport, + 0, + len(fullSuccessfulTableReports), + ) for _, report := range fullSuccessfulTableReports { successfulTableReports = append(successfulTableReports, report) } @@ -281,10 +297,20 @@ func orchestrateTables( logger.Error("activity did not complete", "err", err) return } - logger.Debug("table pii detect completed", "table", table.Table, "schema", table.Schema) + logger.Debug( + "table pii detect completed", + "table", + table.Table, + "schema", + table.Schema, + ) err = mu.Lock(ctx) if err != nil { - logger.Error("unable to lock mutex after table pii detect completed", "err", err) + logger.Error( + "unable to lock mutex after table pii detect completed", + "err", + err, + ) return } defer mu.Unlock() @@ -299,7 +325,9 @@ func orchestrateTables( return nil } - previousReportsMap := make(map[piidetect_job_activities.TableIdentifier]*piidetect_job_activities.TableReport) + previousReportsMap := make( + map[piidetect_job_activities.TableIdentifier]*piidetect_job_activities.TableReport, + ) for _, report := range tablesToScanResp.PreviousReports { previousReportsMap[piidetect_job_activities.TableIdentifier{Schema: report.TableSchema, Table: report.TableName}] = report } diff --git a/worker/pkg/workflows/ee/piidetect/workflows/register/register.go b/worker/pkg/workflows/ee/piidetect/workflows/register/register.go index a0dffbe426..ccfcb131a7 100644 --- a/worker/pkg/workflows/ee/piidetect/workflows/register/register.go +++ b/worker/pkg/workflows/ee/piidetect/workflows/register/register.go @@ -32,13 +32,23 @@ func Register( w.RegisterWorkflow(tablePiiDetectWorkflow.TablePiiDetect) w.RegisterWorkflow(jobPiiDetectWorkflow.JobPiiDetect) - tablePiiDetectActivitites := piidetect_table_activities.New(connclient, openaiclient.Chat.Completions, connectiondatabuilder, jobclient) + tablePiiDetectActivitites := piidetect_table_activities.New( + connclient, + openaiclient.Chat.Completions, + connectiondatabuilder, + jobclient, + ) w.RegisterActivity(tablePiiDetectActivitites.GetColumnData) w.RegisterActivity(tablePiiDetectActivitites.DetectPiiRegex) w.RegisterActivity(tablePiiDetectActivitites.DetectPiiLLM) w.RegisterActivity(tablePiiDetectActivitites.SaveTablePiiDetectReport) - jobPiiDetectActivitites := piidetect_job_activities.New(jobclient, connclient, connectiondatabuilder, tmprlScheduleClient) + jobPiiDetectActivitites := piidetect_job_activities.New( + jobclient, + connclient, + connectiondatabuilder, + tmprlScheduleClient, + ) w.RegisterActivity(jobPiiDetectActivitites.GetPiiDetectJobDetails) w.RegisterActivity(jobPiiDetectActivitites.GetTablesToPiiScan) w.RegisterActivity(jobPiiDetectActivitites.SaveJobPiiDetectReport) diff --git a/worker/pkg/workflows/ee/piidetect/workflows/table/activities/activities.go b/worker/pkg/workflows/ee/piidetect/workflows/table/activities/activities.go index a26c243899..6902791b3d 100644 --- a/worker/pkg/workflows/ee/piidetect/workflows/table/activities/activities.go +++ b/worker/pkg/workflows/ee/piidetect/workflows/table/activities/activities.go @@ -24,7 +24,11 @@ import ( ) type OpenAiCompletionsClient interface { - New(ctx context.Context, body openai.ChatCompletionNewParams, opts ...option.RequestOption) (res *openai.ChatCompletion, err error) + New( + ctx context.Context, + body openai.ChatCompletionNewParams, + opts ...option.RequestOption, + ) (res *openai.ChatCompletion, err error) } type Activities struct { @@ -65,7 +69,10 @@ type ColumnData struct { Comment *string } -func (a *Activities) GetColumnData(ctx context.Context, req *GetColumnDataRequest) (*GetColumnDataResponse, error) { +func (a *Activities) GetColumnData( + ctx context.Context, + req *GetColumnDataRequest, +) (*GetColumnDataResponse, error) { logger := activity.GetLogger(ctx) slogger := temporallogger.NewSlogger(logger) @@ -102,9 +109,12 @@ func (a *Activities) getTableDetailsFromConnection( tableName string, logger *slog.Logger, ) ([]*mgmtv1alpha1.DatabaseColumn, error) { - connResp, err := a.connclient.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: connectionId, - })) + connResp, err := a.connclient.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: connectionId, + }), + ) if err != nil { return nil, err } @@ -221,7 +231,10 @@ type DetectPiiRegexResponse struct { PiiColumns map[string]PiiCategory // Changed to map column names to their PII category } -func (a *Activities) DetectPiiRegex(ctx context.Context, req *DetectPiiRegexRequest) (*DetectPiiRegexResponse, error) { +func (a *Activities) DetectPiiRegex( + ctx context.Context, + req *DetectPiiRegexRequest, +) (*DetectPiiRegexResponse, error) { logger := activity.GetLogger(ctx) piiColumns := make(map[string]PiiCategory) @@ -367,10 +380,17 @@ const ( maxDataSamples = uint(5) ) -func (a *Activities) getSampleData(ctx context.Context, req *DetectPiiLLMRequest, logger *slog.Logger) (Records, error) { - connResp, err := a.connclient.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: req.ConnectionId, - })) +func (a *Activities) getSampleData( + ctx context.Context, + req *DetectPiiLLMRequest, + logger *slog.Logger, +) (Records, error) { + connResp, err := a.connclient.GetConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: req.ConnectionId, + }), + ) if err != nil { return nil, err } @@ -401,7 +421,9 @@ Here is the table name: {{.TableName}} Here are the fields and (optionally) values: {{.RecordData}}` -var piiDetectionPromptTmpl = template.Must(template.New("pii_detection_prompt").Parse(piiDetectionPrompt)) +var piiDetectionPromptTmpl = template.Must( + template.New("pii_detection_prompt").Parse(piiDetectionPrompt), +) const ( systemMessage = "You are a helpful assistant that classifies database fields for PII." @@ -459,12 +481,18 @@ func getPrompt(records Records, tableName, userPrompt string, maxRecords uint) ( // If we're at 0 records and still over limit, something else is wrong if currentMaxRecords == 0 { - return "", fmt.Errorf("prompt exceeds token limit (%d) even with no sample data", maxTokenLimit) + return "", fmt.Errorf( + "prompt exceeds token limit (%d) even with no sample data", + maxTokenLimit, + ) } } } -func (a *Activities) DetectPiiLLM(ctx context.Context, req *DetectPiiLLMRequest) (*DetectPiiLLMResponse, error) { +func (a *Activities) DetectPiiLLM( + ctx context.Context, + req *DetectPiiLLMRequest, +) (*DetectPiiLLMResponse, error) { logger := activity.GetLogger(ctx) slogger := temporallogger.NewSlogger(logger) @@ -480,9 +508,13 @@ func (a *Activities) DetectPiiLLM(ctx context.Context, req *DetectPiiLLMRequest) logger.Debug("LLM PII detection prompt", "prompt", userMessage) chatResp, err := a.openaiclient.New(ctx, openai.ChatCompletionNewParams{ - Temperature: openai.F(0.0), - Model: openai.F(model), - ResponseFormat: openai.F[openai.ChatCompletionNewParamsResponseFormatUnion](openai.ResponseFormatJSONObjectParam{Type: openai.F(openai.ResponseFormatJSONObjectTypeJSONObject)}), + Temperature: openai.F(0.0), + Model: openai.F(model), + ResponseFormat: openai.F[openai.ChatCompletionNewParamsResponseFormatUnion]( + openai.ResponseFormatJSONObjectParam{ + Type: openai.F(openai.ResponseFormatJSONObjectTypeJSONObject), + }, + ), Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ openai.SystemMessage(systemMessage), openai.UserMessage(userMessage), @@ -636,7 +668,10 @@ func isPiiColumn(columnName string) (PiiCategory, bool) { return "", false } -func getSchemasByTable(databaseColumns []*mgmtv1alpha1.DatabaseColumn, schema, table string) []*mgmtv1alpha1.DatabaseColumn { +func getSchemasByTable( + databaseColumns []*mgmtv1alpha1.DatabaseColumn, + schema, table string, +) []*mgmtv1alpha1.DatabaseColumn { output := []*mgmtv1alpha1.DatabaseColumn{} for _, databaseColumn := range databaseColumns { if databaseColumn.Schema == schema && databaseColumn.Table == table { diff --git a/worker/pkg/workflows/ee/piidetect/workflows/table/workflow.go b/worker/pkg/workflows/ee/piidetect/workflows/table/workflow.go index 3afc0d10d8..6608464067 100644 --- a/worker/pkg/workflows/ee/piidetect/workflows/table/workflow.go +++ b/worker/pkg/workflows/ee/piidetect/workflows/table/workflow.go @@ -33,7 +33,10 @@ type TablePiiDetectResponse struct { ResultKey *mgmtv1alpha1.RunContextKey } -func (w *Workflow) TablePiiDetect(ctx workflow.Context, req *TablePiiDetectRequest) (*TablePiiDetectResponse, error) { +func (w *Workflow) TablePiiDetect( + ctx workflow.Context, + req *TablePiiDetectRequest, +) (*TablePiiDetectResponse, error) { logger := log.With( workflow.GetLogger(ctx), "jobId", req.JobId, diff --git a/worker/pkg/workflows/schemainit/activities/init-schema/activity.go b/worker/pkg/workflows/schemainit/activities/init-schema/activity.go index 3dc23037ce..1fa8eb3e6c 100644 --- a/worker/pkg/workflows/schemainit/activities/init-schema/activity.go +++ b/worker/pkg/workflows/schemainit/activities/init-schema/activity.go @@ -79,7 +79,9 @@ func (a *Activity) RunSqlInitTableStatements( return builder.RunSqlInitTableStatements( ctx, req, - connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(info.WorkflowExecution.ID)), + connectionmanager.NewUniqueSession( + connectionmanager.WithSessionGroup(info.WorkflowExecution.ID), + ), slogger, ) } diff --git a/worker/pkg/workflows/schemainit/activities/init-schema/init-schema.go b/worker/pkg/workflows/schemainit/activities/init-schema/init-schema.go index 4a1ab8642d..845ac2f6d7 100644 --- a/worker/pkg/workflows/schemainit/activities/init-schema/init-schema.go +++ b/worker/pkg/workflows/schemainit/activities/init-schema/init-schema.go @@ -63,13 +63,18 @@ func (b *initStatementBuilder) RunSqlInitTableStatements( ) if job.GetSource().GetOptions().GetAiGenerate() != nil { - sourceConnection, err = shared.GetConnectionById(ctx, b.connclient, *job.GetSource().GetOptions().GetAiGenerate().FkSourceConnectionId) + sourceConnection, err = shared.GetConnectionById( + ctx, + b.connclient, + *job.GetSource().GetOptions().GetAiGenerate().FkSourceConnectionId, + ) if err != nil { return nil, fmt.Errorf("unable to get connection by id: %w", err) } } - if sourceConnection.GetConnectionConfig().GetMongoConfig() != nil || sourceConnection.GetConnectionConfig().GetDynamodbConfig() != nil { + if sourceConnection.GetConnectionConfig().GetMongoConfig() != nil || + sourceConnection.GetConnectionConfig().GetDynamodbConfig() != nil { return &RunSqlInitTableStatementsResponse{}, nil } @@ -89,9 +94,17 @@ func (b *initStatementBuilder) RunSqlInitTableStatements( initSchemaRunContext := []*InitSchemaRunContext{} - destinationConnection, err := shared.GetConnectionById(ctx, b.connclient, destination.ConnectionId) + destinationConnection, err := shared.GetConnectionById( + ctx, + b.connclient, + destination.ConnectionId, + ) if err != nil { - return nil, fmt.Errorf("unable to get destination connection by id (%s): %w", destination.ConnectionId, err) + return nil, fmt.Errorf( + "unable to get destination connection by id (%s): %w", + destination.ConnectionId, + err, + ) } destinationConnectionType := shared.GetConnectionType(destinationConnection) slogger = slogger.With( @@ -102,7 +115,9 @@ func (b *initStatementBuilder) RunSqlInitTableStatements( if job.GetSource().GetOptions().GetAiGenerate() != nil { fkSrcConnId := job.GetSource().GetOptions().GetAiGenerate().GetFkSourceConnectionId() if fkSrcConnId == destination.GetConnectionId() { - slogger.Warn("cannot init schema when destination connection is the same as the foreign key source connection") + slogger.Warn( + "cannot init schema when destination connection is the same as the foreign key source connection", + ) shouldInitSchema = false } } @@ -110,7 +125,9 @@ func (b *initStatementBuilder) RunSqlInitTableStatements( if job.GetSource().GetOptions().GetGenerate() != nil { fkSrcConnId := job.GetSource().GetOptions().GetGenerate().GetFkSourceConnectionId() if fkSrcConnId == destination.GetConnectionId() { - slogger.Warn("cannot init schema when destination connection is the same as the foreign key source connection") + slogger.Warn( + "cannot init schema when destination connection is the same as the foreign key source connection", + ) shouldInitSchema = false } } diff --git a/worker/pkg/workflows/schemainit/activities/reconcile-schema/activity.go b/worker/pkg/workflows/schemainit/activities/reconcile-schema/activity.go index c15613ef9a..8b1493a2df 100644 --- a/worker/pkg/workflows/schemainit/activities/reconcile-schema/activity.go +++ b/worker/pkg/workflows/schemainit/activities/reconcile-schema/activity.go @@ -80,7 +80,9 @@ func (a *Activity) RunReconcileSchema( return builder.RunReconcileSchema( ctx, req, - connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(info.WorkflowExecution.ID)), + connectionmanager.NewUniqueSession( + connectionmanager.WithSessionGroup(info.WorkflowExecution.ID), + ), slogger, ) } diff --git a/worker/pkg/workflows/schemainit/activities/reconcile-schema/reconcile-schema.go b/worker/pkg/workflows/schemainit/activities/reconcile-schema/reconcile-schema.go index 45beff5776..bc14198023 100644 --- a/worker/pkg/workflows/schemainit/activities/reconcile-schema/reconcile-schema.go +++ b/worker/pkg/workflows/schemainit/activities/reconcile-schema/reconcile-schema.go @@ -64,13 +64,18 @@ func (b *reconcileSchemaBuilder) RunReconcileSchema( ) if job.GetSource().GetOptions().GetAiGenerate() != nil { - sourceConnection, err = shared.GetConnectionById(ctx, b.connclient, *job.GetSource().GetOptions().GetAiGenerate().FkSourceConnectionId) + sourceConnection, err = shared.GetConnectionById( + ctx, + b.connclient, + *job.GetSource().GetOptions().GetAiGenerate().FkSourceConnectionId, + ) if err != nil { return nil, fmt.Errorf("unable to get connection by id: %w", err) } } - if sourceConnection.GetConnectionConfig().GetMongoConfig() != nil || sourceConnection.GetConnectionConfig().GetDynamodbConfig() != nil { + if sourceConnection.GetConnectionConfig().GetMongoConfig() != nil || + sourceConnection.GetConnectionConfig().GetDynamodbConfig() != nil { return &RunReconcileSchemaResponse{}, nil } @@ -87,9 +92,17 @@ func (b *reconcileSchemaBuilder) RunReconcileSchema( uniqueTables := getUniqueTablesMapFromJob(job) - destinationConnection, err := shared.GetConnectionById(ctx, b.connclient, destination.ConnectionId) + destinationConnection, err := shared.GetConnectionById( + ctx, + b.connclient, + destination.ConnectionId, + ) if err != nil { - return nil, fmt.Errorf("unable to get destination connection by id (%s): %w", destination.ConnectionId, err) + return nil, fmt.Errorf( + "unable to get destination connection by id (%s): %w", + destination.ConnectionId, + err, + ) } destinationConnectionType := shared.GetConnectionType(destinationConnection) slogger = slogger.With( @@ -100,7 +113,9 @@ func (b *reconcileSchemaBuilder) RunReconcileSchema( if job.GetSource().GetOptions().GetAiGenerate() != nil { fkSrcConnId := job.GetSource().GetOptions().GetAiGenerate().GetFkSourceConnectionId() if fkSrcConnId == destination.GetConnectionId() { - slogger.Warn("cannot init schema when destination connection is the same as the foreign key source connection") + slogger.Warn( + "cannot init schema when destination connection is the same as the foreign key source connection", + ) shouldInitSchema = false } } @@ -108,7 +123,9 @@ func (b *reconcileSchemaBuilder) RunReconcileSchema( if job.GetSource().GetOptions().GetGenerate() != nil { fkSrcConnId := job.GetSource().GetOptions().GetGenerate().GetFkSourceConnectionId() if fkSrcConnId == destination.GetConnectionId() { - slogger.Warn("cannot init schema when destination connection is the same as the foreign key source connection") + slogger.Warn( + "cannot init schema when destination connection is the same as the foreign key source connection", + ) shouldInitSchema = false } } @@ -135,7 +152,11 @@ func (b *reconcileSchemaBuilder) RunReconcileSchema( return nil, fmt.Errorf("unable to build schema diff statements: %w", err) } - reconcileSchemaErrors, err := schemaManager.ReconcileDestinationSchema(ctx, uniqueTables, schemaStatements) + reconcileSchemaErrors, err := schemaManager.ReconcileDestinationSchema( + ctx, + uniqueTables, + schemaStatements, + ) if err != nil { return nil, fmt.Errorf("unable to reconcile schema: %w", err) } @@ -145,7 +166,12 @@ func (b *reconcileSchemaBuilder) RunReconcileSchema( Errors: reconcileSchemaErrors, } - err = b.setReconcileSchemaRunCtx(ctx, reconcileSchemaRunContext, job.AccountId, destination.Id) + err = b.setReconcileSchemaRunCtx( + ctx, + reconcileSchemaRunContext, + job.AccountId, + destination.Id, + ) if err != nil { return nil, err } @@ -219,7 +245,9 @@ func getUniqueTablesMapFromJob(job *mgmtv1alpha1.Job) map[string]*sqlmanager_sha } // Parses the job mappings and returns the unique set of tables. -func getUniqueTablesFromMappings(mappings []*mgmtv1alpha1.JobMapping) map[string]*sqlmanager_shared.SchemaTable { +func getUniqueTablesFromMappings( + mappings []*mgmtv1alpha1.JobMapping, +) map[string]*sqlmanager_shared.SchemaTable { uniqueTables := map[string]*sqlmanager_shared.SchemaTable{} for _, mapping := range mappings { schematable := &sqlmanager_shared.SchemaTable{ diff --git a/worker/pkg/workflows/schemainit/workflow/register/register.go b/worker/pkg/workflows/schemainit/workflow/register/register.go index abacdf36da..786f026cf6 100644 --- a/worker/pkg/workflows/schemainit/workflow/register/register.go +++ b/worker/pkg/workflows/schemainit/workflow/register/register.go @@ -21,7 +21,12 @@ func Register( sqlmanager *sql_manager.SqlManager, eelicense license.EEInterface, ) { - runSqlInitTableStatements := initschema_activity.New(jobclient, connclient, sqlmanager, eelicense) + runSqlInitTableStatements := initschema_activity.New( + jobclient, + connclient, + sqlmanager, + eelicense, + ) runReconcileSchema := reconcileschema_activity.New(jobclient, connclient, sqlmanager, eelicense) siWf := schemainit_workflow.New() w.RegisterWorkflow(siWf.SchemaInit) diff --git a/worker/pkg/workflows/schemainit/workflow/workflow.go b/worker/pkg/workflows/schemainit/workflow/workflow.go index 6917346657..ff9753f382 100644 --- a/worker/pkg/workflows/schemainit/workflow/workflow.go +++ b/worker/pkg/workflows/schemainit/workflow/workflow.go @@ -25,7 +25,10 @@ func New() *Workflow { return &Workflow{} } -func (w *Workflow) SchemaInit(ctx workflow.Context, req *SchemaInitRequest) (*SchemaInitResponse, error) { +func (w *Workflow) SchemaInit( + ctx workflow.Context, + req *SchemaInitRequest, +) (*SchemaInitResponse, error) { logger := log.With( workflow.GetLogger(ctx), "accountId", req.AccountId, diff --git a/worker/pkg/workflows/shared/util.go b/worker/pkg/workflows/shared/util.go index e29821b8b1..b37ee67827 100644 --- a/worker/pkg/workflows/shared/util.go +++ b/worker/pkg/workflows/shared/util.go @@ -39,8 +39,12 @@ func HandleWorkflowEventLifecycle[T any]( createdFuture := workflow.ExecuteChildWorkflow( workflow.WithChildOptions(ctx, workflow.ChildWorkflowOptions{ ParentClosePolicy: enums.PARENT_CLOSE_POLICY_ABANDON, - WorkflowID: getAccountHookChildWorkflowId(runId, "job-run-created", workflow.Now(ctx)), - StaticSummary: "Account Hook: Job Run Created", + WorkflowID: getAccountHookChildWorkflowId( + runId, + "job-run-created", + workflow.Now(ctx), + ), + StaticSummary: "Account Hook: Job Run Created", }), accounthook_workflow.ProcessAccountHook, &accounthook_workflow.ProcessAccountHookRequest{ @@ -56,8 +60,12 @@ func HandleWorkflowEventLifecycle[T any]( failedFuture := workflow.ExecuteChildWorkflow( workflow.WithChildOptions(ctx, workflow.ChildWorkflowOptions{ ParentClosePolicy: enums.PARENT_CLOSE_POLICY_ABANDON, - WorkflowID: getAccountHookChildWorkflowId(runId, "job-run-failed", workflow.Now(ctx)), - StaticSummary: "Account Hook: Job Run Failed", + WorkflowID: getAccountHookChildWorkflowId( + runId, + "job-run-failed", + workflow.Now(ctx), + ), + StaticSummary: "Account Hook: Job Run Failed", }), accounthook_workflow.ProcessAccountHook, &accounthook_workflow.ProcessAccountHookRequest{ @@ -73,8 +81,12 @@ func HandleWorkflowEventLifecycle[T any]( completedFuture := workflow.ExecuteChildWorkflow( workflow.WithChildOptions(ctx, workflow.ChildWorkflowOptions{ ParentClosePolicy: enums.PARENT_CLOSE_POLICY_ABANDON, - WorkflowID: getAccountHookChildWorkflowId(runId, "job-run-succeeded", workflow.Now(ctx)), - StaticSummary: "Account Hook: Job Run Succeeded", + WorkflowID: getAccountHookChildWorkflowId( + runId, + "job-run-succeeded", + workflow.Now(ctx), + ), + StaticSummary: "Account Hook: Job Run Succeeded", }), accounthook_workflow.ProcessAccountHook, &accounthook_workflow.ProcessAccountHookRequest{ @@ -88,7 +100,11 @@ func HandleWorkflowEventLifecycle[T any]( return resp, nil } -func ensureChildSpawned(ctx workflow.Context, future workflow.ChildWorkflowFuture, logger log.Logger) error { +func ensureChildSpawned( + ctx workflow.Context, + future workflow.ChildWorkflowFuture, + logger log.Logger, +) error { var childWE workflow.Execution if waitErr := future.GetChildWorkflowExecution().Get(ctx, &childWE); waitErr != nil { return waitErr @@ -103,7 +119,12 @@ func getAccountHookChildWorkflowId(parentJobRunId, eventName string, now time.Ti // Builds a child workflow id that is unique for the given parent execution. Sanitizes the name and cuts to the max allowed limit func BuildChildWorkflowId(parentExecutionId, name string, ts time.Time) string { - id := fmt.Sprintf("%s-%s-%d", parentExecutionId, SanitizeWorkflowID(strings.ToLower(name)), ts.UnixNano()) + id := fmt.Sprintf( + "%s-%s-%d", + parentExecutionId, + SanitizeWorkflowID(strings.ToLower(name)), + ts.UnixNano(), + ) if len(id) > 1000 { id = id[:1000] } diff --git a/worker/pkg/workflows/tablesync/activities/sync/activity.go b/worker/pkg/workflows/tablesync/activities/sync/activity.go index 691dceda97..97b14fbbab 100644 --- a/worker/pkg/workflows/tablesync/activities/sync/activity.go +++ b/worker/pkg/workflows/tablesync/activities/sync/activity.go @@ -99,7 +99,11 @@ type SyncResponse struct { } // Deprecated -func (a *Activity) Sync(ctx context.Context, req *SyncRequest, metadata *SyncMetadata) (*SyncResponse, error) { +func (a *Activity) Sync( + ctx context.Context, + req *SyncRequest, + metadata *SyncMetadata, +) (*SyncResponse, error) { info := activity.GetInfo(ctx) _, err := a.SyncTable(ctx, &SyncTableRequest{ @@ -117,7 +121,11 @@ func (a *Activity) Sync(ctx context.Context, req *SyncRequest, metadata *SyncMet }, nil } -func (a *Activity) SyncTable(ctx context.Context, req *SyncTableRequest, metadata *SyncMetadata) (*SyncTableResponse, error) { +func (a *Activity) SyncTable( + ctx context.Context, + req *SyncTableRequest, + metadata *SyncMetadata, +) (*SyncTableResponse, error) { info := activity.GetInfo(ctx) session := connectionmanager.NewUniqueSession(connectionmanager.WithSessionGroup(req.JobRunId)) @@ -173,7 +181,9 @@ func (a *Activity) SyncTable(ctx context.Context, req *SyncTableRequest, metadat var continuationTokenToReturn *string hasMorePages := func(lastReadOrderValues []any) { - token := continuation_token.NewFromContents(continuation_token.NewContents(lastReadOrderValues)) + token := continuation_token.NewFromContents( + continuation_token.NewContents(lastReadOrderValues), + ) tokenStr := token.String() continuationTokenToReturn = &tokenStr } @@ -240,7 +250,10 @@ const ( allocatorBlockSize = 1_000 // todo: should be the page limit ) -func (a *Activity) getIdentityAllocator(tclient temporalclient.Client, info *activity.Info) tablesync_shared.IdentityAllocator { +func (a *Activity) getIdentityAllocator( + tclient temporalclient.Client, + info *activity.Info, +) tablesync_shared.IdentityAllocator { blockAllocator := tablesync_shared.NewTemporalBlockAllocator( tclient, info.WorkflowExecution.ID, @@ -363,7 +376,10 @@ func (a *Activity) getBenthosStream( err = streambldr.SetYAML(benthosConfig) if err != nil { - return nil, fmt.Errorf("unable to convert benthos config to yaml for stream builder: %w", err) + return nil, fmt.Errorf( + "unable to convert benthos config to yaml for stream builder: %w", + err, + ) } stream, err := a.benthosStreamManager.NewBenthosStreamFromBuilder(streambldr) @@ -393,13 +409,23 @@ func (a *Activity) getBenthosEnvironment( logger, benthos_environment.WithMeter(a.meter), benthos_environment.WithSqlConfig(&benthos_environment.SqlConfig{ - Provider: pool_sql_provider.NewConnectionProvider(a.sqlconnmanager, getConnectionById, session, logger), + Provider: pool_sql_provider.NewConnectionProvider( + a.sqlconnmanager, + getConnectionById, + session, + logger, + ), IsRetry: isRetry, InputHasMorePages: hasMorePages, InputContinuationToken: continuationToken, }), benthos_environment.WithMongoConfig(&benthos_environment.MongoConfig{ - Provider: pool_mongo_provider.NewProvider(a.mongoconnmanager, getConnectionById, session, logger), + Provider: pool_mongo_provider.NewProvider( + a.mongoconnmanager, + getConnectionById, + session, + logger, + ), }), benthos_environment.WithStopChannel(stopActivityChan), benthos_environment.WithBlobEnv(blobEnv), @@ -429,11 +455,19 @@ func (a *Activity) getBenthosConfig( req *mgmtv1alpha1.RunContextKey, metadata *SyncMetadata, ) (string, error) { - rcResp, err := a.jobclient.GetRunContext(ctx, connect.NewRequest(&mgmtv1alpha1.GetRunContextRequest{ - Id: req, - })) + rcResp, err := a.jobclient.GetRunContext( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetRunContextRequest{ + Id: req, + }), + ) if err != nil { - return "", fmt.Errorf("unable to retrieve benthosconfig runcontext for %s.%s: %w", metadata.Schema, metadata.Table, err) + return "", fmt.Errorf( + "unable to retrieve benthosconfig runcontext for %s.%s: %w", + metadata.Schema, + metadata.Table, + err, + ) } return string(rcResp.Msg.GetValue()), nil } @@ -443,11 +477,19 @@ func (a *Activity) getConnectionIds( req *mgmtv1alpha1.RunContextKey, metadata *SyncMetadata, ) ([]string, error) { - rcResp, err := a.jobclient.GetRunContext(ctx, connect.NewRequest(&mgmtv1alpha1.GetRunContextRequest{ - Id: req, - })) + rcResp, err := a.jobclient.GetRunContext( + ctx, + connect.NewRequest(&mgmtv1alpha1.GetRunContextRequest{ + Id: req, + }), + ) if err != nil { - return nil, fmt.Errorf("unable to retrieve connection ids runcontext for %s.%s: %w", metadata.Schema, metadata.Table, err) + return nil, fmt.Errorf( + "unable to retrieve connection ids runcontext for %s.%s: %w", + metadata.Schema, + metadata.Table, + err, + ) } var connectionIds []string err = json.Unmarshal(rcResp.Msg.GetValue(), &connectionIds) @@ -468,7 +510,10 @@ func (a *Activity) getConnectionsFromConnectionIds( idx := idx connectionId := connectionId errgrp.Go(func() error { - resp, err := a.connclient.GetConnection(errctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{Id: connectionId})) + resp, err := a.connclient.GetConnection( + errctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{Id: connectionId}), + ) if err != nil { return fmt.Errorf("failed to retrieve connection: %w", err) } @@ -482,7 +527,9 @@ func (a *Activity) getConnectionsFromConnectionIds( return connections, nil } -func getConnectionByIdFn(connectionCache map[string]*mgmtv1alpha1.Connection) func(connectionId string) (connectionmanager.ConnectionInput, error) { +func getConnectionByIdFn( + connectionCache map[string]*mgmtv1alpha1.Connection, +) func(connectionId string) (connectionmanager.ConnectionInput, error) { return func(connectionId string) (connectionmanager.ConnectionInput, error) { connection, ok := connectionCache[connectionId] if !ok { diff --git a/worker/pkg/workflows/tablesync/shared/identity-allocator.go b/worker/pkg/workflows/tablesync/shared/identity-allocator.go index d97b32e990..adaae0114d 100644 --- a/worker/pkg/workflows/tablesync/shared/identity-allocator.go +++ b/worker/pkg/workflows/tablesync/shared/identity-allocator.go @@ -41,7 +41,10 @@ type TemporalBlockAllocator struct { runId string } -func NewTemporalBlockAllocator(temporalclient temporalclient.Client, workflowId, runId string) *TemporalBlockAllocator { +func NewTemporalBlockAllocator( + temporalclient temporalclient.Client, + workflowId, runId string, +) *TemporalBlockAllocator { return &TemporalBlockAllocator{ temporalclient: temporalclient, workflowId: workflowId, @@ -49,7 +52,11 @@ func NewTemporalBlockAllocator(temporalclient temporalclient.Client, workflowId, } } -func (i *TemporalBlockAllocator) GetNextBlock(ctx context.Context, token string, blockSize uint) (*IdentityRange, error) { +func (i *TemporalBlockAllocator) GetNextBlock( + ctx context.Context, + token string, + blockSize uint, +) (*IdentityRange, error) { handle, err := i.temporalclient.UpdateWorkflow(ctx, temporalclient.UpdateWorkflowOptions{ WorkflowID: i.workflowId, RunID: i.runId, @@ -61,7 +68,11 @@ func (i *TemporalBlockAllocator) GetNextBlock(ctx context.Context, token string, WaitForStage: temporalclient.WorkflowUpdateStageCompleted, }) if err != nil { - return nil, fmt.Errorf("unable to send update to get next block size for identity %s: %w", token, err) + return nil, fmt.Errorf( + "unable to send update to get next block size for identity %s: %w", + token, + err, + ) } var resp *AllocateIdentityBlockResponse err = handle.Get(ctx, &resp) @@ -89,7 +100,11 @@ type MultiIdentityAllocator struct { allocators map[string]*SingleIdentityAllocator } -func NewMultiIdentityAllocator(blockAllocator BlockAllocator, blockSize uint, seed uint64) *MultiIdentityAllocator { +func NewMultiIdentityAllocator( + blockAllocator BlockAllocator, + blockSize uint, + seed uint64, +) *MultiIdentityAllocator { return &MultiIdentityAllocator{ blockAllocator: blockAllocator, blockSize: blockSize, @@ -99,13 +114,21 @@ func NewMultiIdentityAllocator(blockAllocator BlockAllocator, blockSize uint, se } } -func (i *MultiIdentityAllocator) GetIdentity(ctx context.Context, token string, value *uint) (uint, error) { +func (i *MultiIdentityAllocator) GetIdentity( + ctx context.Context, + token string, + value *uint, +) (uint, error) { i.mu.Lock() defer i.mu.Unlock() allocator, ok := i.allocators[token] if !ok { - allocator = NewSingleIdentityAllocator(i.blockAllocator, i.blockSize, rng.NewSplit(i.seed, hashToSeed(token))) + allocator = NewSingleIdentityAllocator( + i.blockAllocator, + i.blockSize, + rng.NewSplit(i.seed, hashToSeed(token)), + ) i.allocators[token] = allocator } return allocator.GetIdentity(ctx, token, value) @@ -132,7 +155,11 @@ type SingleIdentityAllocator struct { usedValues map[uint]struct{} } -func NewSingleIdentityAllocator(blockAllocator BlockAllocator, blockSize uint, rand rng.Rand) *SingleIdentityAllocator { +func NewSingleIdentityAllocator( + blockAllocator BlockAllocator, + blockSize uint, + rand rng.Rand, +) *SingleIdentityAllocator { return &SingleIdentityAllocator{ blockAllocator: blockAllocator, blockSize: blockSize, @@ -141,7 +168,11 @@ func NewSingleIdentityAllocator(blockAllocator BlockAllocator, blockSize uint, r } } -func (i *SingleIdentityAllocator) GetIdentity(ctx context.Context, token string, value *uint) (uint, error) { +func (i *SingleIdentityAllocator) GetIdentity( + ctx context.Context, + token string, + value *uint, +) (uint, error) { i.mu.Lock() defer i.mu.Unlock() diff --git a/worker/pkg/workflows/tablesync/workflow/workflow.go b/worker/pkg/workflows/tablesync/workflow/workflow.go index 9fc2d7b7c7..5e2daa748a 100644 --- a/worker/pkg/workflows/tablesync/workflow/workflow.go +++ b/worker/pkg/workflows/tablesync/workflow/workflow.go @@ -38,7 +38,10 @@ func New(maxIterations int) *Workflow { } } -func (w *Workflow) TableSync(ctx workflow.Context, req *TableSyncRequest) (*TableSyncResponse, error) { +func (w *Workflow) TableSync( + ctx workflow.Context, + req *TableSyncRequest, +) (*TableSyncResponse, error) { logger := log.With( workflow.GetLogger(ctx), "accountId", req.AccountId, @@ -66,7 +69,10 @@ func (w *Workflow) TableSync(ctx workflow.Context, req *TableSyncRequest) (*Tabl var resp *sync_activity.SyncTableResponse err := workflow.ExecuteActivity( - workflow.WithActivityOptions(ctx, *req.SyncActivityOptions), // todo: check sync activity options nil + workflow.WithActivityOptions( + ctx, + *req.SyncActivityOptions, + ), // todo: check sync activity options nil syncActivity.SyncTable, sync_activity.SyncTableRequest{ Id: req.Id, @@ -115,7 +121,10 @@ func (w *Workflow) TableSync(ctx workflow.Context, req *TableSyncRequest) (*Tabl } // Sets a temporal update handle for use with allocating identity blocks for auto increment columns -func setCursorUpdateHandler(ctx workflow.Context, cursors map[string]*tablesync_shared.IdentityCursor) error { +func setCursorUpdateHandler( + ctx workflow.Context, + cursors map[string]*tablesync_shared.IdentityCursor, +) error { cursorMutex := workflow.NewMutex(ctx) return workflow.SetUpdateHandlerWithOptions( ctx, @@ -147,7 +156,9 @@ func setCursorUpdateHandler(ctx workflow.Context, cursors map[string]*tablesync_ Validator: func(ctx workflow.Context, req *tablesync_shared.AllocateIdentityBlockRequest) error { // Note: The validator function is a read-only function and cannot access workflow state if req == nil { - return errors.New("request is nil, expected a valid *AllocateIdentityBlockRequest") + return errors.New( + "request is nil, expected a valid *AllocateIdentityBlockRequest", + ) } if req.Id == "" || req.BlockSize == 0 { return errors.New("id and block size are required") From d00dcd9ee0a93620c724580c9a3ad217718b5aa0 Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 13:48:51 -0700 Subject: [PATCH 07/12] more lint fixes --- .../pkg/sqlmanager/postgres/postgres-manager.go | 7 +++++-- cli/internal/cmds/neosync/accounts/switch.go | 8 ++++---- cli/internal/cmds/neosync/login/login.go | 4 ++-- cli/internal/cmds/neosync/sync/ui.go | 1 - internal/benthos_slogger/logger.go | 4 ++-- worker/internal/temporal-logger/logger.go | 4 ++-- worker/pkg/benthos/redis/output_hash.go | 16 ++++++++-------- .../account_hooks/activities/execute/activity.go | 4 ++-- 8 files changed, 25 insertions(+), 23 deletions(-) diff --git a/backend/pkg/sqlmanager/postgres/postgres-manager.go b/backend/pkg/sqlmanager/postgres/postgres-manager.go index 2c81b290d3..d7184ce48f 100644 --- a/backend/pkg/sqlmanager/postgres/postgres-manager.go +++ b/backend/pkg/sqlmanager/postgres/postgres-manager.go @@ -844,8 +844,11 @@ func (p *PostgresManager) GetTableInitStatements( for _, record := range tableData { record := record var seqDefinition *string - if record.IdentityGeneration != "" && record.SeqStartValue.Valid && record.SeqMinValue.Valid && - record.SeqMaxValue.Valid && record.SeqIncrementBy.Valid && + if record.IdentityGeneration != "" && + record.SeqStartValue.Valid && + record.SeqMinValue.Valid && + record.SeqMaxValue.Valid && + record.SeqIncrementBy.Valid && record.SeqCycleOption.Valid && record.SeqCacheValue.Valid { seqConfig := &SequenceConfiguration{ diff --git a/cli/internal/cmds/neosync/accounts/switch.go b/cli/internal/cmds/neosync/accounts/switch.go index 90036bcd30..fc02e5d4ca 100644 --- a/cli/internal/cmds/neosync/accounts/switch.go +++ b/cli/internal/cmds/neosync/accounts/switch.go @@ -178,11 +178,11 @@ func switchAccount( return fmt.Errorf("unable to set account context: %w", err) } - fmt.Println( + fmt.Println( //nolint:forbidigo itemStyle.Render( fmt.Sprintf("\n Switched account to %s (%s) \n", account.Name, account.Id), ), - ) //nolint:forbidigo + ) return nil } @@ -217,10 +217,10 @@ func (d itemDelegate) Update(_ tea.Msg, _ *list.Model) tea.Cmd { return nil } func (d itemDelegate) Render( w io.Writer, - m list.Model, + m list.Model, //nolint:gocritic index int, listItem list.Item, -) { //nolint:gocritic +) { i, ok := listItem.(item) if !ok { return diff --git a/cli/internal/cmds/neosync/login/login.go b/cli/internal/cmds/neosync/login/login.go index 4df223270d..da676fdd47 100644 --- a/cli/internal/cmds/neosync/login/login.go +++ b/cli/internal/cmds/neosync/login/login.go @@ -151,10 +151,10 @@ func oAuthLogin( }() if err := webbrowser.Open(authorizeurlResp.Msg.Url); err != nil { - fmt.Println( + fmt.Println( //nolint:forbidigo "There was an issue opening the web browser, proceed to the following url to finish logging in to Neosync:\n", authorizeurlResp.Msg.Url, - ) //nolint + ) } select { diff --git a/cli/internal/cmds/neosync/sync/ui.go b/cli/internal/cmds/neosync/sync/ui.go index 2f7ec9b996..be2994f4cb 100644 --- a/cli/internal/cmds/neosync/sync/ui.go +++ b/cli/internal/cmds/neosync/sync/ui.go @@ -180,7 +180,6 @@ func (m *model) syncConfigs( } results := map[string]string{} - //nolint:gofmt messageMap.Range(func(key, value any) bool { d := value.(time.Duration) results[key.(string)] = fmt.Sprintf("%s %s %s", checkMark, key, diff --git a/internal/benthos_slogger/logger.go b/internal/benthos_slogger/logger.go index 75bfaefd7a..ed40b42e06 100644 --- a/internal/benthos_slogger/logger.go +++ b/internal/benthos_slogger/logger.go @@ -22,8 +22,8 @@ func (h *benthosLogHandler) Enabled(ctx context.Context, level slog.Level) bool func (h *benthosLogHandler) Handle( ctx context.Context, - r slog.Record, -) error { //nolint:gocritic // Needs to conform to the slog.Handler interface + r slog.Record, //nolint:gocritic // Needs to conform to the slog.Handler interface +) error { // Combine pre-defined attrs with record attrs allAttrs := make([]slog.Attr, 0, len(h.attrs)+r.NumAttrs()) allAttrs = append(allAttrs, h.attrs...) diff --git a/worker/internal/temporal-logger/logger.go b/worker/internal/temporal-logger/logger.go index d4c44123e1..f466c3dcff 100644 --- a/worker/internal/temporal-logger/logger.go +++ b/worker/internal/temporal-logger/logger.go @@ -22,8 +22,8 @@ func (h *temporalLogHandler) Enabled(ctx context.Context, level slog.Level) bool func (h *temporalLogHandler) Handle( ctx context.Context, - r slog.Record, -) error { //nolint:gocritic // Needs to conform to the slog.Handler interface + r slog.Record, //nolint:gocritic // Needs to conform to the slog.Handler interface +) error { // Combine pre-defined attrs with record attrs allAttrs := make([]slog.Attr, 0, len(h.attrs)+r.NumAttrs()) allAttrs = append(allAttrs, h.attrs...) diff --git a/worker/pkg/benthos/redis/output_hash.go b/worker/pkg/benthos/redis/output_hash.go index 60d24c70a0..4f2e57c7a2 100644 --- a/worker/pkg/benthos/redis/output_hash.go +++ b/worker/pkg/benthos/redis/output_hash.go @@ -47,10 +47,10 @@ func init() { redisHashOutputConfig(), func(conf *service.ParsedConfig, mgr *service.Resources) (out service.Output, maxInFlight int, err error) { if maxInFlight, err = conf.FieldMaxInFlight(); err != nil { - return + return nil, 0, err } out, err = newRedisHashWriter(conf, mgr) - return + return out, maxInFlight, err }, ) if err != nil { @@ -82,17 +82,17 @@ func newRedisHashWriter( log: mgr.Logger(), } if _, err = getClient(conf); err != nil { - return + return nil, err } if r.key, err = conf.FieldInterpolatedString(hoFieldKey); err != nil { - return + return nil, err } if r.walkMetadata, err = conf.FieldBool(hoFieldWalkMetadata); err != nil { - return + return nil, err } if r.walkJSON, err = conf.FieldBool(hoFieldWalkJSON); err != nil { - return + return nil, err } if r.fieldsMapping, err = conf.FieldBloblang(hoFieldFieldsMapping); err != nil { return nil, err @@ -101,7 +101,7 @@ func newRedisHashWriter( if !r.walkMetadata && !r.walkJSON && r.fieldsMapping == nil { return nil, errors.New("at least one mechanism for setting fields must be enabled") } - return + return r, nil } func (r *redisHashWriter) Connect(ctx context.Context) error { @@ -180,7 +180,7 @@ func (r *redisHashWriter) Write(ctx context.Context, msg *service.Message) error } if mapVal != nil { - fieldMappings, ok := mapVal.(map[string]any) //nolint:gofmt + fieldMappings, ok := mapVal.(map[string]any) if !ok { return fmt.Errorf("fieldMappings resulted in a non-object mapping: %T", mapVal) } diff --git a/worker/pkg/workflows/ee/account_hooks/activities/execute/activity.go b/worker/pkg/workflows/ee/account_hooks/activities/execute/activity.go index bd33eb81e4..98aa4c4c81 100644 --- a/worker/pkg/workflows/ee/account_hooks/activities/execute/activity.go +++ b/worker/pkg/workflows/ee/account_hooks/activities/execute/activity.go @@ -200,8 +200,8 @@ func executeWebhookRequest( if skipSslVerification { client.Transport = &http.Transport{ TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, //nolint:gosec // we want to enable this if it's user specified + InsecureSkipVerify: true, //nolint:gosec // we want to enable this if it's user specified + }, } } resp, err := client.Do(httpReq) From 9545c6a2ca4dd300050c9fe829bf43e7258e691e Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 13:49:10 -0700 Subject: [PATCH 08/12] enables standard mode --- .golangci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.golangci.yaml b/.golangci.yaml index 8aa0b0ff1a..07578be033 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -2,7 +2,7 @@ version: "2" run: tests: false linters: - default: none + default: standard enable: - bodyclose - dogsled From 12f90e43d6129e0c6090f4df3df4d16aeb201c52 Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 13:54:00 -0700 Subject: [PATCH 09/12] sets max len to 140 --- .golangci.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.golangci.yaml b/.golangci.yaml index 07578be033..2d7b09f728 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -111,6 +111,8 @@ formatters: rewrite-rules: - pattern: interface{} replacement: any + golines: + max-len: 140 exclusions: generated: lax paths: From ac2376d50b41ab1bdc4af3d55bf92fa690524f4a Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 13:59:15 -0700 Subject: [PATCH 10/12] merges and reruns lint --- internal/schema-manager/mysql/mysql.go | 22 +++++- internal/schema-manager/postgres/postgres.go | 76 +++++++++++++++---- internal/schema-manager/shared/schema-diff.go | 10 ++- 3 files changed, 90 insertions(+), 18 deletions(-) diff --git a/internal/schema-manager/mysql/mysql.go b/internal/schema-manager/mysql/mysql.go index 5916f46680..1e03d722f2 100644 --- a/internal/schema-manager/mysql/mysql.go +++ b/internal/schema-manager/mysql/mysql.go @@ -169,7 +169,12 @@ func getDatabaseDataForSchemaDiff( nonFkConstraints[key] = nonFkConstraint } for _, fkConstraint := range tableconstraint.ForeignKeyConstraints { - key := fmt.Sprintf("%s.%s.%s", fkConstraint.ReferencingSchema, fkConstraint.ReferencingTable, fkConstraint.ConstraintName) + key := fmt.Sprintf( + "%s.%s.%s", + fkConstraint.ReferencingSchema, + fkConstraint.ReferencingTable, + fkConstraint.ConstraintName, + ) fkConstraints[key] = fkConstraint } } @@ -223,7 +228,15 @@ func (d *MysqlSchemaManager) BuildSchemaDiffStatements( } // only way to update non fk constraint is to drop and recreate for _, constraint := range diff.ExistsInBoth.Different.NonForeignKeyConstraints { - dropNonFkConstraintStatements = append(dropNonFkConstraintStatements, sqlmanager_mysql.BuildDropConstraintStatement(constraint.SchemaName, constraint.TableName, constraint.ConstraintType, constraint.ConstraintName)) + dropNonFkConstraintStatements = append( + dropNonFkConstraintStatements, + sqlmanager_mysql.BuildDropConstraintStatement( + constraint.SchemaName, + constraint.TableName, + constraint.ConstraintType, + constraint.ConstraintName, + ), + ) } orderedForeignKeysToDrop := shared.BuildOrderedForeignKeyConstraintsToDrop(d.logger, diff) @@ -257,7 +270,10 @@ func (d *MysqlSchemaManager) BuildSchemaDiffStatements( } // only way to update trigger is to drop and recreate for _, trigger := range diff.ExistsInBoth.Different.Triggers { - dropTriggerStatements = append(dropTriggerStatements, sqlmanager_mysql.BuildDropTriggerStatement(trigger.TriggerSchema, trigger.TriggerName)) + dropTriggerStatements = append( + dropTriggerStatements, + sqlmanager_mysql.BuildDropTriggerStatement(trigger.TriggerSchema, trigger.TriggerName), + ) } updateColumnStatements := []string{} diff --git a/internal/schema-manager/postgres/postgres.go b/internal/schema-manager/postgres/postgres.go index a5c2f81254..d2326c90d8 100644 --- a/internal/schema-manager/postgres/postgres.go +++ b/internal/schema-manager/postgres/postgres.go @@ -57,7 +57,10 @@ func NewPostgresSchemaManager( }, nil } -func (d *PostgresSchemaManager) CalculateSchemaDiff(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable) (*shared.SchemaDifferences, error) { +func (d *PostgresSchemaManager) CalculateSchemaDiff( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, +) (*shared.SchemaDifferences, error) { d.logger.Debug("calculating schema diff") tables := []*sqlmanager_shared.SchemaTable{} schemaMap := map[string][]*sqlmanager_shared.SchemaTable{} @@ -136,7 +139,12 @@ func getDatabaseDataForSchemaDiff( nonFkConstraints[key] = nonFkConstraint } for _, fkConstraint := range tableconstraint.ForeignKeyConstraints { - key := fmt.Sprintf("%s.%s.%s", fkConstraint.ReferencingSchema, fkConstraint.ReferencingTable, fkConstraint.ConstraintName) + key := fmt.Sprintf( + "%s.%s.%s", + fkConstraint.ReferencingSchema, + fkConstraint.ReferencingTable, + fkConstraint.ConstraintName, + ) fkConstraints[key] = fkConstraint } } @@ -184,7 +192,10 @@ func getDatabaseDataForSchemaDiff( }, nil } -func (d *PostgresSchemaManager) BuildSchemaDiffStatements(ctx context.Context, diff *shared.SchemaDifferences) ([]*sqlmanager_shared.InitSchemaStatements, error) { +func (d *PostgresSchemaManager) BuildSchemaDiffStatements( + ctx context.Context, + diff *shared.SchemaDifferences, +) ([]*sqlmanager_shared.InitSchemaStatements, error) { d.logger.Debug("building schema diff statements") if !d.destOpts.GetInitTableSchema() { d.logger.Info("skipping schema init as it is not enabled") @@ -200,44 +211,79 @@ func (d *PostgresSchemaManager) BuildSchemaDiffStatements(ctx context.Context, d dropNonFkConstraintStatements := []string{} for _, constraint := range diff.ExistsInDestination.NonForeignKeyConstraints { - dropNonFkConstraintStatements = append(dropNonFkConstraintStatements, sqlmanager_postgres.BuildDropConstraintStatement(constraint.SchemaName, constraint.TableName, constraint.ConstraintName)) + dropNonFkConstraintStatements = append( + dropNonFkConstraintStatements, + sqlmanager_postgres.BuildDropConstraintStatement(constraint.SchemaName, constraint.TableName, constraint.ConstraintName), + ) } // only way to update non fk constraint is to drop and recreate for _, constraint := range diff.ExistsInBoth.Different.NonForeignKeyConstraints { - dropNonFkConstraintStatements = append(dropNonFkConstraintStatements, sqlmanager_postgres.BuildDropConstraintStatement(constraint.SchemaName, constraint.TableName, constraint.ConstraintName)) + dropNonFkConstraintStatements = append( + dropNonFkConstraintStatements, + sqlmanager_postgres.BuildDropConstraintStatement(constraint.SchemaName, constraint.TableName, constraint.ConstraintName), + ) } dropFkConstraintStatements := []string{} for _, constraint := range diff.ExistsInDestination.ForeignKeyConstraints { - dropFkConstraintStatements = append(dropFkConstraintStatements, sqlmanager_postgres.BuildDropConstraintStatement(constraint.ReferencingSchema, constraint.ReferencingTable, constraint.ConstraintName)) + dropFkConstraintStatements = append( + dropFkConstraintStatements, + sqlmanager_postgres.BuildDropConstraintStatement( + constraint.ReferencingSchema, + constraint.ReferencingTable, + constraint.ConstraintName, + ), + ) } // only way to update fk constraint is to drop and recreate for _, constraint := range diff.ExistsInBoth.Different.ForeignKeyConstraints { - dropFkConstraintStatements = append(dropFkConstraintStatements, sqlmanager_postgres.BuildDropConstraintStatement(constraint.ReferencingSchema, constraint.ReferencingTable, constraint.ConstraintName)) + dropFkConstraintStatements = append( + dropFkConstraintStatements, + sqlmanager_postgres.BuildDropConstraintStatement( + constraint.ReferencingSchema, + constraint.ReferencingTable, + constraint.ConstraintName, + ), + ) } dropColumnStatements := []string{} for _, column := range diff.ExistsInDestination.Columns { - dropColumnStatements = append(dropColumnStatements, sqlmanager_postgres.BuildDropColumnStatement(column.Schema, column.Table, column.Name)) + dropColumnStatements = append( + dropColumnStatements, + sqlmanager_postgres.BuildDropColumnStatement(column.Schema, column.Table, column.Name), + ) } dropTriggerStatements := []string{} for _, trigger := range diff.ExistsInDestination.Triggers { - dropTriggerStatements = append(dropTriggerStatements, sqlmanager_postgres.BuildDropTriggerStatement(trigger.Schema, trigger.Table, trigger.TriggerName)) + dropTriggerStatements = append( + dropTriggerStatements, + sqlmanager_postgres.BuildDropTriggerStatement(trigger.Schema, trigger.Table, trigger.TriggerName), + ) } // only way to update trigger is to drop and recreate for _, trigger := range diff.ExistsInBoth.Different.Triggers { - dropTriggerStatements = append(dropTriggerStatements, sqlmanager_postgres.BuildDropTriggerStatement(trigger.Schema, trigger.Table, trigger.TriggerName)) + dropTriggerStatements = append( + dropTriggerStatements, + sqlmanager_postgres.BuildDropTriggerStatement(trigger.Schema, trigger.Table, trigger.TriggerName), + ) } dropFunctionStatements := []string{} for _, function := range diff.ExistsInDestination.Functions { - dropFunctionStatements = append(dropFunctionStatements, sqlmanager_postgres.BuildDropFunctionStatement(function.Schema, function.Name)) + dropFunctionStatements = append( + dropFunctionStatements, + sqlmanager_postgres.BuildDropFunctionStatement(function.Schema, function.Name), + ) } updateFunctionStatements := []string{} for _, function := range diff.ExistsInBoth.Different.Functions { - updateFunctionStatements = append(updateFunctionStatements, sqlmanager_postgres.BuildUpdateFunctionStatement(function.Schema, function.Name, function.Definition)) + updateFunctionStatements = append( + updateFunctionStatements, + sqlmanager_postgres.BuildUpdateFunctionStatement(function.Schema, function.Name, function.Definition), + ) } return []*sqlmanager_shared.InitSchemaStatements{ @@ -272,7 +318,11 @@ func (d *PostgresSchemaManager) BuildSchemaDiffStatements(ctx context.Context, d }, nil } -func (d *PostgresSchemaManager) ReconcileDestinationSchema(ctx context.Context, uniqueTables map[string]*sqlmanager_shared.SchemaTable, schemaStatements []*sqlmanager_shared.InitSchemaStatements) ([]*shared.InitSchemaError, error) { +func (d *PostgresSchemaManager) ReconcileDestinationSchema( + ctx context.Context, + uniqueTables map[string]*sqlmanager_shared.SchemaTable, + schemaStatements []*sqlmanager_shared.InitSchemaStatements, +) ([]*shared.InitSchemaError, error) { d.logger.Debug("reconciling destination schema") initErrors := []*shared.InitSchemaError{} if !d.destOpts.GetInitTableSchema() { diff --git a/internal/schema-manager/shared/schema-diff.go b/internal/schema-manager/shared/schema-diff.go index f7d2f12593..862e7f82f1 100644 --- a/internal/schema-manager/shared/schema-diff.go +++ b/internal/schema-manager/shared/schema-diff.go @@ -146,14 +146,20 @@ func (b *SchemaDifferencesBuilder) buildTableColumnDifferences() { } func (b *SchemaDifferencesBuilder) buildTableForeignKeyConstraintDifferences() { - existsInSource, existsInBoth, existsInDestination := buildDifferencesByFingerprint(b.source.ForeignKeyConstraints, b.destination.ForeignKeyConstraints) + existsInSource, existsInBoth, existsInDestination := buildDifferencesByFingerprint( + b.source.ForeignKeyConstraints, + b.destination.ForeignKeyConstraints, + ) b.diff.ExistsInSource.ForeignKeyConstraints = existsInSource b.diff.ExistsInBoth.Different.ForeignKeyConstraints = existsInBoth b.diff.ExistsInDestination.ForeignKeyConstraints = existsInDestination } func (b *SchemaDifferencesBuilder) buildTableNonForeignKeyConstraintDifferences() { - existsInSource, existsInBoth, existsInDestination := buildDifferencesByFingerprint(b.source.NonForeignKeyConstraints, b.destination.NonForeignKeyConstraints) + existsInSource, existsInBoth, existsInDestination := buildDifferencesByFingerprint( + b.source.NonForeignKeyConstraints, + b.destination.NonForeignKeyConstraints, + ) b.diff.ExistsInSource.NonForeignKeyConstraints = existsInSource b.diff.ExistsInBoth.Different.NonForeignKeyConstraints = existsInBoth b.diff.ExistsInDestination.NonForeignKeyConstraints = existsInDestination From 7c0daba049fb81e8b499643b58d5b1faac0d2bdb Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 14:14:40 -0700 Subject: [PATCH 11/12] Fixes generator with new line formatting --- .../benthos/transformers/generator_utils.go | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/worker/pkg/benthos/transformers/generator_utils.go b/worker/pkg/benthos/transformers/generator_utils.go index 4321378e38..9d04d8d235 100644 --- a/worker/pkg/benthos/transformers/generator_utils.go +++ b/worker/pkg/benthos/transformers/generator_utils.go @@ -112,19 +112,25 @@ func ParseBloblangSpec(benthosSpec *BenthosSpec) (*ParsedBenthosSpec, error) { var benthosSpecStr string start := false + foundRegister := false for fileScanner.Scan() { line := fileScanner.Text() if strings.Contains(line, "bloblang.NewPluginSpec") { start = true - benthosSpecStr += strings.TrimSpace(fileScanner.Text()) + benthosSpecStr += strings.TrimSpace(line) } else if start { - if strings.Contains(line, ".RegisterFunctionV2") { - benthosSpecStr += strings.TrimSpace(fileScanner.Text()) - break + benthosSpecStr += strings.TrimSpace(line) + if foundRegister { + break // Now we break after capturing one more line after RegisterFunctionV2 + } + if strings.Contains(line, "RegisterFunctionV2") { + foundRegister = true } - benthosSpecStr += strings.TrimSpace(fileScanner.Text()) } } + if !foundRegister { + return nil, fmt.Errorf("RegisterFunctionV2 not found in file: %s", filepath.Base(benthosSpec.SourceFile)) + } categoryRegex := regexp.MustCompile(`\.Category\("([^"]*)"\)`) var category string @@ -188,15 +194,13 @@ func ParseBloblangSpec(benthosSpec *BenthosSpec) (*ParsedBenthosSpec, error) { func extractBloblangFunctionName(input, sourceFile string) (string, error) { // Looks for bloblang.RegisterFunctionV2 and captures the function name in quotes - re := regexp.MustCompile(`\.RegisterFunctionV2\("([^"]+)"`) - + re := regexp.MustCompile(`RegisterFunctionV2\s*\(\s*"([^"]+)"`) matches := re.FindStringSubmatch(input) - if len(matches) > 1 { - return matches[1], nil + if len(matches) == 0 { + return "", fmt.Errorf("bloblang function name not found: %s", filepath.Base(sourceFile)) } - - return "", fmt.Errorf("bloblang function name not found: %s", sourceFile) + return matches[1], nil } func lowercaseFirst(s string) string { From 653da3fb6f5a0e9bd5d480ae965ff7f093eac3ff Mon Sep 17 00:00:00 2001 From: Nick Z <2420177+nickzelei@users.noreply.github.com> Date: Fri, 28 Mar 2025 14:28:57 -0700 Subject: [PATCH 12/12] updates test expectation --- internal/job/jobmapping-validator_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/job/jobmapping-validator_test.go b/internal/job/jobmapping-validator_test.go index 72c988743b..5b07446209 100644 --- a/internal/job/jobmapping-validator_test.go +++ b/internal/job/jobmapping-validator_test.go @@ -429,7 +429,7 @@ func TestValidateCircularDependencies(t *testing.T) { require.Len(t, errs, 1) require.Len(t, errs, 1) assert.Equal(t, mgmtv1alpha1.DatabaseError_DATABASE_ERROR_CODE_UNSUPPORTED_CIRCULAR_DEPENDENCY_AT_LEAST_ONE_NULLABLE, errs[0].Code) - assert.Contains(t, errs[0].Message, "Unsupported circular dependency detected. At least one foreign key in circular dependency must be nullable") + assert.Contains(t, errs[0].Message, "unsupported circular dependency detected. at least one foreign key in circular dependency must be nullable") }) t.Run("should not return error when cycle has nullable foreign key", func(t *testing.T) { @@ -591,7 +591,7 @@ func TestValidateCircularDependencies(t *testing.T) { require.NotEmpty(t, errs) require.Len(t, errs, 1) assert.Equal(t, mgmtv1alpha1.DatabaseError_DATABASE_ERROR_CODE_UNSUPPORTED_CIRCULAR_DEPENDENCY_AT_LEAST_ONE_NULLABLE, errs[0].Code) - assert.Contains(t, errs[0].Message, "Unsupported circular dependency detected. At least one foreign key in circular dependency must be nullable") + assert.Contains(t, errs[0].Message, "unsupported circular dependency detected. at least one foreign key in circular dependency must be nullable") }) t.Run("should skip tables not in mappings", func(t *testing.T) {