From fd98f1235ecb05841c77aa8422dd00c7e5c4f329 Mon Sep 17 00:00:00 2001 From: Julian Vanden Broeck Date: Mon, 20 Apr 2026 15:50:59 +0200 Subject: [PATCH] Refactor code into internal packages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We also standardize logger dependencies. Centralize responsibilities across the code by reorganizing logic into domain-oriented internal packages and making dependencies explicit in function signatures. With this commit, we also: - Move storage-related logic (upload, purge, locking) into internal/storage; - Move PostgreSQL helpers (connection, SQL execution) into internal/postgresql; - Move hashing and checksum utilities into internal/crypto and improve them; - Move command hooks (pre/post execution) into the command package; - Move shared helpers (path formatting, naming, relPath, etc.) into a helpers package; - Move logger and configuration handling into internal packages; - Refactor pg_dump handling and move command-related logic into internal; Additionally, update function signatures to explicitly pass dependencies (e.g. logger instances) instead of relying on implicit or package-level state. For example: ``` lockPath(path string) → LockPath(logger *logger.LevelLog, path string) ``` This change is intentionally broad and mostly mechanical. While somewhat naive, it establishes a consistent pattern for dependency injection, improves observability, and removes hidden coupling across packages. This is a first step before deeper refactoring. --- helpers.go | 10 - internal/command/command.go | 53 +++ internal/command/command_test.go | 45 +++ hook.go => internal/command/hook.go | 29 +- hook_test.go => internal/command/hook_test.go | 28 +- config.go => internal/config/config.go | 90 +++-- .../config/config_test.go | 76 ++-- crypto.go => internal/crypto/crypto.go | 71 ++-- .../crypto/crypto_test.go | 54 +-- hash.go => internal/crypto/hash.go | 39 +- hash_test.go => internal/crypto/hash_test.go | 24 +- internal/helpers/helpers.go | 93 +++++ legacy.go => internal/legacy/legacy.go | 14 +- .../legacy/legacy_test.go | 7 +- log.go => internal/logger/log.go | 50 ++- log_test.go => internal/logger/log_test.go | 24 +- internal/metadata/metadata.go | 3 + .../postgresql/connstring.go | 7 +- .../postgresql/connstring_test.go | 4 +- sql.go => internal/postgresql/sql.go | 155 ++++---- .../postgresql/sql_test.go | 41 +- lock.go => internal/storage/lock.go | 16 +- lock_test.go => internal/storage/lock_test.go | 18 +- lock_win.go => internal/storage/lock_win.go | 14 +- purge.go => internal/storage/purge.go | 82 ++-- .../storage/purge_test.go | 24 +- upload.go => internal/storage/upload.go | 142 +++---- .../storage/upload_test.go | 7 +- main.go | 374 ++++++------------ main_test.go | 50 +-- 30 files changed, 887 insertions(+), 757 deletions(-) delete mode 100644 helpers.go create mode 100644 internal/command/command.go create mode 100644 internal/command/command_test.go rename hook.go => internal/command/hook.go (73%) rename hook_test.go => internal/command/hook_test.go (88%) rename config.go => internal/config/config.go (94%) rename config_test.go => internal/config/config_test.go (96%) rename crypto.go => internal/crypto/crypto.go (81%) rename crypto_test.go => internal/crypto/crypto_test.go (84%) rename hash.go => internal/crypto/hash.go (84%) rename hash_test.go => internal/crypto/hash_test.go (88%) create mode 100644 internal/helpers/helpers.go rename legacy.go => internal/legacy/legacy.go (94%) rename legacy_test.go => internal/legacy/legacy_test.go (97%) rename log.go => internal/logger/log.go (81%) rename log_test.go => internal/logger/log_test.go (94%) create mode 100644 internal/metadata/metadata.go rename connstring.go => internal/postgresql/connstring.go (98%) rename connstring_test.go => internal/postgresql/connstring_test.go (99%) rename sql.go => internal/postgresql/sql.go (81%) rename sql_test.go => internal/postgresql/sql_test.go (90%) rename lock.go => internal/storage/lock.go (83%) rename lock_test.go => internal/storage/lock_test.go (86%) rename lock_win.go => internal/storage/lock_win.go (85%) rename purge.go => internal/storage/purge.go (75%) rename purge_test.go => internal/storage/purge_test.go (87%) rename upload.go => internal/storage/upload.go (83%) rename upload_test.go => internal/storage/upload_test.go (93%) diff --git a/helpers.go b/helpers.go deleted file mode 100644 index 35f8f8b..0000000 --- a/helpers.go +++ /dev/null @@ -1,10 +0,0 @@ -package main - -import "io" - -func WrappedClose(c io.Closer, err *error) { - cErr := c.Close() - if cErr != nil && *err == nil { - *err = cErr - } -} diff --git a/internal/command/command.go b/internal/command/command.go new file mode 100644 index 0000000..ea631c0 --- /dev/null +++ b/internal/command/command.go @@ -0,0 +1,53 @@ +package command + +import ( + "fmt" + "os/exec" + "path/filepath" + "runtime" + + "github.com/orgrim/pg_back/internal/logger" +) + +func ExecPath(binDir, prog string) string { + binFile := prog + if runtime.GOOS == "windows" { + binFile = fmt.Sprintf("%s.exe", prog) + } + + if binDir != "" { + return filepath.Join(binDir, binFile) + } + + return binFile +} + +func PgToolVersion(logger *logger.LevelLog, binDir, tool string) int { + vs, err := exec.Command(ExecPath(binDir, tool), "--version").Output() + if err != nil { + logger.Warnf("failed to retrieve version of %s: %s", tool, err) + return 0 + } + + var maj, min, rev, numver int + n, _ := fmt.Sscanf(string(vs), tool+" (PostgreSQL) %d.%d.%d", &maj, &min, &rev) + + switch n { + case 3: + // Before PostgreSQL 10, the format si MAJ.MIN.REV + numver = (maj*100+min)*100 + rev + case 2: + // From PostgreSQL 10, the format si MAJ.REV, so the rev ends + // up in min with the scan + numver = maj*10000 + min + default: + // We have the special case of the development version, where the + // format is MAJdevel + fmt.Sscanf(string(vs), tool+" (PostgreSQL) %ddevel", &maj) + numver = maj * 10000 + } + + logger.Verboseln(tool, "version is:", numver) + + return numver +} diff --git a/internal/command/command_test.go b/internal/command/command_test.go new file mode 100644 index 0000000..b9bc181 --- /dev/null +++ b/internal/command/command_test.go @@ -0,0 +1,45 @@ +package command + +import ( + "fmt" + "runtime" + "testing" +) + +func TestExecPath(t *testing.T) { + var tests []struct { + dir string + prog string + want string + } + + if runtime.GOOS != "windows" { + tests = []struct { + dir string + prog string + want string + }{ + {"", "pg_dump", "pg_dump"}, + {"/path/to/bin", "prog", "/path/to/bin/prog"}, + } + } else { + tests = []struct { + dir string + prog string + want string + }{ + {"", "pg_dump", "pg_dump.exe"}, + {"C:\\path\\to\\bin", "prog", "C:\\path\\to\\bin\\prog.exe"}, + } + } + + for i, st := range tests { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + binDir := st.dir + got := ExecPath(binDir, st.prog) + if got != st.want { + t.Errorf("expected %q, got %q\n", st.want, got) + } + }) + } +} diff --git a/hook.go b/internal/command/hook.go similarity index 73% rename from hook.go rename to internal/command/hook.go index e776782..4e8a3df 100644 --- a/hook.go +++ b/internal/command/hook.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package command import ( "fmt" @@ -32,14 +32,15 @@ import ( "strings" "github.com/anmitsu/go-shlex" + "github.com/orgrim/pg_back/internal/logger" ) -func hookCommand(cmd string, logPrefix string) error { +func hookCommand(logger *logger.LevelLog, cmd string, logPrefix string) error { if cmd == "" { return fmt.Errorf("unable to run an empty command") } - l.Verboseln("parsing hook command") + logger.Verboseln("parsing hook command") words, err := shlex.Split(cmd, true) if err != nil { return fmt.Errorf("unable to parse hook command: %s", err) @@ -48,13 +49,13 @@ func hookCommand(cmd string, logPrefix string) error { prog := words[0] args := words[1:] - l.Verboseln("running:", prog, args) + logger.Verboseln("running:", prog, args) c := exec.Command(prog, args...) stdoutStderr, err := c.CombinedOutput() if err != nil { for line := range strings.SplitSeq(string(stdoutStderr), "\n") { if line != "" { - l.Errorln(logPrefix, line) + logger.Errorln(logPrefix, line) } } return err @@ -62,29 +63,29 @@ func hookCommand(cmd string, logPrefix string) error { if len(stdoutStderr) > 0 { for line := range strings.SplitSeq(string(stdoutStderr), "\n") { if line != "" { - l.Infoln(logPrefix, line) + logger.Infoln(logPrefix, line) } } } return nil } -func preBackupHook(cmd string) error { +func PreBackupHook(logger *logger.LevelLog, cmd string) error { if cmd != "" { - l.Infoln("running pre-backup command:", cmd) - if err := hookCommand(cmd, "pre-backup:"); err != nil { - l.Fatalln("hook command failed:", err) + logger.Infoln("running pre-backup command:", cmd) + if err := hookCommand(logger, cmd, "pre-backup:"); err != nil { + logger.Fatalln("hook command failed:", err) return err } } return nil } -func postBackupHook(cmd string) { +func PostBackupHook(logger *logger.LevelLog, cmd string) { if cmd != "" { - l.Infoln("running post-backup command:", cmd) - if err := hookCommand(cmd, "post-backup:"); err != nil { - l.Fatalln("hook command failed:", err) + logger.Infoln("running post-backup command:", cmd) + if err := hookCommand(logger, cmd, "post-backup:"); err != nil { + logger.Fatalln("hook command failed:", err) os.Exit(1) } } diff --git a/hook_test.go b/internal/command/hook_test.go similarity index 88% rename from hook_test.go rename to internal/command/hook_test.go index 84fd3a5..968150b 100644 --- a/hook_test.go +++ b/internal/command/hook_test.go @@ -25,7 +25,7 @@ //go:build !windows -package main +package command import ( "bytes" @@ -35,6 +35,8 @@ import ( "regexp" "strings" "testing" + + "github.com/orgrim/pg_back/internal/logger" ) func TestHookCommand(t *testing.T) { @@ -58,14 +60,14 @@ func TestHookCommand(t *testing.T) { `^\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2} ERROR: test: test\n\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2} ERROR: exit status 1\n$`, }, } - + logger := logger.NewLevelLog() for i, subt := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { buf := new(bytes.Buffer) - l.logger.SetOutput(buf) + logger.Logger.SetOutput(buf) - if err := hookCommand(subt.cmd, "test:"); err != nil { - l.Errorln(err) + if err := hookCommand(logger, subt.cmd, "test:"); err != nil { + logger.Errorln(err) } lines := strings.ReplaceAll(buf.String(), "\r", "") @@ -76,7 +78,7 @@ func TestHookCommand(t *testing.T) { if !matched { t.Errorf("expected a match of %q, got %q\n", subt.re, lines) } - l.logger.SetOutput(os.Stderr) + logger.Logger.SetOutput(os.Stderr) }) } } @@ -99,12 +101,13 @@ func TestPreBackupHook(t *testing.T) { true, }, } + logger := logger.NewLevelLog() for i, subt := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { buf := new(bytes.Buffer) - l.logger.SetOutput(buf) + logger.Logger.SetOutput(buf) - if err := preBackupHook(subt.cmd); err != nil { + if err := PreBackupHook(logger, subt.cmd); err != nil { if !subt.fails { t.Errorf("function test must not fail, got error: %q\n", err) } @@ -122,15 +125,16 @@ func TestPreBackupHook(t *testing.T) { if !matched { t.Errorf("expected a match of %q, got %q\n", subt.re, lines) } - l.logger.SetOutput(os.Stderr) + logger.Logger.SetOutput(os.Stderr) }) } } func TestPostBackupHook(t *testing.T) { + logger := logger.NewLevelLog() t.Run("0", func(t *testing.T) { if os.Getenv("_TEST_HOOK") == "1" { - postBackupHook("false") + PostBackupHook(logger, "false") return } cmd := exec.Command(os.Args[0], "-test.run=TestPostBackupHook") @@ -144,8 +148,8 @@ func TestPostBackupHook(t *testing.T) { t.Run("1", func(t *testing.T) { buf := new(bytes.Buffer) - l.logger.SetOutput(buf) - postBackupHook("") + logger.Logger.SetOutput(buf) + PostBackupHook(logger, "") lines := buf.String() if len(lines) != 0 { t.Errorf("did not expect any output, got %q\n", lines) diff --git a/config.go b/internal/config/config.go similarity index 94% rename from config.go rename to internal/config/config.go index bfd2979..5ea2716 100644 --- a/config.go +++ b/internal/config/config.go @@ -23,10 +23,9 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package config import ( - _ "embed" "errors" "fmt" "os" @@ -36,17 +35,54 @@ import ( "time" "github.com/anmitsu/go-shlex" + "github.com/orgrim/pg_back/internal/logger" + "github.com/orgrim/pg_back/internal/metadata" "github.com/spf13/pflag" "gopkg.in/ini.v1" ) var defaultCfgFile = "/etc/pg_back/pg_back.conf" -//go:embed pg_back.conf -var defaultCfg string +type DbOpts struct { + // Format of the dump + Format rune + + // Algorithm of the checksum of the file, "none" is used to + // disable checksuming + SumAlgo string + + // Number of parallel jobs for directory format + Jobs int + + // Compression level for compressed formats, -1 means the default + CompressLevel int + + // Purge configuration + PurgeInterval time.Duration + PurgeKeep int + + // Limit schemas + Schemas []string + ExcludedSchemas []string + + // Limit dumped tables + Tables []string + ExcludedTables []string + + // Other pg_dump options to use + PgDumpOpts []string + + // Whether to force the dump of large objects or not with pg_dump -b or + // -B, or let pg_dump use its default. 0 means default, 1 include + // blobs, 2 exclude blobs. + WithBlobs int + + // Connection user for that database + Username string +} // options struct holds command line and configuration file options -type options struct { +type Options struct { NoConfigFile bool BinDirectory string Directory string @@ -69,7 +105,7 @@ type options struct { PreHook string PostHook string PgDumpOpts []string - PerDbOpts map[string]*dbOpts + PerDbOpts map[string]*DbOpts CfgFile string TimeFormat string Verbose bool @@ -123,13 +159,13 @@ type options struct { AzureEndpoint string } -func defaultOptions() options { +func DefaultOptions() Options { timeFormat := time.RFC3339 if runtime.GOOS == "windows" { timeFormat = "2006-01-02_15-04-05" } - return options{ + return Options{ NoConfigFile: false, Directory: "/var/backups/postgresql", Mode: 0o600, @@ -154,14 +190,14 @@ func defaultOptions() options { // parseCliResult is use to handle utility flags like help, version, that make // the program end early -type parseCliResult struct { +type ParseCliResult struct { ShowHelp bool ShowVersion bool LegacyConfig string ShowConfig bool } -func (*parseCliResult) Error() string { +func (*ParseCliResult) Error() string { return "please exit now" } @@ -169,12 +205,12 @@ func validateMode(s string) (int, error) { if (strings.HasPrefix(s, "0") && len(s) <= 5) || (strings.HasPrefix(s, "-")) { mode, err := strconv.ParseInt(s, 0, 32) if err != nil { - return 0, fmt.Errorf("invalid permission %q", s) + return 0, fmt.Errorf("Invalid permission %q", s) } return int(mode), nil } return 0, fmt.Errorf( - "invalid permission %q, must be octal (start by 0 and max 5 digits) number or negative", + "Invalid permission %q, must be octal (start by 0 and max 5 digits) number or negative", s, ) } @@ -199,11 +235,11 @@ func validatePurgeKeepValue(k string) (int, error) { keep, err := strconv.ParseInt(k, 10, 0) if err != nil { // return -1 too when the input is not convertible to an int - return -1, fmt.Errorf("invalid input for keep: %w", err) + return -1, fmt.Errorf("Invalid input for keep: %w", err) } if keep < 0 { - return -1, fmt.Errorf("invalid input for keep: negative value: %d", keep) + return -1, fmt.Errorf("Invalid input for keep: negative value: %d", keep) } return int(keep), nil @@ -212,7 +248,7 @@ func validatePurgeKeepValue(k string) (int, error) { func validatePurgeTimeLimitValue(i string) (time.Duration, error) { if days, err := strconv.ParseInt(i, 10, 0); err != nil { if errors.Is(err, strconv.ErrRange) { - return 0, errors.New("invalid input for purge interval, number too big") + return 0, errors.New("Invalid input for purge interval, number too big") } } else { return time.Duration(-days*24) * time.Hour, nil @@ -268,11 +304,11 @@ func validateDirectory(s string) error { return nil } -func parseCli(args []string) (options, []string, error) { +func ParseCli(args []string, defaultCfg string) (Options, []string, error) { var format, mode, purgeKeep, purgeInterval string - opts := defaultOptions() - pce := &parseCliResult{} + opts := DefaultOptions() + pce := &ParseCliResult{} pflag.Usage = func() { fmt.Fprintf(os.Stderr, "pg_back dumps some PostgreSQL databases\n\n") @@ -546,9 +582,7 @@ func parseCli(args []string) (options, []string, error) { // Do not use the default pflag.Parse() that use os.Args[1:], // but pass it explicitly so that unit-tests can feed any set // of flags - if err := pflag.CommandLine.Parse(args); err != nil { - return opts, []string{}, err - } + pflag.CommandLine.Parse(args) // Record the list of flags set on the command line to allow // overriding the configuration later, if an alternate @@ -593,7 +627,7 @@ func parseCli(args []string) (options, []string, error) { } if pce.ShowVersion { - fmt.Printf("pg_back version %v\n", version) + fmt.Printf("pg_back version %v\n", metadata.Version) return opts, changed, pce } @@ -790,10 +824,10 @@ gkLoop: return nil } -func loadConfigurationFile(path string) (options, error) { +func LoadConfigurationFile(path string, l *logger.LevelLog) (Options, error) { var format, mode, purgeKeep, purgeInterval string - opts := defaultOptions() + opts := DefaultOptions() cfg, err := ini.Load(path) if err != nil { @@ -803,7 +837,7 @@ func loadConfigurationFile(path string) (options, error) { return opts, nil } - return opts, fmt.Errorf("could load configuration file: %v", err) + return opts, fmt.Errorf("Could load configuration file: %v", err) } if err := validateConfigurationFile(cfg); err != nil { @@ -959,7 +993,7 @@ func loadConfigurationFile(path string) (options, error) { // Process all sections with database specific configuration, // fallback on the values of the global section subs := cfg.Sections() - opts.PerDbOpts = make(map[string]*dbOpts, len(subs)) + opts.PerDbOpts = make(map[string]*DbOpts, len(subs)) for _, s := range subs { if s.Name() == ini.DefaultSection { @@ -968,7 +1002,7 @@ func loadConfigurationFile(path string) (options, error) { var dbFormat, dbPurgeInterval, dbPurgeKeep string - o := dbOpts{} + o := DbOpts{} dbFormat = s.Key("format").MustString(format) o.Jobs = s.Key("parallel_backup_jobs").MustInt(opts.DirJobs) o.CompressLevel = s.Key("compress_level").MustInt(opts.CompressLevel) @@ -1030,7 +1064,7 @@ func loadConfigurationFile(path string) (options, error) { return opts, nil } -func mergeCliAndConfigOptions(cliOpts options, configOpts options, onCli []string) options { +func MergeCliAndConfigOptions(cliOpts Options, configOpts Options, onCli []string) Options { opts := configOpts // Command line values take precedence on everything, including per diff --git a/config_test.go b/internal/config/config_test.go similarity index 96% rename from config_test.go rename to internal/config/config_test.go index d4f2c3f..5c20eed 100644 --- a/config_test.go +++ b/internal/config/config_test.go @@ -23,12 +23,11 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package config import ( "errors" "fmt" - "io" "os" "runtime" "testing" @@ -36,6 +35,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/orgrim/pg_back/internal/logger" "github.com/spf13/pflag" "gopkg.in/ini.v1" ) @@ -75,7 +75,6 @@ func TestValidateMode(t *testing.T) { {"-8170", -8170, false}, // valid and mean do nothing (useful when using umask) } - l.logger.SetOutput(io.Discard) for i, st := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { got, err := validateMode(st.give) @@ -103,7 +102,6 @@ func TestValidatePurgeKeepValue(t *testing.T) { {"-10", -1, true}, } - l.logger.SetOutput(io.Discard) for i, st := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { got, err := validatePurgeKeepValue(st.give) @@ -210,7 +208,7 @@ func TestDefaultOptions(t *testing.T) { timeFormat = "2006-01-02_15-04-05" } - var want = options{ + var want = Options{ Directory: "/var/backups/postgresql", Mode: 0o600, Format: 'c', @@ -231,7 +229,7 @@ func TestDefaultOptions(t *testing.T) { B2ConcurrentConnections: 5, } - got := defaultOptions() + got := DefaultOptions() if diff := cmp.Diff(want, got, cmpopts.EquateEmpty()); diff != "" { t.Errorf("DefaultOptions() mismatch (-want +got):\n%s", diff) @@ -245,10 +243,10 @@ func TestParseCli(t *testing.T) { } var ( - defaults = defaultOptions() + defaults = DefaultOptions() tests = []struct { args []string - want options + want Options help bool version bool err string @@ -256,7 +254,7 @@ func TestParseCli(t *testing.T) { }{ { []string{"-b", "test", "-Z", "2", "a", "b"}, - options{ + Options{ Directory: "test", Mode: 0o600, Dbnames: []string{"a", "b"}, @@ -284,7 +282,7 @@ func TestParseCli(t *testing.T) { }, { []string{"-t", "--without-templates"}, - options{ + Options{ Directory: "/var/backups/postgresql", Mode: 0o600, WithTemplates: false, @@ -336,7 +334,7 @@ func TestParseCli(t *testing.T) { }, { []string{"--upload", "wrong"}, - options{ + Options{ Directory: "/var/backups/postgresql", Mode: 0o600, Format: 'c', @@ -365,7 +363,7 @@ func TestParseCli(t *testing.T) { }, { []string{"--download", "wrong"}, - options{ + Options{ Directory: "/var/backups/postgresql", Mode: 0o600, Format: 'c', @@ -402,7 +400,7 @@ func TestParseCli(t *testing.T) { }, { []string{"--cipher-pass", "mypass"}, - options{ + Options{ Directory: "/var/backups/postgresql", Mode: 0o600, Format: 'c', @@ -431,7 +429,7 @@ func TestParseCli(t *testing.T) { }, { []string{"--cipher-private-key", "mykey"}, - options{ + Options{ Directory: "/var/backups/postgresql", Mode: 0o600, Format: 'c', @@ -460,7 +458,7 @@ func TestParseCli(t *testing.T) { }, { []string{"--cipher-public-key", "fakepubkey"}, - options{ + Options{ Directory: "/var/backups/postgresql", Mode: 0o600, Format: 'c', @@ -505,7 +503,7 @@ func TestParseCli(t *testing.T) { }, { []string{"--b2-concurrent-connections", "0"}, - defaultOptions(), + DefaultOptions(), false, false, "b2 concurrent connections must be more than 0 (current 0)", @@ -513,7 +511,7 @@ func TestParseCli(t *testing.T) { }, { []string{"--delete-uploaded", "yes"}, - options{ + Options{ Directory: "/var/backups/postgresql", Mode: 0o600, Format: 'c', @@ -543,7 +541,7 @@ func TestParseCli(t *testing.T) { }, { []string{"--delete-uploaded", "true"}, - options{ + Options{ Directory: "/var/backups/postgresql", Mode: 0o600, Format: 'c', @@ -577,7 +575,7 @@ func TestParseCli(t *testing.T) { for i, st := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { var ( - opts options + opts Options err error ) @@ -591,14 +589,14 @@ func TestParseCli(t *testing.T) { _, w, _ := os.Pipe() os.Stderr = w os.Stdout = w - opts, _, err = parseCli(st.args) + opts, _, err = ParseCli(st.args, "") os.Stderr = oldStderr os.Stdout = oldStdout } else { - opts, _, err = parseCli(st.args) + opts, _, err = ParseCli(st.args, "") } - var errVal *parseCliResult + var errVal *ParseCliResult if err != nil { if errors.As(err, &errVal) { @@ -640,12 +638,12 @@ func TestLoadConfigurationFile(t *testing.T) { var tests = []struct { params []string fail bool - want options + want Options }{ { []string{"backup_directory = test", "port = 5433", "backup_file_mode = 0700"}, false, - options{ + Options{ Directory: "test", Mode: 0o700, Port: 5433, @@ -675,7 +673,7 @@ func TestLoadConfigurationFile(t *testing.T) { "backup_file_mode = 0400", }, false, - options{ + Options{ Directory: "test", Mode: 0o400, Dbnames: []string{"a", "b", "postgres"}, @@ -700,7 +698,7 @@ func TestLoadConfigurationFile(t *testing.T) { { []string{"timestamp_format = rfc3339"}, false, - options{ + Options{ Directory: "/var/backups/postgresql", Mode: 0o600, Format: 'c', @@ -724,7 +722,7 @@ func TestLoadConfigurationFile(t *testing.T) { { []string{"timestamp_format = legacy"}, false, - options{ + Options{ Directory: "/var/backups/postgresql", Mode: 0o600, Format: 'c', @@ -748,12 +746,12 @@ func TestLoadConfigurationFile(t *testing.T) { { []string{"timestamp_format = wrong"}, true, - defaultOptions(), + DefaultOptions(), }, { // with an error output is the default []string{}, true, - defaultOptions(), + DefaultOptions(), }, { []string{ @@ -766,7 +764,7 @@ func TestLoadConfigurationFile(t *testing.T) { "compress_level = 2", }, false, - options{ + Options{ Directory: "test", Mode: 0o600, Format: 'c', @@ -780,7 +778,7 @@ func TestLoadConfigurationFile(t *testing.T) { CfgFile: "/etc/pg_back/pg_back.conf", TimeFormat: timeFormat, PgDumpOpts: []string{"-O", "-x"}, - PerDbOpts: map[string]*dbOpts{"db": &dbOpts{ + PerDbOpts: map[string]*DbOpts{"db": &DbOpts{ Format: 'c', SumAlgo: "none", Jobs: 2, @@ -811,7 +809,7 @@ func TestLoadConfigurationFile(t *testing.T) { "with_blobs = false", }, false, - options{ + Options{ Directory: "test", Mode: 0o600, Format: 'c', @@ -825,7 +823,7 @@ func TestLoadConfigurationFile(t *testing.T) { CfgFile: "/etc/pg_back/pg_back.conf", TimeFormat: timeFormat, PgDumpOpts: []string{"-O", "-x"}, - PerDbOpts: map[string]*dbOpts{"db": &dbOpts{ + PerDbOpts: map[string]*DbOpts{"db": &DbOpts{ Format: 'c', SumAlgo: "none", CompressLevel: 3, @@ -846,7 +844,7 @@ func TestLoadConfigurationFile(t *testing.T) { { []string{"b2_concurrent_connections = 0"}, true, - defaultOptions(), + DefaultOptions(), }, } @@ -878,8 +876,8 @@ func TestLoadConfigurationFile(t *testing.T) { defer remove() } - var got options - got, err = loadConfigurationFile(f.Name()) + var got Options + got, err = LoadConfigurationFile(f.Name(), logger.NewLevelLog()) if err != nil && !st.fail { t.Errorf("expected an error: %s", err) } @@ -896,7 +894,7 @@ func TestMergeCliAndConfigoptions(t *testing.T) { timeFormat = "2006-01-02_15-04-05" } - want := options{ + want := Options{ BinDirectory: "/bin", Directory: "test", Mode: 0o600, @@ -950,14 +948,14 @@ func TestMergeCliAndConfigoptions(t *testing.T) { "dbname", } - got := mergeCliAndConfigOptions(want, defaultOptions(), cliOptList) + got := MergeCliAndConfigOptions(want, DefaultOptions(), cliOptList) if diff := cmp.Diff(want, got, cmpopts.EquateEmpty()); diff != "" { t.Errorf("mergeCliAndConfigOptions() mismatch (-want +got):\n%s", diff) } } func TestError(t *testing.T) { - err := &parseCliResult{} + err := &ParseCliResult{} s := fmt.Sprintf("%s", err) if s != "please exit now" { diff --git a/crypto.go b/internal/crypto/crypto.go similarity index 81% rename from crypto.go rename to internal/crypto/crypto.go index 12d36c3..e12f20b 100644 --- a/crypto.go +++ b/internal/crypto/crypto.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package crypto import ( "errors" @@ -34,9 +34,31 @@ import ( "strings" "filippo.io/age" + "github.com/orgrim/pg_back/internal/helpers" + "github.com/orgrim/pg_back/internal/logger" ) -func ageEncrypt(src io.Reader, dst io.Writer, params encryptParams) error { +type EncryptParams struct { + Logger *logger.LevelLog + + // Encrypt with a passphrase + Passphrase string + + // Encrypt with an AGE public key encoded in Bech32 + PublicKey string +} + +type DecryptParams struct { + Logger *logger.LevelLog + + // A passphrase to use for decryption + Passphrase string + + // An AGE private key encoded in Bech32 + PrivateKey string +} + +func (params *EncryptParams) ageEncrypt(src io.Reader, dst io.Writer) error { if params.PublicKey != "" { return ageEncryptPublicKey(src, dst, params.PublicKey) } @@ -85,7 +107,7 @@ func ageEncryptInternal(src io.Reader, dst io.Writer, recipient age.Recipient) e return nil } -func ageDecrypt(src io.Reader, dst io.Writer, params decryptParams) error { +func (params *DecryptParams) ageDecrypt(src io.Reader, dst io.Writer) error { if params.PrivateKey != "" { return ageDecryptPrivateKey(src, dst, params.PrivateKey) } @@ -132,10 +154,9 @@ func ageDecryptInternal(src io.Reader, dst io.Writer, identity age.Identity) err return nil } -func encryptFile( +func (params *EncryptParams) EncryptFile( path string, mode int, - params encryptParams, keep bool, ) (_ []string, err error) { encrypted := make([]string, 0) @@ -146,30 +167,30 @@ func encryptFile( } if i.IsDir() { - l.Verboseln("dump is a directory, encrypting all files inside") + params.Logger.Verboseln("dump is a directory, encrypting all files inside") err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error { if err != nil { return err } if info.Mode().IsRegular() { - l.Verboseln("encrypting:", path) + params.Logger.Verboseln("encrypting:", path) src, err := os.Open(path) if err != nil { - l.Errorln(err) + params.Logger.Errorln(err) return err } - defer WrappedClose(src, &err) + defer helpers.WrappedClose(src, &err) dstFile := fmt.Sprintf("%s.age", path) dst, err := os.Create(dstFile) if err != nil { - l.Errorln(err) + params.Logger.Errorln(err) return err } - defer WrappedClose(dst, &err) + defer helpers.WrappedClose(dst, &err) - if err := ageEncrypt(src, dst, params); err != nil { + if err := params.ageEncrypt(src, dst); err != nil { // explicitly ignore error on close and remove dst.Close() //nolint:errcheck os.Remove(dstFile) //nolint:errcheck @@ -187,7 +208,7 @@ func encryptFile( } if !keep { - l.Verboseln("removing source file:", path) + params.Logger.Verboseln("removing source file:", path) if err := src.Close(); err != nil { return fmt.Errorf("could not close %s: %w", path, err) } @@ -203,25 +224,25 @@ func encryptFile( return encrypted, fmt.Errorf("error walking the path %q: %v", path, err) } } else { - l.Verboseln("encrypting:", path) + params.Logger.Verboseln("encrypting:", path) src, err := os.Open(path) if err != nil { - l.Errorln(err) + params.Logger.Errorln(err) return encrypted, err } - defer WrappedClose(src, &err) + defer helpers.WrappedClose(src, &err) dstFile := fmt.Sprintf("%s.age", path) dst, err := os.Create(dstFile) if err != nil { - l.Errorln(err) + params.Logger.Errorln(err) return encrypted, err } - defer WrappedClose(dst, &err) + defer helpers.WrappedClose(dst, &err) - if err := ageEncrypt(src, dst, params); err != nil { + if err := params.ageEncrypt(src, dst); err != nil { // explicitly ignore error here, we already return an error dst.Close() //nolint:errcheck os.Remove(dstFile) //nolint:errcheck @@ -235,7 +256,7 @@ func encryptFile( } } if !keep { - l.Verboseln("removing source file:", path) + params.Logger.Verboseln("removing source file:", path) if err := src.Close(); err != nil { return encrypted, fmt.Errorf("could not close %s: %w", path, err) } @@ -248,15 +269,15 @@ func encryptFile( return encrypted, err } -func decryptFile(path string, params decryptParams) (err error) { - l.Infoln("decrypting", path) +func (params *DecryptParams) DecryptFile(path string) (err error) { + params.Logger.Infoln("decrypting", path) src, err := os.Open(path) if err != nil { return err } - defer WrappedClose(src, &err) + defer helpers.WrappedClose(src, &err) dstFile := strings.TrimSuffix(path, ".age") dst, err := os.Create(dstFile) @@ -264,9 +285,9 @@ func decryptFile(path string, params decryptParams) (err error) { return err } - defer WrappedClose(dst, &err) + defer helpers.WrappedClose(dst, &err) - if err := ageDecrypt(src, dst, params); err != nil { + if err := params.ageDecrypt(src, dst); err != nil { // explicitly ignore error on close and remove dst.Close() //nolint:errcheck os.Remove(dstFile) //nolint:errcheck diff --git a/crypto_test.go b/internal/crypto/crypto_test.go similarity index 84% rename from crypto_test.go rename to internal/crypto/crypto_test.go index 1f9fd05..7cd3356 100644 --- a/crypto_test.go +++ b/internal/crypto/crypto_test.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package crypto import ( "bytes" @@ -44,9 +44,9 @@ func TestAgeEncrypt_NilParams_Failure(t *testing.T) { content := "to be encrypted" reader := strings.NewReader(content) writer := &bytes.Buffer{} - params := encryptParams{} + params := EncryptParams{} - err := ageEncrypt(reader, writer, params) + err := params.ageEncrypt(reader, writer) if err == nil { t.Errorf("Expected empty encryption params to fail") } @@ -56,9 +56,9 @@ func TestAgeDecrypt_NilParams_Failure(t *testing.T) { content := "to be encrypted" reader := strings.NewReader(content) writer := &bytes.Buffer{} - params := decryptParams{} + params := DecryptParams{} - err := ageDecrypt(reader, writer, params) + err := params.ageDecrypt(reader, writer) if err == nil { t.Errorf("Expected empty encryption params to fail") } @@ -71,9 +71,9 @@ func TestAgeDecrypt_InvalidPrivateKey_Failure(t *testing.T) { } reader := bytes.NewReader(encrypted) writer := &bytes.Buffer{} - params := decryptParams{PrivateKey: TEST_PUBLIC_KEY} + params := DecryptParams{PrivateKey: TEST_PUBLIC_KEY} - err = ageDecrypt(reader, writer, params) + err = params.ageDecrypt(reader, writer) if err == nil { t.Errorf("Expected invalid private key to fail") } @@ -83,9 +83,9 @@ func TestAgeDecrypt_InvalidPublicKey_Failure(t *testing.T) { content := "to be encrypted" reader := strings.NewReader(content) writer := &bytes.Buffer{} - params := decryptParams{PrivateKey: TEST_PRIVATE_KEY} + params := DecryptParams{PrivateKey: TEST_PRIVATE_KEY} - err := ageDecrypt(reader, writer, params) + err := params.ageDecrypt(reader, writer) if err == nil { t.Errorf("Expected invalid public key to fail") } @@ -123,11 +123,11 @@ func TestAgeDecrypt_Golden_Success(t *testing.T) { } reader := bytes.NewReader(encrypted) writer := &bytes.Buffer{} - params := decryptParams{ + params := DecryptParams{ PrivateKey: TEST_PRIVATE_KEY, } - err = ageDecrypt(reader, writer, params) + err = params.ageDecrypt(reader, writer) if err != nil { t.Fatalf("could not decrypt golden message: %v", err) } @@ -146,9 +146,9 @@ func TestAgeEncrypt_PublicKey_Loopback_Success(t *testing.T) { content := "to be encrypted" reader := strings.NewReader(content) writer := &bytes.Buffer{} - params := encryptParams{PublicKey: identity.Recipient().String()} + params := EncryptParams{PublicKey: identity.Recipient().String()} - err = ageEncrypt(reader, writer, params) + err = params.ageEncrypt(reader, writer) if err != nil { t.Errorf("Unexpected error when encrypting") } @@ -160,8 +160,8 @@ func TestAgeEncrypt_PublicKey_Loopback_Success(t *testing.T) { reader = strings.NewReader(ciphertext) writer = &bytes.Buffer{} - decryptParams := decryptParams{PrivateKey: identity.String()} - err = ageDecrypt(reader, writer, decryptParams) + decryptParams := DecryptParams{PrivateKey: identity.String()} + err = decryptParams.ageDecrypt(reader, writer) if err != nil { t.Errorf("Unexpected error when decrypting") } @@ -175,9 +175,9 @@ func TestAgeEncrypt_Passphrase_Loopback_Success(t *testing.T) { content := "to be encrypted" reader := strings.NewReader(content) writer := &bytes.Buffer{} - params := encryptParams{Passphrase: "supersecret"} + params := EncryptParams{Passphrase: "supersecret"} - err := ageEncrypt(reader, writer, params) + err := params.ageEncrypt(reader, writer) if err != nil { t.Errorf("Unexpected error when encrypting") } @@ -189,8 +189,8 @@ func TestAgeEncrypt_Passphrase_Loopback_Success(t *testing.T) { reader = strings.NewReader(ciphertext) writer = &bytes.Buffer{} - decryptParams := decryptParams{Passphrase: "supersecret"} - err = ageDecrypt(reader, writer, decryptParams) + decryptParams := DecryptParams{Passphrase: "supersecret"} + err = decryptParams.ageDecrypt(reader, writer) if err != nil { t.Errorf("Unexpected error when decrypting") } @@ -209,9 +209,9 @@ func TestAgeEncrypt_WrongPrivateKey_Loopback_Failure(t *testing.T) { content := "to be encrypted" reader := strings.NewReader(content) writer := &bytes.Buffer{} - params := encryptParams{PublicKey: identity.Recipient().String()} + params := EncryptParams{PublicKey: identity.Recipient().String()} - err = ageEncrypt(reader, writer, params) + err = params.ageEncrypt(reader, writer) if err != nil { t.Errorf("Unexpected error when encrypting") } @@ -228,8 +228,8 @@ func TestAgeEncrypt_WrongPrivateKey_Loopback_Failure(t *testing.T) { reader = strings.NewReader(ciphertext) writer = &bytes.Buffer{} - decryptParams := decryptParams{PrivateKey: wrongIdentity.String()} - err = ageDecrypt(reader, writer, decryptParams) + decryptParams := DecryptParams{PrivateKey: wrongIdentity.String()} + err = decryptParams.ageDecrypt(reader, writer) if err == nil { t.Errorf("Decryption should have failed") } @@ -239,9 +239,9 @@ func TestAgeEncrypt_WrongPassphrase_Loopback_Failure(t *testing.T) { content := "to be encrypted" reader := strings.NewReader(content) writer := &bytes.Buffer{} - params := encryptParams{Passphrase: "supersecret"} + params := EncryptParams{Passphrase: "supersecret"} - err := ageEncrypt(reader, writer, params) + err := params.ageEncrypt(reader, writer) if err != nil { t.Errorf("Unexpected error when encrypting") } @@ -253,8 +253,8 @@ func TestAgeEncrypt_WrongPassphrase_Loopback_Failure(t *testing.T) { reader = strings.NewReader(ciphertext) writer = &bytes.Buffer{} - decryptParams := decryptParams{Passphrase: "wrong"} - err = ageDecrypt(reader, writer, decryptParams) + decryptParams := DecryptParams{Passphrase: "wrong"} + err = decryptParams.ageDecrypt(reader, writer) if err == nil { t.Fatalf("Decryption should have failed") } diff --git a/hash.go b/internal/crypto/hash.go similarity index 84% rename from hash.go rename to internal/crypto/hash.go index 82a12a6..390bd1b 100644 --- a/hash.go +++ b/internal/crypto/hash.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package crypto import ( "crypto/sha1" @@ -34,6 +34,9 @@ import ( "io" "os" "path/filepath" + + "github.com/orgrim/pg_back/internal/helpers" + "github.com/orgrim/pg_back/internal/logger" ) func computeChecksum(path string, h hash.Hash) (_ string, err error) { @@ -43,7 +46,7 @@ func computeChecksum(path string, h hash.Hash) (_ string, err error) { if err != nil { return "", err } - defer WrappedClose(f, &err) + defer helpers.WrappedClose(f, &err) if _, err := io.Copy(h, f); err != nil { return "", err @@ -51,7 +54,12 @@ func computeChecksum(path string, h hash.Hash) (_ string, err error) { return string(h.Sum(nil)), err } -func checksumFile(path string, mode int, algo string) (_ string, err error) { +func ChecksumFile( + logger *logger.LevelLog, + path string, + mode int, + algo string, +) (_ string, err error) { var h hash.Hash switch algo { @@ -77,22 +85,22 @@ func checksumFile(path string, mode int, algo string) (_ string, err error) { } sumFile := fmt.Sprintf("%s.%s", path, algo) - l.Verbosef("create checksum file: %s", sumFile) + logger.Verbosef("create checksum file: %s", sumFile) o, err := os.Create(sumFile) if err != nil { - l.Errorln(err) + logger.Errorln(err) return "", err } - defer WrappedClose(o, &err) + defer helpers.WrappedClose(o, &err) if i.IsDir() { - l.Verboseln("dump is a directory, checksumming all file inside") + logger.Verboseln("dump is a directory, checksumming all file inside") err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error { if err != nil { return err } if info.Mode().IsRegular() { - l.Verboseln("computing checksum of:", path) + logger.Verboseln("computing checksum of:", path) r, cerr := computeChecksum(path, h) if cerr != nil { return fmt.Errorf("could not checksum %s: %s", path, cerr) @@ -112,14 +120,14 @@ func checksumFile(path string, mode int, algo string) (_ string, err error) { // Open the file and use io.Copy to feed the data to the hash, // like in the example of the doc, then write the result to a // file that the standard shaXXXsum tools can understand - l.Verboseln("computing checksum of:", path) + logger.Verboseln("computing checksum of:", path) r, _ := computeChecksum(path, h) if _, err := fmt.Fprintf(o, "%x %s\n", r, path); err != nil { return "", fmt.Errorf("could not write checksum to %s: %w", path, err) } } - l.Verboseln("computing checksum with MODE", mode, path) + logger.Verboseln("computing checksum with MODE", mode, path) if mode > 0 { if err := os.Chmod(o.Name(), os.FileMode(mode)); err != nil { return "", fmt.Errorf("could not chmod checksum file %s: %s", path, err) @@ -128,7 +136,8 @@ func checksumFile(path string, mode int, algo string) (_ string, err error) { return sumFile, err } -func checksumFileList( +func ChecksumFileList( + logger *logger.LevelLog, paths []string, mode int, algo string, @@ -154,20 +163,20 @@ func checksumFileList( } sumPath := fmt.Sprintf("%s.%s", sumFilePrefix, algo) - l.Verbosef("create or use checksum file: %s", sumPath) + logger.Verbosef("create or use checksum file: %s", sumPath) o, err := os.OpenFile(sumPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666) if err != nil { return "", fmt.Errorf("could not open %s: %w", sumPath, err) } - defer WrappedClose(o, &err) + defer helpers.WrappedClose(o, &err) failed := false for _, path := range paths { - l.Verboseln("computing checksum of:", path) + logger.Verboseln("computing checksum of:", path) r, err := computeChecksum(path, h) if err != nil { - l.Errorf("could not checksum %s: %s", path, err) + logger.Errorf("could not checksum %s: %s", path, err) failed = true continue } diff --git a/hash_test.go b/internal/crypto/hash_test.go similarity index 88% rename from hash_test.go rename to internal/crypto/hash_test.go index cdcbb66..fd04116 100644 --- a/hash_test.go +++ b/internal/crypto/hash_test.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package crypto import ( "errors" @@ -34,6 +34,8 @@ import ( "path/filepath" "runtime" "testing" + + "github.com/orgrim/pg_back/internal/logger" ) func TestChecksumFile(t *testing.T) { @@ -89,19 +91,21 @@ func TestChecksumFile(t *testing.T) { } } + logger := logger.NewLevelLog() + // bad algo - if _, err := checksumFile("", 0o700, "none"); err != nil { + if _, err := ChecksumFile(logger, "", 0o700, "none"); err != nil { t.Errorf("expected , got %q\n", err) } - if _, err := checksumFile("", 0o700, "other"); err == nil { + if _, err := ChecksumFile(logger, "", 0o700, "other"); err == nil { t.Errorf("expected err, got \n") } // test each algo with the file for i, st := range tests { t.Run(fmt.Sprintf("f%v", i), func(t *testing.T) { - if _, err := checksumFile("test", 0o700, st.algo); err != nil { + if _, err := ChecksumFile(logger, "test", 0o700, st.algo); err != nil { t.Errorf("checksumFile returned: %v", err) } @@ -123,8 +127,10 @@ func TestChecksumFile(t *testing.T) { // bad files var e *os.PathError - l.logger.SetOutput(io.Discard) - if _, err := checksumFile("", 0o700, "sha1"); !errors.As(err, &e) { + + logger.Logger.SetOutput(io.Discard) + + if _, err := ChecksumFile(logger, "", 0o700, "sha1"); !errors.As(err, &e) { t.Errorf("expected an *os.PathError, got %q\n", err) } @@ -132,13 +138,13 @@ func TestChecksumFile(t *testing.T) { t.Errorf("could not chmod test.sha1 file: %q", err) } - if _, err := checksumFile("test", 0o700, "sha1"); !errors.As(err, &e) { + if _, err := ChecksumFile(logger, "test", 0o700, "sha1"); !errors.As(err, &e) { t.Errorf("expected an *os.PathError, got %q\n", err) } if err := os.Chmod("test.sha1", 0644); err != nil { t.Errorf("could not chmod test.sha1 file: %q", err) } - l.logger.SetOutput(os.Stderr) + logger.Logger.SetOutput(os.Stderr) // create a directory and some files if err := os.Mkdir("test.d", 0755); err != nil { @@ -158,7 +164,7 @@ func TestChecksumFile(t *testing.T) { // test each algo with the directory for i, st := range tests { t.Run(fmt.Sprintf("d%v", i), func(t *testing.T) { - if _, err := checksumFile("test.d", 0o700, st.algo); err != nil { + if _, err := ChecksumFile(logger, "test.d", 0o700, st.algo); err != nil { t.Errorf("checksumFile returned: %v", err) } diff --git a/internal/helpers/helpers.go b/internal/helpers/helpers.go new file mode 100644 index 0000000..5897c97 --- /dev/null +++ b/internal/helpers/helpers.go @@ -0,0 +1,93 @@ +package helpers + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + "time" + + "github.com/orgrim/pg_back/internal/logger" +) + +func WrappedClose(c io.Closer, err *error) { + cErr := c.Close() + if cErr != nil && *err == nil { + *err = cErr + } +} +func CleanDBName(dbname string) string { + // We do not want a database name starting with a dot to avoid creating hidden files + if strings.HasPrefix(dbname, ".") { + dbname = "_" + dbname + } + + // If there is a path separator in the database name, we do not want to + // create the dump in a subdirectory or in a parent directory + if strings.ContainsRune(dbname, os.PathSeparator) { + dbname = strings.ReplaceAll(dbname, string(os.PathSeparator), "_") + } + + // Always remove slashes to avoid issues with filenames on windows + if strings.ContainsRune(dbname, '/') { + dbname = strings.ReplaceAll(dbname, "/", "_") + } + + return dbname +} + +func FormatDumpPath( + dir string, + timeFormat string, + suffix string, + dbname string, + when time.Time, + compressLevel int, +) string { + var f, s, d string + + // Avoid attacks on the database name + dbname = CleanDBName(dbname) + + d = dir + if dbname != "" { + d = strings.ReplaceAll(dir, "{dbname}", dbname) + } + + s = suffix + if suffix == "" { + s = "dump" + } + + // Output is "dir(formatted)/dbname_date.suffix" when the + // input time is not zero, otherwise do not include the date + // and time. Reference time for time.Format(): "Mon Jan 2 + // 15:04:05 MST 2006" + if when.IsZero() { + f = fmt.Sprintf("%s.%s", dbname, s) + } else { + f = fmt.Sprintf("%s_%s.%s", dbname, when.Format(timeFormat), s) + } + + if suffix == "sql" && compressLevel > 0 { + f = f + ".gz" + } + + return filepath.Join(d, f) +} + +func RelPath(logger *logger.LevelLog, basedir, path string) string { + target, err := filepath.Rel(basedir, path) + if err != nil { + logger.Warnf("could not get relative path from %s: %s\n", path, err) + target = path + } + + prefix := fmt.Sprintf("..%c", os.PathSeparator) + for strings.HasPrefix(target, prefix) { + target = strings.TrimPrefix(target, prefix) + } + + return target +} diff --git a/legacy.go b/internal/legacy/legacy.go similarity index 94% rename from legacy.go rename to internal/legacy/legacy.go index 9b05c18..9d4136b 100644 --- a/legacy.go +++ b/internal/legacy/legacy.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package legacy import ( "fmt" @@ -32,6 +32,8 @@ import ( "strings" "github.com/anmitsu/go-shlex" + "github.com/orgrim/pg_back/internal/helpers" + "github.com/orgrim/pg_back/internal/logger" ) // Read the input file and return all lines that look like legacy configuration @@ -130,7 +132,7 @@ out: return strings.Trim(string(buf), " \t\v") } -func convertLegacyConf(oldConf []string) string { +func convertLegacyConf(logger *logger.LevelLog, oldConf []string) string { var result string table := map[string]string{ @@ -178,7 +180,7 @@ func convertLegacyConf(oldConf []string) string { } words, err := shlex.Split(v, true) if err != nil { - l.Warnf("could not parse value of PGBK_OPTS \"%s\": %s", value, err) + logger.Warnf("could not parse value of PGBK_OPTS \"%s\": %s", value, err) continue } @@ -263,19 +265,19 @@ func convertLegacyConf(oldConf []string) string { return result } -func convertLegacyConfFile(path string) (err error) { +func ConvertLegacyConfFile(logger *logger.LevelLog, path string) (err error) { f, err := os.Open(path) if err != nil { return fmt.Errorf("could not convert configuration: %w", err) } - defer WrappedClose(f, &err) + defer helpers.WrappedClose(f, &err) contents, err := readLegacyConf(f) if err != nil { return fmt.Errorf("could not convert configuration: %w", err) } - fmt.Printf("%s", convertLegacyConf(contents)) + fmt.Printf("%s", convertLegacyConf(logger, contents)) return err } diff --git a/legacy_test.go b/internal/legacy/legacy_test.go similarity index 97% rename from legacy_test.go rename to internal/legacy/legacy_test.go index 051b462..12d1913 100644 --- a/legacy_test.go +++ b/internal/legacy/legacy_test.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package legacy import ( "bytes" @@ -32,6 +32,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/orgrim/pg_back/internal/logger" ) func TestReadLegacyConf(t *testing.T) { @@ -171,10 +172,10 @@ func TestConvertLegacyConf(t *testing.T) { }, "format = plain\n" + "pg_dump_options = --create\n"}, } - + logger := logger.NewLevelLog() for i, st := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - got := convertLegacyConf(st.input) + got := convertLegacyConf(logger, st.input) if got != st.want { t.Errorf("got %v, want %v", got, st.want) } diff --git a/log.go b/internal/logger/log.go similarity index 81% rename from log.go rename to internal/logger/log.go index b52f7f2..c564b87 100644 --- a/log.go +++ b/internal/logger/log.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package logger import ( "log" @@ -33,17 +33,15 @@ import ( // LevelLog custom type to allow a verbose mode and handling of levels // with a prefix type LevelLog struct { - logger *log.Logger + Logger *log.Logger verbose bool quiet bool } -var l = NewLevelLog() - // NewLevelLog setups a logger with the proper configuration for the underlying log func NewLevelLog() *LevelLog { return &LevelLog{ - logger: log.New(os.Stderr, "", log.LstdFlags|log.Lmsgprefix), + Logger: log.New(os.Stderr, "", log.LstdFlags|log.Lmsgprefix), verbose: false, quiet: false, } @@ -61,74 +59,74 @@ func (l *LevelLog) SetVerbosity(verbose bool, quiet bool) { l.verbose = verbose if verbose { - l.logger.SetFlags(log.LstdFlags | log.Lmsgprefix | log.Lmicroseconds) + l.Logger.SetFlags(log.LstdFlags | log.Lmsgprefix | log.Lmicroseconds) } } // Verbosef prints with log.Printf a message with DEBUG: prefix using log.Printf, only when verbose mode is true func (l *LevelLog) Verbosef(format string, v ...interface{}) { if l.verbose { - l.logger.SetPrefix("DEBUG: ") - l.logger.Printf(format, v...) + l.Logger.SetPrefix("DEBUG: ") + l.Logger.Printf(format, v...) } } // Verboseln prints a message with DEBUG: prefix using log.Println, only when verbose mode is true func (l *LevelLog) Verboseln(v ...interface{}) { if l.verbose { - l.logger.SetPrefix("DEBUG: ") - l.logger.Println(v...) + l.Logger.SetPrefix("DEBUG: ") + l.Logger.Println(v...) } } // Infof prints a message with INFO: prefix using log.Printf func (l *LevelLog) Infof(format string, v ...interface{}) { if !l.quiet { - l.logger.SetPrefix("INFO: ") - l.logger.Printf(format, v...) + l.Logger.SetPrefix("INFO: ") + l.Logger.Printf(format, v...) } } // Infoln prints a message with INFO: prefix using log.Println func (l *LevelLog) Infoln(v ...interface{}) { if !l.quiet { - l.logger.SetPrefix("INFO: ") - l.logger.Println(v...) + l.Logger.SetPrefix("INFO: ") + l.Logger.Println(v...) } } // Warnf prints a message with WARN: prefix using log.Printf func (l *LevelLog) Warnf(format string, v ...interface{}) { - l.logger.SetPrefix("WARN: ") - l.logger.Printf(format, v...) + l.Logger.SetPrefix("WARN: ") + l.Logger.Printf(format, v...) } // Warnln prints a message with WARN: prefix using log.Println func (l *LevelLog) Warnln(v ...interface{}) { - l.logger.SetPrefix("WARN: ") - l.logger.Println(v...) + l.Logger.SetPrefix("WARN: ") + l.Logger.Println(v...) } // Errorf prints a message with ERROR: prefix using log.Printf func (l *LevelLog) Errorf(format string, v ...interface{}) { - l.logger.SetPrefix("ERROR: ") - l.logger.Printf(format, v...) + l.Logger.SetPrefix("ERROR: ") + l.Logger.Printf(format, v...) } // Errorln prints a message with ERROR: prefix using log.Println func (l *LevelLog) Errorln(v ...interface{}) { - l.logger.SetPrefix("ERROR: ") - l.logger.Println(v...) + l.Logger.SetPrefix("ERROR: ") + l.Logger.Println(v...) } // Fatalf prints a message with FATAL: prefix using log.Printf func (l *LevelLog) Fatalf(format string, v ...interface{}) { - l.logger.SetPrefix("FATAL: ") - l.logger.Printf(format, v...) + l.Logger.SetPrefix("FATAL: ") + l.Logger.Printf(format, v...) } // Fatalln prints a message with FATAL: prefix using log.Println func (l *LevelLog) Fatalln(v ...interface{}) { - l.logger.SetPrefix("FATAL: ") - l.logger.Println(v...) + l.Logger.SetPrefix("FATAL: ") + l.Logger.Println(v...) } diff --git a/log_test.go b/internal/logger/log_test.go similarity index 94% rename from log_test.go rename to internal/logger/log_test.go index 2ce6590..c8cb40c 100644 --- a/log_test.go +++ b/internal/logger/log_test.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package logger import ( "bytes" @@ -64,7 +64,7 @@ func TestLevelLogVerbose(t *testing.T) { for i, subt := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { buf := new(bytes.Buffer) - l.logger.SetOutput(buf) + l.Logger.SetOutput(buf) l.SetVerbosity(subt.verbose, false) if subt.fOrln { l.Verbosef("%s", subt.message) @@ -82,7 +82,7 @@ func TestLevelLogVerbose(t *testing.T) { if !matched { t.Errorf("log output should match %q is %q", subt.re, line) } - l.logger.SetOutput(os.Stderr) + l.Logger.SetOutput(os.Stderr) }) } } @@ -102,7 +102,7 @@ func TestLevelLogInfo(t *testing.T) { for i, subt := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { buf := new(bytes.Buffer) - l.logger.SetOutput(buf) + l.Logger.SetOutput(buf) if subt.fOrln { l.Infof("%s", subt.message) } else { @@ -118,7 +118,7 @@ func TestLevelLogInfo(t *testing.T) { if !matched { t.Errorf("log output should match %q is %q", subt.re, line) } - l.logger.SetOutput(os.Stderr) + l.Logger.SetOutput(os.Stderr) }) } } @@ -138,7 +138,7 @@ func TestLevelLogWarn(t *testing.T) { for i, subt := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { buf := new(bytes.Buffer) - l.logger.SetOutput(buf) + l.Logger.SetOutput(buf) if subt.fOrln { l.Warnf("%s", subt.message) } else { @@ -154,7 +154,7 @@ func TestLevelLogWarn(t *testing.T) { if !matched { t.Errorf("log output should match %q is %q", subt.re, line) } - l.logger.SetOutput(os.Stderr) + l.Logger.SetOutput(os.Stderr) }) } } @@ -174,7 +174,7 @@ func TestLevelLogError(t *testing.T) { for i, subt := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { buf := new(bytes.Buffer) - l.logger.SetOutput(buf) + l.Logger.SetOutput(buf) if subt.fOrln { l.Errorf("%s", subt.message) } else { @@ -190,7 +190,7 @@ func TestLevelLogError(t *testing.T) { if !matched { t.Errorf("log output should match %q is %q", subt.re, line) } - l.logger.SetOutput(os.Stderr) + l.Logger.SetOutput(os.Stderr) }) } } @@ -210,7 +210,7 @@ func TestLevelLogFatal(t *testing.T) { for i, subt := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { buf := new(bytes.Buffer) - l.logger.SetOutput(buf) + l.Logger.SetOutput(buf) if subt.fOrln { l.Fatalf("%s", subt.message) } else { @@ -226,7 +226,7 @@ func TestLevelLogFatal(t *testing.T) { if !matched { t.Errorf("log output should match %q is %q", subt.re, line) } - l.logger.SetOutput(os.Stderr) + l.Logger.SetOutput(os.Stderr) }) } } @@ -238,7 +238,7 @@ func TestLevelLogQuiet(t *testing.T) { l.SetVerbosity(true, true) buf := new(bytes.Buffer) - l.logger.SetOutput(buf) + l.Logger.SetOutput(buf) l.Verbosef("test") if buf.Len() > 0 { diff --git a/internal/metadata/metadata.go b/internal/metadata/metadata.go new file mode 100644 index 0000000..d09854d --- /dev/null +++ b/internal/metadata/metadata.go @@ -0,0 +1,3 @@ +package metadata + +var Version string = "2.6.0" diff --git a/connstring.go b/internal/postgresql/connstring.go similarity index 98% rename from connstring.go rename to internal/postgresql/connstring.go index 2d68f27..725ce9f 100644 --- a/connstring.go +++ b/internal/postgresql/connstring.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package postgresql import ( "fmt" @@ -511,10 +511,10 @@ func makeUrlConnInfo(infos map[string]string) string { return u.String() } -// prepareConnInfo returns a connexion string computed from the input +// PrepareConnInfo returns a connexion string computed from the input // values. When the dbname is already a connection string or a postgresql:// // URI, it only add the application_name keyword if not set. -func prepareConnInfo(host string, port int, username string, dbname string) (*ConnInfo, error) { +func PrepareConnInfo(host string, port int, username string, dbname string) (*ConnInfo, error) { var ( conninfo *ConnInfo err error @@ -554,7 +554,6 @@ func prepareConnInfo(host string, port int, username string, dbname string) (*Co } if _, ok := conninfo.Infos["application_name"]; !ok { - l.Verboseln("using pg_back as application_name") conninfo.Infos["application_name"] = "pg_back" } diff --git a/connstring_test.go b/internal/postgresql/connstring_test.go similarity index 99% rename from connstring_test.go rename to internal/postgresql/connstring_test.go index ff31efe..5611b32 100644 --- a/connstring_test.go +++ b/internal/postgresql/connstring_test.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package postgresql import ( "fmt" @@ -349,7 +349,7 @@ func TestPrepareConnInfo(t *testing.T) { for i, subt := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - res, _ := prepareConnInfo(subt.host, subt.port, subt.username, subt.dbname) + res, _ := PrepareConnInfo(subt.host, subt.port, subt.username, subt.dbname) if res.String() != subt.want { t.Errorf("got '%s', want '%s'", res, subt.want) } diff --git a/sql.go b/internal/postgresql/sql.go similarity index 81% rename from sql.go rename to internal/postgresql/sql.go index b794e21..aa415d9 100644 --- a/sql.go +++ b/internal/postgresql/sql.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package postgresql import ( "database/sql" @@ -34,20 +34,22 @@ import ( "github.com/jackc/pgtype" _ "github.com/jackc/pgx/v4/stdlib" + "github.com/orgrim/pg_back/internal/helpers" + "github.com/orgrim/pg_back/internal/logger" ) -type pg struct { +type Pg struct { conn *sql.DB version int xlogOrWal string - superuser bool + Superuser bool } -func pgGetVersionNum(db *sql.DB) (int, error) { +func pgGetVersionNum(logger *logger.LevelLog, db *sql.DB) (int, error) { var version int query := "select setting from pg_settings where name = 'server_version_num'" - l.Verboseln("executing SQL query:", query) + logger.Verboseln("executing SQL query:", query) err := db.QueryRow(query).Scan(&version) if err != nil { return 0, fmt.Errorf("could not get PostgreSQL server version: %s", err) @@ -56,11 +58,11 @@ func pgGetVersionNum(db *sql.DB) (int, error) { return version, nil } -func pgAmISuperuser(db *sql.DB) (bool, error) { +func pgAmISuperuser(logger *logger.LevelLog, db *sql.DB) (bool, error) { var isSuper bool query := "select rolsuper from pg_roles where rolname = current_user" - l.Verboseln("executing SQL query:", query) + logger.Verboseln("executing SQL query:", query) err := db.QueryRow(query).Scan(&isSuper) if err != nil { return false, fmt.Errorf("could not check if db user is superuser: %s", err) @@ -69,9 +71,9 @@ func pgAmISuperuser(db *sql.DB) (bool, error) { return isSuper, nil } -func dbOpen(conninfo *ConnInfo) (*pg, error) { +func DbOpen(logger *logger.LevelLog, conninfo *ConnInfo) (*Pg, error) { connstr := conninfo.String() - l.Verbosef("connecting to PostgreSQL with: \"%s\"", connstr) + logger.Verbosef("connecting to PostgreSQL with: \"%s\"", connstr) db, err := sql.Open("pgx", connstr) if err != nil { return nil, fmt.Errorf("could not open database: %s", err) @@ -82,15 +84,15 @@ func dbOpen(conninfo *ConnInfo) (*pg, error) { return nil, fmt.Errorf("could not connect to database: %s", err) } - newDB := new(pg) + newDB := new(Pg) newDB.conn = db - newDB.version, err = pgGetVersionNum(db) + newDB.version, err = pgGetVersionNum(logger, db) if err != nil { db.Close() //nolint:errcheck return nil, err } - l.Verboseln("server num version is:", newDB.version) + logger.Verboseln("server num version is:", newDB.version) // Keyword xlog has been replaced by wal as of PostgreSQL 10 if newDB.version >= 100000 { newDB.xlogOrWal = "wal" @@ -98,7 +100,7 @@ func dbOpen(conninfo *ConnInfo) (*pg, error) { newDB.xlogOrWal = "xlog" } - newDB.superuser, err = pgAmISuperuser(db) + newDB.Superuser, err = pgAmISuperuser(logger, db) if err != nil { db.Close() //nolint:errcheck return nil, err @@ -107,8 +109,9 @@ func dbOpen(conninfo *ConnInfo) (*pg, error) { return newDB, nil } -func (db *pg) Close() error { - l.Verboseln("closing connection to PostgreSQL") +func (db *Pg) Close() error { + // TODO: re-add logging when adding a new logger field to the DB struct + // logger.Verboseln("closing connection to PostgreSQL") return db.conn.Close() } @@ -135,7 +138,7 @@ func sqlQuoteIdent(s string) string { return strings.ReplaceAll(s, "\"", "\"\"") } -func listAllDatabases(db *pg, withTemplates bool) (_ []string, err error) { +func listAllDatabases(logger *logger.LevelLog, db *Pg, withTemplates bool) (_ []string, err error) { var ( query string dbname string @@ -148,12 +151,12 @@ func listAllDatabases(db *pg, withTemplates bool) (_ []string, err error) { } dbs := make([]string, 0) - l.Verboseln("executing SQL query:", query) + logger.Verboseln("executing SQL query:", query) rows, err := db.conn.Query(query) if err != nil { return dbs, fmt.Errorf("could not list databases: %s", err) } - defer WrappedClose(rows, &err) + defer helpers.WrappedClose(rows, &err) for rows.Next() { err := rows.Scan(&dbname) @@ -168,8 +171,9 @@ func listAllDatabases(db *pg, withTemplates bool) (_ []string, err error) { return dbs, nil } -func listDatabases( - db *pg, +func ListDatabases( + logger *logger.LevelLog, + db *Pg, withTemplates bool, excludedDbs []string, includedDbs []string, @@ -182,7 +186,7 @@ func listDatabases( // When an explicit list of database is given, allow to select // templates if len(includedDbs) > 0 { - databases, err = listAllDatabases(db, true) + databases, err = listAllDatabases(logger, db, true) if err != nil { return databases, err } @@ -197,11 +201,11 @@ func listDatabases( continue nextidb } } - l.Warnf("database \"%s\" does not exists, excluded", d) + logger.Warnf("database \"%s\" does not exists, excluded", d) } databases = realDbs } else { - databases, err = listAllDatabases(db, withTemplates) + databases, err = listAllDatabases(logger, db, withTemplates) if err != nil { return databases, err } @@ -225,24 +229,30 @@ func listDatabases( return databases, nil } -type pgVersionError struct { +type PgVersionError struct { s string } -func (e *pgVersionError) Error() string { +func (e *PgVersionError) Error() string { return e.s } -type pgPrivError struct { +type PgPrivError struct { s string } -func (e *pgPrivError) Error() string { +func (e *PgPrivError) Error() string { return e.s } // pg_dumpacl stuff -func dumpCreateDBAndACL(db *pg, dbname string, force bool) (_ string, err error) { +func DumpCreateDBAndACL( + logger *logger.LevelLog, + db *Pg, + dbname string, + pgDumpVersion int, + force bool, +) (_ string, err error) { var s string if dbname == "" { @@ -252,18 +262,20 @@ func dumpCreateDBAndACL(db *pg, dbname string, force bool) (_ string, err error) // this query only work from 9.0, where datcollate and datctype were // added to pg_database if db.version < 90000 { - return "", &pgVersionError{s: "cluster version is older than 9.0, not dumping ACL"} + return "", &PgVersionError{s: "cluster version is older than 9.0, not dumping ACL"} } // this is no longer necessary after 11. Dumping ACL is the // job of pg_dump so we have to check its version, not the // server - if pgToolVersion("pg_dump") >= 110000 && !force { - l.Verboseln("no need to dump create database query and database ACL with pg_dump from >=11") + if pgDumpVersion >= 110000 && !force { + logger.Verboseln( + "no need to dump create database query and database ACL with pg_dump from >=11", + ) return "", nil } - l.Infoln("dumping database creation and ACL commands of", dbname) + logger.Infoln("dumping database creation and ACL commands of", dbname) query := "SELECT coalesce(rolname, (select rolname from pg_roles where oid=(select datdba from pg_database where datname='template0'))), " + " pg_encoding_to_char(d.encoding), " + @@ -272,12 +284,12 @@ func dumpCreateDBAndACL(db *pg, dbname string, force bool) (_ string, err error) "FROM pg_database d" + " LEFT JOIN pg_roles u ON (datdba = u.oid) " + "WHERE datallowconn AND datname = $1" - l.Verboseln("executing SQL query:", query) + logger.Verboseln("executing SQL query:", query) rows, err := db.conn.Query(query, dbname) if err != nil { return "", fmt.Errorf("could not query database information for %s: %s", dbname, err) } - defer WrappedClose(rows, &err) + defer helpers.WrappedClose(rows, &err) for rows.Next() { var ( @@ -455,7 +467,12 @@ func makeACLCommands(aclitem string, dbname string, owner string) string { return s } -func dumpDBConfig(db *pg, dbname string) (_ string, err error) { +func DumpDBConfig( + logger *logger.LevelLog, + db *Pg, + dbname string, + pgDumpVersion int, +) (_ string, err error) { var s string if dbname == "" { @@ -464,7 +481,7 @@ func dumpDBConfig(db *pg, dbname string) (_ string, err error) { // this query only work from 9.0, where pg_db_role_setting was introduced if db.version < 90000 { - return "", &pgVersionError{ + return "", &PgVersionError{ s: "cluster version is older than 9.0, not dumping database configuration", } } @@ -472,20 +489,20 @@ func dumpDBConfig(db *pg, dbname string) (_ string, err error) { // this is no longer necessary after 11. Dumping ACL is the // job of pg_dump so we have to check its version, not the // server - if pgToolVersion("pg_dump") >= 110000 { - l.Verboseln("no need to dump database configuration with pg_dump from >=11") + if pgDumpVersion >= 110000 { + logger.Verboseln("no need to dump database configuration with pg_dump from >=11") return "", nil } - l.Infoln("dumping database configuration commands of", dbname) + logger.Infoln("dumping database configuration commands of", dbname) // dump per database config query := "SELECT CASE setrole WHEN 0 THEN NULL ELSE pg_get_userbyid(setrole) END, unnest(setconfig) FROM pg_db_role_setting WHERE setdatabase = (SELECT oid FROM pg_database WHERE datname = $1) ORDER BY 1, 2" - l.Verboseln("executing SQL query:", query) + logger.Verboseln("executing SQL query:", query) rows, err := db.conn.Query(query, dbname) if err != nil { return "", fmt.Errorf("could not query database configuration for %s: %s", dbname, err) } - defer WrappedClose(rows, &err) + defer helpers.WrappedClose(rows, &err) for rows.Next() { var ( @@ -526,17 +543,17 @@ func dumpDBConfig(db *pg, dbname string) (_ string, err error) { return s, err } -func showSettings(db *pg) (_ string, err error) { +func ShowSettings(logger *logger.LevelLog, db *Pg) (_ string, err error) { var s, query string if db.version < 80400 { - return "", &pgVersionError{ + return "", &PgVersionError{ s: "cluster version is older than 8.4, not dumping configuration", } } - if !db.superuser { - return "", &pgPrivError{s: "current user is not superuser, not dumping configuration"} + if !db.Superuser { + return "", &PgPrivError{s: "current user is not superuser, not dumping configuration"} } if db.version >= 90500 { @@ -549,12 +566,12 @@ func showSettings(db *pg) (_ string, err error) { query = "SELECT name, setting FROM pg_settings WHERE sourcefile IS NOT NULL" } - l.Verboseln("executing SQL query:", query) + logger.Verboseln("executing SQL query:", query) rows, err := db.conn.Query(query) if err != nil { return "", fmt.Errorf("could not query instance configuration: %s", err) } - defer WrappedClose(rows, &err) + defer helpers.WrappedClose(rows, &err) for rows.Next() { var ( @@ -564,7 +581,7 @@ func showSettings(db *pg) (_ string, err error) { err := rows.Scan(&name, &value) if err != nil { - l.Errorln(err) + logger.Errorln(err) continue } @@ -586,19 +603,19 @@ func showSettings(db *pg) (_ string, err error) { // when dumping settings from pg_settings, some // settings may not be found because their value can // set a higher levels than configuration files - return s, &pgVersionError{s: "cluster version is older than 9.5, settings from configuration files could be missing if the SET command was used"} + return s, &PgVersionError{s: "cluster version is older than 9.5, settings from configuration files could be missing if the SET command was used"} } } -func extractFileFromSettings(db *pg, name string) (_ string, err error) { +func ExtractFileFromSettings(logger *logger.LevelLog, db *Pg, name string) (_ string, err error) { query := "SELECT setting, pg_read_file(setting, 0, (pg_stat_file(setting)).size) FROM pg_settings WHERE name = $1" - l.Verboseln("executing SQL query:", query) + logger.Verboseln("executing SQL query:", query) rows, err := db.conn.Query(query, name) if err != nil { return "", fmt.Errorf("could not query file contents from settings: %s", err) } - defer WrappedClose(rows, &err) + defer helpers.WrappedClose(rows, &err) var result string @@ -610,7 +627,7 @@ func extractFileFromSettings(db *pg, name string) (_ string, err error) { err := rows.Scan(&path, &contents) if err != nil { - l.Errorln(err) + logger.Errorln(err) continue } @@ -631,18 +648,18 @@ func (*pgReplicaHasLocks) Error() string { return "replication not paused because of AccessExclusiveLock" } -func pauseReplication(db *pg) (err error) { +func pauseReplication(logger *logger.LevelLog, db *Pg) (err error) { // If an AccessExclusiveLock is granted when the replay is // paused, it will remain and pg_dump would be stuck forever query := fmt.Sprintf("SELECT pg_%s_replay_pause() "+ "WHERE NOT EXISTS (SELECT 1 FROM pg_locks WHERE mode = 'AccessExclusiveLock') "+ "AND pg_is_in_recovery();", db.xlogOrWal) - l.Verboseln("executing SQL query:", query) + logger.Verboseln("executing SQL query:", query) rows, err := db.conn.Query(query) if err != nil { return fmt.Errorf("could not pause replication: %s", err) } - defer WrappedClose(rows, &err) + defer helpers.WrappedClose(rows, &err) // The query returns a single row with one column of type void, // which is and empty string, on success. It does not return @@ -660,7 +677,7 @@ func pauseReplication(db *pg) (err error) { return err } -func canPauseReplication(db *pg) (_ bool, err error) { +func canPauseReplication(logger *logger.LevelLog, db *Pg) (_ bool, err error) { // hot standby exists from 9.0 if db.version < 90000 { return false, nil @@ -668,12 +685,12 @@ func canPauseReplication(db *pg) (_ bool, err error) { query := fmt.Sprintf("SELECT 1 FROM pg_proc "+ "WHERE proname='pg_%s_replay_pause' AND pg_is_in_recovery()", db.xlogOrWal) - l.Verboseln("executing SQL query:", query) + logger.Verboseln("executing SQL query:", query) rows, err := db.conn.Query(query) if err != nil { return false, fmt.Errorf("could not check if replication is pausable: %s", err) } - defer WrappedClose(rows, &err) + defer helpers.WrappedClose(rows, &err) // The query returns 1 on success, no row on failure var one int @@ -690,9 +707,9 @@ func canPauseReplication(db *pg) (_ bool, err error) { return true, err } -func pauseReplicationWithTimeout(db *pg, timeOut int) error { +func PauseReplicationWithTimeout(logger *logger.LevelLog, db *Pg, timeOut int) error { - if ok, err := canPauseReplication(db); !ok { + if ok, err := canPauseReplication(logger, db); !ok { return err } @@ -701,7 +718,7 @@ func pauseReplicationWithTimeout(db *pg, timeOut int) error { stop := make(chan bool) fail := make(chan error) - l.Infoln("pausing replication") + logger.Infoln("pausing replication") // We want to retry pausing replication at a defined interval // but not forever. We cannot put the timeout in the same @@ -711,9 +728,9 @@ func pauseReplicationWithTimeout(db *pg, timeOut int) error { defer ticker.Stop() for { - if err := pauseReplication(db); err != nil { + if err := pauseReplication(logger, db); err != nil { if errors.As(err, &rerr) { - l.Warnln(err) + logger.Warnln(err) } else { fail <- err return @@ -736,7 +753,7 @@ func pauseReplicationWithTimeout(db *pg, timeOut int) error { // goroutine if we hit the timeout select { case <-done: - l.Infoln("replication paused") + logger.Infoln("replication paused") case <-time.After(time.Duration(timeOut) * time.Second): stop <- true return fmt.Errorf("replication not paused after %v", time.Duration(timeOut)*time.Second) @@ -747,14 +764,14 @@ func pauseReplicationWithTimeout(db *pg, timeOut int) error { return nil } -func resumeReplication(db *pg) error { - if ok, err := canPauseReplication(db); !ok { +func ResumeReplication(logger *logger.LevelLog, db *Pg) error { + if ok, err := canPauseReplication(logger, db); !ok { return err } - l.Infoln("resuming replication") + logger.Infoln("resuming replication") query := fmt.Sprintf("SELECT pg_%s_replay_resume() WHERE pg_is_in_recovery();", db.xlogOrWal) - l.Verboseln("executing SQL query:", query) + logger.Verboseln("executing SQL query:", query) _, err := db.conn.Exec(query) if err != nil { return fmt.Errorf("could not resume replication: %s", err) diff --git a/sql_test.go b/internal/postgresql/sql_test.go similarity index 90% rename from sql_test.go rename to internal/postgresql/sql_test.go index e0406ee..5f52ecb 100644 --- a/sql_test.go +++ b/internal/postgresql/sql_test.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package postgresql import ( "fmt" @@ -34,10 +34,13 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/orgrim/pg_back/internal/command" + "github.com/orgrim/pg_back/internal/logger" ) var ( - testdb *pg + testdb *Pg + l = logger.NewLevelLog() ) func needPgConn(t *testing.T) { @@ -50,17 +53,19 @@ func needPgConn(t *testing.T) { if err != nil { t.Fatalf("unable to parse PGBK_TEST_CONNINFO: %s", err) } - testdb, err = dbOpen(conninfo) + testdb, err = DbOpen(l, conninfo) if err != nil { - t.Fatalf("expected an ok on dbOpen(), got %s", err) + t.Fatalf("expected an ok on DbOpen(), got %s", err) } } } -func needPgDump(t *testing.T) { - if pgToolVersion("pg_dump") >= 110000 { +func needPgDump(t *testing.T) int { + v := command.PgToolVersion(l, "", "pg_dump") + if v >= 110000 { t.Skip("testing with a pg_dump version > 11") } + return v } func TestSqlQuoteLiteral(t *testing.T) { @@ -156,17 +161,17 @@ func TestDbOpen(t *testing.T) { if err != nil { t.Fatalf("unable to parse PGBK_TEST_CONNINFO: %s", err) } - db, err := dbOpen(conninfo) + db, err := DbOpen(l, conninfo) if err != nil { - t.Fatalf("expected an ok on dbOpen(), got %s", err) + t.Fatalf("expected an ok on DbOpen(), got %s", err) } if err := db.Close(); err != nil { t.Errorf("expected an okon db.Close(), got %s", err) } - testdb, err = dbOpen(conninfo) + testdb, err = DbOpen(l, conninfo) if err != nil { - t.Fatalf("expected an ok on dbOpen(), got %s", err) + t.Fatalf("expected an ok on DbOpen(), got %s", err) } } @@ -183,7 +188,7 @@ func TestListAllDatabases(t *testing.T) { for i, st := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - got, err := listAllDatabases(testdb, st.templates) + got, err := listAllDatabases(l, testdb, st.templates) if err != nil { t.Errorf("expected non nil error, got %q", err) } @@ -222,7 +227,7 @@ func TestListDatabases(t *testing.T) { for i, st := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - got, err := listDatabases(testdb, st.withTemplates, st.excludedDbs, st.includedDbs) + got, err := ListDatabases(l, testdb, st.withTemplates, st.excludedDbs, st.includedDbs) if err != nil { t.Errorf("expected non nil error, got %q", err) } @@ -244,11 +249,11 @@ func TestDumpDBConfig(t *testing.T) { } needPgConn(t) - needPgDump(t) + pg_dump_v := needPgDump(t) for i, st := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - got, err := dumpDBConfig(testdb, "b1") + got, err := DumpDBConfig(l, testdb, "b1", pg_dump_v) if err != nil { t.Errorf("expected non nil error, got %q", err) } @@ -263,7 +268,7 @@ func TestDumpDBConfig(t *testing.T) { func TestShowSettings(t *testing.T) { needPgConn(t) - got, err := showSettings(testdb) + got, err := ShowSettings(l, testdb) if err != nil { t.Errorf("expected non nil error, got %q", err) } @@ -288,7 +293,7 @@ func TestShowSettings(t *testing.T) { func TestDumpCreateDBAndACL(t *testing.T) { needPgConn(t) - needPgDump(t) + pg_dump_v := needPgDump(t) var tests = []struct { db string @@ -306,7 +311,7 @@ func TestDumpCreateDBAndACL(t *testing.T) { for _, st := range tests { t.Run(st.db, func(t *testing.T) { - got, err := dumpCreateDBAndACL(testdb, st.db, false) + got, err := DumpCreateDBAndACL(l, testdb, st.db, pg_dump_v, false) if err != nil { t.Errorf("expected non nil error, got %q", err) } @@ -321,7 +326,7 @@ func TestDumpCreateDBAndACL(t *testing.T) { func TestExtractFileFromSettings(t *testing.T) { needPgConn(t) - got, err := extractFileFromSettings(testdb, "hba_file") + got, err := ExtractFileFromSettings(l, testdb, "hba_file") if err != nil { t.Errorf("expected non nil error, got %q", err) } diff --git a/lock.go b/internal/storage/lock.go similarity index 83% rename from lock.go rename to internal/storage/lock.go index a5e5b4c..c65d3c2 100644 --- a/lock.go +++ b/internal/storage/lock.go @@ -26,18 +26,20 @@ //go:build !windows // +build !windows -package main +package storage import ( "os" "path/filepath" "syscall" + + "github.com/orgrim/pg_back/internal/logger" ) -// lockPath open and try to lock a file with a non-blocking exclusive +// LockPath open and try to lock a file with a non-blocking exclusive // lock. return the open file, which must be held open to keep the // lock, wether it could be locked and a potentiel error. -func lockPath(path string) (*os.File, bool, error) { +func LockPath(logger *logger.LevelLog, path string) (*os.File, bool, error) { if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { return nil, false, err } @@ -47,7 +49,7 @@ func lockPath(path string) (*os.File, bool, error) { return nil, false, err } - l.Verboseln("locking", path, "with flock()") + logger.Verboseln("locking", path, "with flock()") if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil { switch err { case syscall.EWOULDBLOCK: @@ -61,11 +63,11 @@ func lockPath(path string) (*os.File, bool, error) { return f, true, nil } -// unlockPath releases the lock from the open file and removes the +// UnlockPath releases the lock from the open file and removes the // underlying path -func unlockPath(f *os.File) error { +func UnlockPath(logger *logger.LevelLog, f *os.File) error { path := f.Name() - l.Verboseln("unlocking", path, "with flock()") + logger.Verboseln("unlocking", path, "with flock()") if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil { return err } diff --git a/lock_test.go b/internal/storage/lock_test.go similarity index 86% rename from lock_test.go rename to internal/storage/lock_test.go index 5167a54..dc98321 100644 --- a/lock_test.go +++ b/internal/storage/lock_test.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package storage import ( "errors" @@ -31,9 +31,12 @@ import ( "path/filepath" "runtime" "testing" + + "github.com/orgrim/pg_back/internal/logger" ) func TestLockPath(t *testing.T) { + logger := logger.NewLevelLog() // Work from a tempdir dir, err := os.MkdirTemp("", "test_lockpath") if err != nil { @@ -49,19 +52,19 @@ func TestLockPath(t *testing.T) { var e *os.PathError // On windows the directory is created even with a mode of the tempdir that should make it fail if runtime.GOOS != "windows" { - _, _, err = lockPath(filepath.Join(dir, "subfail", "subfail", "lockfail")) + _, _, err = LockPath(logger, filepath.Join(dir, "subfail", "subfail", "lockfail")) if !errors.As(err, &e) { t.Errorf("expected a *os.PathError, got %q\n", err) } } // path is subdir of tempdir to make os.create fail - _, _, err = lockPath(filepath.Join(dir, "subfail")) + _, _, err = LockPath(logger, filepath.Join(dir, "subfail")) if !errors.As(err, &e) { t.Errorf("expected a *os.PathError, got %q\n", err) } // lock a path with success - f, l, err := lockPath(filepath.Join(dir, "lock")) + f, l, err := LockPath(logger, filepath.Join(dir, "lock")) if err != nil { t.Errorf("expected got error %q\n", err) } @@ -71,7 +74,7 @@ func TestLockPath(t *testing.T) { } // fail to lock it again - f1, l1, err := lockPath(filepath.Join(dir, "lock")) + f1, l1, err := LockPath(logger, filepath.Join(dir, "lock")) if err != nil { t.Errorf("expected got error %q\n", err) } @@ -82,6 +85,7 @@ func TestLockPath(t *testing.T) { } func TestUnlockPath(t *testing.T) { + logger := logger.NewLevelLog() f, err := os.CreateTemp("", "test_unlockpath") if err != nil { t.Fatal("could not create tempfile") @@ -89,14 +93,14 @@ func TestUnlockPath(t *testing.T) { defer os.Remove(f.Name()) // unlock shall always work even if the file is not locked - err = unlockPath(f) + err = UnlockPath(logger, f) if err != nil { t.Errorf("got error %q on non locked file\n", err) } // error when the locked file as already been removed os.Remove(f.Name()) - err = unlockPath(f) + err = UnlockPath(logger, f) if err == nil { t.Errorf("got instead of \"bad file descriptor\" error") } diff --git a/lock_win.go b/internal/storage/lock_win.go similarity index 85% rename from lock_win.go rename to internal/storage/lock_win.go index d0d523e..2061458 100644 --- a/lock_win.go +++ b/internal/storage/lock_win.go @@ -26,17 +26,19 @@ //go:build windows // +build windows -package main +package storage import ( "fmt" "os" "path/filepath" + + "github.com/orgrim/pg_back/internal/logger" ) -// lockPath on windows just creates a file without locking, it only tests if +// LockPath on windows just creates a file without locking, it only tests if // the file exist to consider it locked -func lockPath(path string) (*os.File, bool, error) { +func LockPath(logger *logger.LevelLog, path string) (*os.File, bool, error) { if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { return nil, false, err } @@ -53,7 +55,7 @@ func lockPath(path string) (*os.File, bool, error) { return nil, false, err } - l.Verboseln("creating lock file", path) + logger.Verboseln("creating lock file", path) f, err := os.Create(path) if err != nil { return nil, false, err @@ -63,9 +65,9 @@ func lockPath(path string) (*os.File, bool, error) { // unlockPath releases the lock from the open file and removes the // underlying path -func unlockPath(f *os.File) error { +func UnlockPath(logger *logger.LevelLog, f *os.File) error { path := f.Name() - l.Verboseln("removing lock file", path) + logger.Verboseln("removing lock file", path) f.Close() return os.Remove(path) } diff --git a/purge.go b/internal/storage/purge.go similarity index 75% rename from purge.go rename to internal/storage/purge.go index 9749ae3..6b889d6 100644 --- a/purge.go +++ b/internal/storage/purge.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package storage import ( "errors" @@ -35,6 +35,9 @@ import ( "sort" "strings" "time" + + "github.com/orgrim/pg_back/internal/helpers" + "github.com/orgrim/pg_back/internal/logger" ) type purgeJob struct { @@ -53,8 +56,8 @@ func genPurgeJobs(items []Item, dbname string) []purgeJob { ) for _, item := range items { - if strings.HasPrefix(item.key, cleanDBName(dbname)+"_") { - dateNExt := strings.TrimPrefix(item.key, cleanDBName(dbname)+"_") + if strings.HasPrefix(item.Key, helpers.CleanDBName(dbname)+"_") { + dateNExt := strings.TrimPrefix(item.Key, helpers.CleanDBName(dbname)+"_") parts := strings.SplitN(dateNExt, ".", 2) var ( @@ -98,10 +101,10 @@ func genPurgeJobs(items []Item, dbname string) []purgeJob { job.datetime = date } - if item.isDir { - job.dirs = append(job.dirs, item.key) + if item.IsDir { + job.dirs = append(job.dirs, item.Key) } else { - job.files = append(job.files, item.key) + job.files = append(job.files, item.Key) } jobs[parts[0]] = job @@ -123,18 +126,24 @@ func genPurgeJobs(items []Item, dbname string) []purgeJob { return jobList } -func purgeDumps(directory string, dbname string, keep int, limit time.Time) (err error) { - l.Verboseln("purge:", dbname, "limit:", limit, "keep:", keep) +func PurgeDumps( + logger *logger.LevelLog, + directory string, + dbname string, + keep int, + limit time.Time, +) (err error) { + logger.Verboseln("purge:", dbname, "limit:", limit, "keep:", keep) // The dbname can be put in the path of the backup directory, so we // have to compute it first. This is why a dbname is required to purge // old dumps - dirpath := filepath.Dir(formatDumpPath(directory, "", "", dbname, time.Time{}, 0)) + dirpath := filepath.Dir(helpers.FormatDumpPath(directory, "", "", dbname, time.Time{}, 0)) dir, err := os.Open(dirpath) if err != nil { return fmt.Errorf("could not purge %s: %s", dirpath, err) } - defer WrappedClose(dir, &err) + defer helpers.WrappedClose(dir, &err) files := make([]Item, 0) for { @@ -149,7 +158,7 @@ func purgeDumps(directory string, dbname string, keep int, limit time.Time) (err return fmt.Errorf("could not purge %s: %s", dirpath, err) } - files = append(files, Item{key: f[0].Name(), modtime: f[0].ModTime(), isDir: f[0].IsDir()}) + files = append(files, Item{Key: f[0].Name(), modtime: f[0].ModTime(), IsDir: f[0].IsDir()}) } // Parse and group by date. We remove groups of files produced by @@ -160,11 +169,11 @@ func purgeDumps(directory string, dbname string, keep int, limit time.Time) (err // Show the files kept in verbose mode for _, j := range jobs[:keep] { for _, f := range j.files { - l.Verboseln("keeping (count)", filepath.Join(dirpath, f)) + logger.Verboseln("keeping (count)", filepath.Join(dirpath, f)) } for _, d := range j.dirs { - l.Verboseln("keeping (count)", filepath.Join(dirpath, d)) + logger.Verboseln("keeping (count)", filepath.Join(dirpath, d)) } } @@ -174,26 +183,26 @@ func purgeDumps(directory string, dbname string, keep int, limit time.Time) (err if j.datetime.Before(limit) { for _, f := range j.files { path := filepath.Join(dirpath, f) - l.Infoln("removing", path) + logger.Infoln("removing", path) if err = os.Remove(path); err != nil { - l.Errorln(err) + logger.Errorln(err) } } for _, d := range j.dirs { path := filepath.Join(dirpath, d) - l.Infoln("removing", path) + logger.Infoln("removing", path) if err = os.RemoveAll(path); err != nil { - l.Errorln(err) + logger.Errorln(err) } } } else { for _, f := range j.files { - l.Verboseln("keeping (age)", filepath.Join(dirpath, f)) + logger.Verboseln("keeping (age)", filepath.Join(dirpath, f)) } for _, d := range j.dirs { - l.Verboseln("keeping (age)", filepath.Join(dirpath, d)) + logger.Verboseln("keeping (age)", filepath.Join(dirpath, d)) } } } @@ -206,7 +215,8 @@ func purgeDumps(directory string, dbname string, keep int, limit time.Time) (err return nil } -func purgeRemoteDumps( +func PurgeRemoteDumps( + logger *logger.LevelLog, repo Repo, uploadPrefix string, directory string, @@ -214,23 +224,23 @@ func purgeRemoteDumps( keep int, limit time.Time, ) error { - l.Verboseln("remote purge:", dbname, "limit:", limit, "keep:", keep) + logger.Verboseln("remote purge:", dbname, "limit:", limit, "keep:", keep) // The dbname can be put in the directory tree of the dump, in this // case the directory containing {dbname} in its name is kept on the // remote path along with any subdirectory. So we have to include it in // the filter when listing remote files - dirpath := filepath.Dir(formatDumpPath(directory, "", "", dbname, time.Time{}, 0)) + dirpath := filepath.Dir(helpers.FormatDumpPath(directory, "", "", dbname, time.Time{}, 0)) prefix := filepath.Join( uploadPrefix, - relPath(directory, filepath.Join(dirpath, cleanDBName(dbname))), + helpers.RelPath(logger, directory, filepath.Join(dirpath, helpers.CleanDBName(dbname))), ) - l.Verboseln("remote file prefix:", prefix) + logger.Verboseln("remote file prefix:", prefix) // Get the list of files from the repository, this includes the // contents of dumps in the directory format. - remoteFiles, err := repo.List(prefix) + remoteFiles, err := repo.List(logger, prefix) if err != nil { return fmt.Errorf("could not purge: %w", err) } @@ -244,13 +254,13 @@ func purgeRemoteDumps( files := make([]Item, 0) for _, i := range remoteFiles { - f, err := filepath.Rel(parentDir, i.key) + f, err := filepath.Rel(parentDir, i.Key) if err != nil { - l.Warnf("could not process remote file %s: %s", i.key, err) + logger.Warnf("could not process remote file %s: %s", i.Key, err) continue } - files = append(files, Item{key: f, modtime: i.modtime, isDir: i.isDir}) + files = append(files, Item{Key: f, modtime: i.modtime, IsDir: i.IsDir}) } // Parse and group by date. We remove groups of files produced by @@ -261,11 +271,11 @@ func purgeRemoteDumps( // Show the files kept in verbose mode for _, j := range jobs[:keep] { for _, f := range j.files { - l.Verboseln("keeping remote (count)", filepath.Join(parentDir, f)) + logger.Verboseln("keeping remote (count)", filepath.Join(parentDir, f)) } for _, d := range j.dirs { - l.Verboseln("keeping remote (count)", filepath.Join(parentDir, d)) + logger.Verboseln("keeping remote (count)", filepath.Join(parentDir, d)) } } @@ -275,27 +285,27 @@ func purgeRemoteDumps( if j.datetime.Before(limit) { for _, f := range j.files { path := filepath.Join(parentDir, f) - l.Infoln("removing remote", path) + logger.Infoln("removing remote", path) if err = repo.Remove(path); err != nil { - l.Errorln(err) + logger.Errorln(err) } } for _, d := range j.dirs { path := filepath.Join(parentDir, d) - l.Infoln("removing remote", path) + logger.Infoln("removing remote", path) if err = repo.Remove(path); err != nil { - l.Errorln(err) + logger.Errorln(err) } } } else { for _, f := range j.files { - l.Verboseln("keeping remote (age)", filepath.Join(parentDir, f)) + logger.Verboseln("keeping remote (age)", filepath.Join(parentDir, f)) } for _, d := range j.dirs { - l.Verboseln("keeping remote (age)", filepath.Join(parentDir, d)) + logger.Verboseln("keeping remote (age)", filepath.Join(parentDir, d)) } } } diff --git a/purge_test.go b/internal/storage/purge_test.go similarity index 87% rename from purge_test.go rename to internal/storage/purge_test.go index b84822d..7e4bc02 100644 --- a/purge_test.go +++ b/internal/storage/purge_test.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package storage import ( "fmt" @@ -32,11 +32,15 @@ import ( "runtime" "testing" "time" + + "github.com/orgrim/pg_back/internal/helpers" + "github.com/orgrim/pg_back/internal/logger" ) // func purgeDumps(directory string, dbname string, keep int, limit time.Time) error func TestPurgeDumps(t *testing.T) { // work in a tempdir + logger := logger.NewLevelLog() dir, err := os.MkdirTemp("", "test_purge_dumps") if err != nil { t.Fatal("could not create tempdir:", err) @@ -51,7 +55,7 @@ func TestPurgeDumps(t *testing.T) { if runtime.GOOS != "windows" { os.Chmod(filepath.Dir(wd), 0444) - err = purgeDumps(wd, "", 0, time.Time{}) + err = PurgeDumps(logger, wd, "", 0, time.Time{}) if err == nil { t.Errorf("empty path gave error \n") } @@ -60,7 +64,7 @@ func TestPurgeDumps(t *testing.T) { // empty dbname when := time.Now().Add(-time.Hour) - tf := formatDumpPath(wd, "2006-01-02_15-04-05", "dump", "", when, 0) + tf := helpers.FormatDumpPath(wd, "2006-01-02_15-04-05", "dump", "", when, 0) f, err := os.Create(tf) if err != nil { t.Errorf("could not create temp file %s: %s", tf, err) @@ -69,7 +73,7 @@ func TestPurgeDumps(t *testing.T) { f.Close() os.Chtimes(tf, when, when) - err = purgeDumps(wd, "", 0, time.Now()) + err = PurgeDumps(logger, wd, "", 0, time.Now()) if err != nil { t.Errorf("empty dbname (file: %s) gave error %s", tf, err) } @@ -79,22 +83,22 @@ func TestPurgeDumps(t *testing.T) { // file without write perms if runtime.GOOS != "windows" { - tf = formatDumpPath(wd, time.RFC3339, "dump", "db", time.Now().Add(-time.Hour), 0) + tf = helpers.FormatDumpPath(wd, time.RFC3339, "dump", "db", time.Now().Add(-time.Hour), 0) os.WriteFile(tf, []byte("truc\n"), 0644) os.Chmod(filepath.Dir(tf), 0555) - err = purgeDumps(wd, "db", 0, time.Now()) + err = PurgeDumps(logger, wd, "db", 0, time.Now()) if err == nil { t.Errorf("bad perms on file did not gave an error") } os.Chmod(filepath.Dir(tf), 0755) // dir without write perms - tf = formatDumpPath(wd, time.RFC3339, "d", "db", time.Now().Add(-time.Hour), 0) + tf = helpers.FormatDumpPath(wd, time.RFC3339, "d", "db", time.Now().Add(-time.Hour), 0) os.MkdirAll(tf, 0755) os.Chmod(filepath.Dir(tf), 0555) - err = purgeDumps(wd, "db", 0, time.Now()) + err = PurgeDumps(logger, wd, "db", 0, time.Now()) if err == nil { t.Errorf("bad perms on dir did not gave an error") } @@ -137,12 +141,12 @@ func TestPurgeDumps(t *testing.T) { } for i := 1; i <= 3; i++ { when := time.Now().Add(-time.Hour * time.Duration(i)) - tf = formatDumpPath(wd, st.format, "dump", "db", when, 0) + tf = helpers.FormatDumpPath(wd, st.format, "dump", "db", when, 0) os.WriteFile(tf, []byte("truc\n"), 0644) os.Chtimes(tf, when, when) } - if err := purgeDumps(wd, "db", st.keep, st.limit); err != nil { + if err := PurgeDumps(logger, wd, "db", st.keep, st.limit); err != nil { t.Errorf("purgeDumps returned: %v", err) } diff --git a/upload.go b/internal/storage/upload.go similarity index 83% rename from upload.go rename to internal/storage/upload.go index e94b722..7b997de 100644 --- a/upload.go +++ b/internal/storage/upload.go @@ -23,7 +23,7 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF // THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package main +package storage import ( "context" @@ -45,6 +45,9 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/orgrim/pg_back/internal/config" + "github.com/orgrim/pg_back/internal/helpers" + "github.com/orgrim/pg_back/internal/logger" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" @@ -56,14 +59,14 @@ import ( // A Repo is a remote service where we can upload files type Repo interface { // Upload a path to the remote naming it target - Upload(path string, target string) error + Upload(logger *logger.LevelLog, path string, target string) error // Download target from the remote and store it into path - Download(target string, path string) error + Download(logger *logger.LevelLog, target string, path string) error // List remote files starting with a prefix. the prefix can be empty to // list all files - List(prefix string) ([]Item, error) + List(logger *logger.LevelLog, prefix string) ([]Item, error) // Remove path from the remote Remove(path string) error @@ -73,9 +76,9 @@ type Repo interface { } type Item struct { - key string + Key string modtime time.Time - isDir bool + IsDir bool } // Replace any backslashes from windows to forward slashed @@ -83,7 +86,7 @@ func forwardSlashes(target string) string { return strings.ReplaceAll(target, fmt.Sprintf("%c", os.PathSeparator), "/") } -func NewRepo(kind string, opts options) (Repo, error) { +func NewRepo(kind string, opts config.Options) (Repo, error) { var ( repo Repo err error @@ -143,7 +146,7 @@ type s3repo struct { session *session.Session } -func NewB2Repo(opts options) (*b2repo, error) { +func NewB2Repo(opts config.Options) (*b2repo, error) { r := &b2repo{ appKey: opts.B2AppKey, bucket: opts.B2Bucket, @@ -153,11 +156,11 @@ func NewB2Repo(opts options) (*b2repo, error) { keyID: opts.B2KeyID, } - l.Verbosef( - "starting b2 client with %d connections to endpoint to bucket %s \n", - r.concurrentConnections, - r.bucket, - ) + // logger.Verbosef( + // "starting b2 client with %d connections to endpoint to bucket %s \n", + // r.concurrentConnections, + // r.bucket, + // ) client, err := b2.NewClient(r.ctx, r.keyID, r.appKey) if err != nil { @@ -177,19 +180,19 @@ func NewB2Repo(opts options) (*b2repo, error) { return r, nil } -func (r *b2repo) Upload(path string, target string) (err error) { +func (r *b2repo) Upload(logger *logger.LevelLog, path string, target string) (err error) { f, err := os.Open(path) if err != nil { return err } - defer WrappedClose(f, &err) + defer helpers.WrappedClose(f, &err) w := r.b2Bucket.Object(target).NewWriter(r.ctx) - defer WrappedClose(w, &err) + defer helpers.WrappedClose(w, &err) w.ConcurrentUploads = r.concurrentConnections - l.Infof("uploading %s to B2 bucket %s\n", path, r.bucket) + logger.Infof("uploading %s to B2 bucket %s\n", path, r.bucket) if _, err := io.Copy(w, f); err != nil { return err } @@ -197,20 +200,20 @@ func (r *b2repo) Upload(path string, target string) (err error) { return nil } -func (r *b2repo) Download(target string, path string) (err error) { +func (r *b2repo) Download(logger *logger.LevelLog, target string, path string) (err error) { f, err := os.Create(path) if err != nil { return fmt.Errorf("download error: %w", err) } - defer WrappedClose(f, &err) + defer helpers.WrappedClose(f, &err) bucket := r.b2Bucket - l.Infof("downloading %s from B2 bucket %s to %s\n", target, r.bucket, path) + logger.Infof("downloading %s from B2 bucket %s to %s\n", target, r.bucket, path) rf := bucket.Object(target).NewReader(r.ctx) rf.ConcurrentDownloads = r.concurrentConnections - defer WrappedClose(rf, &err) + defer helpers.WrappedClose(rf, &err) if err != nil { return err @@ -227,7 +230,7 @@ func (r *b2repo) Close() error { return nil } -func (r *b2repo) List(prefix string) ([]Item, error) { +func (r *b2repo) List(_ *logger.LevelLog, prefix string) ([]Item, error) { files := make([]Item, 0) @@ -242,7 +245,7 @@ func (r *b2repo) List(prefix string) ([]Item, error) { } files = append(files, Item{ - key: obj.Name(), + Key: obj.Name(), modtime: attributes.LastModified, }, ) @@ -259,7 +262,7 @@ func (r *b2repo) Remove(path string) error { return r.b2Bucket.Object(path).Delete(ctx) } -func NewS3Repo(opts options) (*s3repo, error) { +func NewS3Repo(opts config.Options) (*s3repo, error) { r := &s3repo{ region: opts.S3Region, bucket: opts.S3Bucket, @@ -315,16 +318,16 @@ func (r *s3repo) Close() error { return nil } -func (r *s3repo) Upload(path string, target string) (err error) { +func (r *s3repo) Upload(logger *logger.LevelLog, path string, target string) (err error) { file, err := os.Open(path) if err != nil { return fmt.Errorf("upload error: %w", err) } - defer WrappedClose(file, &err) + defer helpers.WrappedClose(file, &err) uploader := s3manager.NewUploader(r.session) - l.Infof("uploading %s to S3 bucket %s\n", path, r.bucket) + logger.Infof("uploading %s to S3 bucket %s\n", path, r.bucket) _, err = uploader.Upload(&s3manager.UploadInput{ Bucket: aws.String(r.bucket), Key: aws.String(forwardSlashes(target)), @@ -338,16 +341,16 @@ func (r *s3repo) Upload(path string, target string) (err error) { return nil } -func (r *s3repo) Download(target string, path string) (err error) { +func (r *s3repo) Download(logger *logger.LevelLog, target string, path string) (err error) { file, err := os.Create(path) if err != nil { return fmt.Errorf("download error: %w", err) } - defer WrappedClose(file, &err) + defer helpers.WrappedClose(file, &err) downloader := s3manager.NewDownloader(r.session) - l.Infof("downloading %s from S3 bucket %s to %s\n", target, r.bucket, path) + logger.Infof("downloading %s from S3 bucket %s to %s\n", target, r.bucket, path) _, err = downloader.Download(file, &s3.GetObjectInput{ Bucket: aws.String(r.bucket), Key: aws.String(forwardSlashes(target)), @@ -360,7 +363,7 @@ func (r *s3repo) Download(target string, path string) (err error) { return nil } -func (r *s3repo) List(prefix string) ([]Item, error) { +func (r *s3repo) List(_ *logger.LevelLog, prefix string) ([]Item, error) { svc := s3.New(r.session) files := make([]Item, 0) @@ -380,7 +383,7 @@ func (r *s3repo) List(prefix string) ([]Item, error) { for _, item := range resp.Contents { file := Item{ - key: *item.Key, + Key: *item.Key, modtime: *item.LastModified, } @@ -557,7 +560,7 @@ func pubKeyAuth(identity string, passphrase string) ([]ssh.Signer, error) { return signers, nil } -func NewSFTPRepo(opts options) (*sftpRepo, error) { +func NewSFTPRepo(opts config.Options) (*sftpRepo, error) { r := &sftpRepo{ host: opts.SFTPHost, port: opts.SFTPPort, @@ -633,14 +636,14 @@ func (r *sftpRepo) Close() error { return errors.Join(r.client.Close(), r.conn.Close()) } -func (r *sftpRepo) Upload(path string, target string) (err error) { - l.Infof("uploading %s to %s:%s using sftp\n", path, r.host, r.baseDir) +func (r *sftpRepo) Upload(logger *logger.LevelLog, path string, target string) (err error) { + logger.Infof("uploading %s to %s:%s using sftp\n", path, r.host, r.baseDir) src, err := os.Open(path) if err != nil { return fmt.Errorf("sftp: could not open source %s: %w", path, err) } - defer WrappedClose(src, &err) + defer helpers.WrappedClose(src, &err) rpath := filepath.Join(r.baseDir, target) targetDir := filepath.Dir(rpath) @@ -650,7 +653,7 @@ func (r *sftpRepo) Upload(path string, target string) (err error) { rpath = strings.ReplaceAll(rpath, string(os.PathSeparator), "/") targetDir = strings.ReplaceAll(targetDir, string(os.PathSeparator), "/") } - l.Verboseln("sftp remote path is:", rpath) + logger.Verboseln("sftp remote path is:", rpath) // Target directory must be created first if targetDir != "." && targetDir != "/" { @@ -663,7 +666,7 @@ func (r *sftpRepo) Upload(path string, target string) (err error) { if err != nil { return fmt.Errorf("sftp: could not open destination %s: %w", rpath, err) } - defer WrappedClose(dst, &err) + defer helpers.WrappedClose(dst, &err) if _, err := io.Copy(dst, src); err != nil { return fmt.Errorf("sftp: could not send data with sftp: %s", err) @@ -672,14 +675,14 @@ func (r *sftpRepo) Upload(path string, target string) (err error) { return err } -func (r *sftpRepo) Download(target string, path string) (err error) { - l.Infof("downloading %s from %s:%s using sftp\n", target, r.host, r.baseDir) +func (r *sftpRepo) Download(logger *logger.LevelLog, target string, path string) (err error) { + logger.Infof("downloading %s from %s:%s using sftp\n", target, r.host, r.baseDir) dst, err := os.Create(path) if err != nil { return fmt.Errorf("sftp: could not open or create %s: %w", path, err) } - defer WrappedClose(dst, &err) + defer helpers.WrappedClose(dst, &err) rpath := filepath.Join(r.baseDir, target) @@ -687,13 +690,13 @@ func (r *sftpRepo) Download(target string, path string) (err error) { if os.PathSeparator != '/' { rpath = strings.ReplaceAll(rpath, string(os.PathSeparator), "/") } - l.Verboseln("sftp remote path is:", rpath) + logger.Verboseln("sftp remote path is:", rpath) src, err := r.client.Open(rpath) if err != nil { return fmt.Errorf("sftp: could not open %s on %s: %w", rpath, r.host, err) } - defer WrappedClose(src, &err) + defer helpers.WrappedClose(src, &err) if _, err := io.Copy(dst, src); err != nil { return fmt.Errorf("sftp: could not receive data with sftp: %s", err) @@ -702,7 +705,7 @@ func (r *sftpRepo) Download(target string, path string) (err error) { return err } -func (r *sftpRepo) List(prefix string) (items []Item, rerr error) { +func (r *sftpRepo) List(logger *logger.LevelLog, prefix string) (items []Item, rerr error) { items = make([]Item, 0) // sftp requires slash as path separator @@ -714,16 +717,15 @@ func (r *sftpRepo) List(prefix string) (items []Item, rerr error) { w := r.client.Walk(baseDir) for w.Step() { if err := w.Err(); err != nil { - l.Warnln("could not list remote file:", err) rerr = err continue } - // relPath() makes use of functions of the filepath std module + // RelPath() makes use of functions of the filepath std module // that take care of putting back the proper os.PathSeparator // if it find some slashes, so we can compare paths without // worrying about path separators - path := relPath(baseDir, w.Path()) + path := helpers.RelPath(logger, baseDir, w.Path()) if !strings.HasPrefix(path, prefix) { continue @@ -731,9 +733,9 @@ func (r *sftpRepo) List(prefix string) (items []Item, rerr error) { finfo := w.Stat() items = append(items, Item{ - key: path, + Key: path, modtime: finfo.ModTime(), - isDir: finfo.IsDir(), + IsDir: finfo.IsDir(), }) } @@ -762,7 +764,7 @@ type gcsRepo struct { client *storage.Client } -func NewGCSRepo(opts options) (*gcsRepo, error) { +func NewGCSRepo(opts config.Options) (*gcsRepo, error) { r := &gcsRepo{ bucket: opts.GCSBucket, url: opts.GCSEndPoint, @@ -792,19 +794,19 @@ func (r *gcsRepo) Close() error { return r.client.Close() } -func (r *gcsRepo) Upload(path string, target string) (err error) { +func (r *gcsRepo) Upload(logger *logger.LevelLog, path string, target string) (err error) { file, err := os.Open(path) if err != nil { return fmt.Errorf("upload error: %w", err) } - defer WrappedClose(file, &err) + defer helpers.WrappedClose(file, &err) obj := r.client.Bucket(r.bucket).Object(forwardSlashes(target)).NewWriter(context.Background()) // The upload is done asynchronously, the error returned by Close() // says if it was successful - defer WrappedClose(obj, &err) + defer helpers.WrappedClose(obj, &err) - l.Infof("uploading %s to GCS bucket %s\n", path, r.bucket) + logger.Infof("uploading %s to GCS bucket %s\n", path, r.bucket) if _, err := io.Copy(obj, file); err != nil { return fmt.Errorf("could not write data to GCS object: %w", err) } @@ -812,12 +814,12 @@ func (r *gcsRepo) Upload(path string, target string) (err error) { return err } -func (r *gcsRepo) Download(target string, path string) (err error) { +func (r *gcsRepo) Download(logger *logger.LevelLog, target string, path string) (err error) { file, err := os.Create(path) if err != nil { return fmt.Errorf("download error: %w", err) } - defer WrappedClose(file, &err) + defer helpers.WrappedClose(file, &err) obj, err := r.client.Bucket(r.bucket). Object(forwardSlashes(target)). @@ -825,9 +827,9 @@ func (r *gcsRepo) Download(target string, path string) (err error) { if err != nil { return fmt.Errorf("download error: %w", err) } - defer WrappedClose(obj, &err) + defer helpers.WrappedClose(obj, &err) - l.Infof("downloading %s from GCS bucket %s to %s\n", target, r.bucket, path) + logger.Infof("downloading %s from GCS bucket %s to %s\n", target, r.bucket, path) if _, err := io.Copy(file, obj); err != nil { return fmt.Errorf("could not read data from GCS object: %w", err) } @@ -835,7 +837,7 @@ func (r *gcsRepo) Download(target string, path string) (err error) { return err } -func (r *gcsRepo) List(prefix string) (items []Item, rerr error) { +func (r *gcsRepo) List(logger *logger.LevelLog, prefix string) (items []Item, rerr error) { items = make([]Item, 0) it := r.client.Bucket(r.bucket). @@ -847,13 +849,13 @@ func (r *gcsRepo) List(prefix string) (items []Item, rerr error) { } if err != nil { - l.Warnln("could not list remote file:", err) + logger.Warnln("could not list remote file:", err) rerr = err break } items = append(items, Item{ - key: attrs.Name, + Key: attrs.Name, modtime: attrs.Updated, }) } @@ -877,7 +879,7 @@ type azRepo struct { client *azblob.Client } -func NewAzRepo(opts options) (*azRepo, error) { +func NewAzRepo(opts config.Options) (*azRepo, error) { r := &azRepo{ container: opts.AzureContainer, account: opts.AzureAccount, @@ -922,14 +924,14 @@ func NewAzRepo(opts options) (*azRepo, error) { return r, nil } -func (r *azRepo) Upload(path string, target string) (err error) { +func (r *azRepo) Upload(logger *logger.LevelLog, path string, target string) (err error) { file, err := os.Open(path) if err != nil { return fmt.Errorf("upload error: %w", err) } - defer WrappedClose(file, &err) + defer helpers.WrappedClose(file, &err) - l.Infof("uploading %s to Azure container %s\n", path, r.container) + logger.Infof("uploading %s to Azure container %s\n", path, r.container) _, err = r.client.UploadFile(context.Background(), r.container, path, file, nil) if err != nil { return fmt.Errorf("could not upload %s to Azure: %w", path, err) @@ -938,14 +940,14 @@ func (r *azRepo) Upload(path string, target string) (err error) { return err } -func (r *azRepo) Download(target string, path string) (err error) { +func (r *azRepo) Download(logger *logger.LevelLog, target string, path string) (err error) { file, err := os.Create(path) if err != nil { return fmt.Errorf("download error: %w", err) } - defer WrappedClose(file, &err) + defer helpers.WrappedClose(file, &err) - l.Infof("downloading %s from Azure container %s\n", target, r.container) + logger.Infof("downloading %s from Azure container %s\n", target, r.container) _, err = r.client.DownloadFile(context.Background(), r.container, target, file, nil) if err != nil { return fmt.Errorf("could not download %s from Azure: %w", target, err) @@ -954,7 +956,7 @@ func (r *azRepo) Download(target string, path string) (err error) { return err } -func (r *azRepo) List(prefix string) ([]Item, error) { +func (r *azRepo) List(_ *logger.LevelLog, prefix string) ([]Item, error) { p := forwardSlashes(prefix) pager := r.client.NewListBlobsFlatPager(r.container, &azblob.ListBlobsFlatOptions{ Prefix: &p, @@ -969,7 +971,7 @@ func (r *azRepo) List(prefix string) ([]Item, error) { for _, v := range resp.Segment.BlobItems { file := Item{ - key: *v.Name, + Key: *v.Name, modtime: *v.Properties.LastModified, } diff --git a/upload_test.go b/internal/storage/upload_test.go similarity index 93% rename from upload_test.go rename to internal/storage/upload_test.go index 0400905..f238569 100644 --- a/upload_test.go +++ b/internal/storage/upload_test.go @@ -1,4 +1,4 @@ -package main +package storage import ( "fmt" @@ -6,6 +6,9 @@ import ( "path/filepath" "runtime" "testing" + + "github.com/orgrim/pg_back/internal/helpers" + "github.com/orgrim/pg_back/internal/logger" ) func TestExpandHomeDir(t *testing.T) { @@ -81,7 +84,7 @@ func TestRelPath(t *testing.T) { for i, st := range tests { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - got := relPath(st.basedir, st.path) + got := helpers.RelPath(logger.NewLevelLog(), st.basedir, st.path) if got != st.want { t.Errorf("got: %v, want %v", got, st.want) } diff --git a/main.go b/main.go index 496d3c8..844912e 100644 --- a/main.go +++ b/main.go @@ -31,22 +31,35 @@ import ( "os" "os/exec" "path/filepath" - "runtime" "slices" "strings" "sync" "time" + + _ "embed" + + pgcommand "github.com/orgrim/pg_back/internal/command" + "github.com/orgrim/pg_back/internal/config" + "github.com/orgrim/pg_back/internal/crypto" + "github.com/orgrim/pg_back/internal/helpers" + "github.com/orgrim/pg_back/internal/legacy" + "github.com/orgrim/pg_back/internal/logger" + "github.com/orgrim/pg_back/internal/postgresql" + "github.com/orgrim/pg_back/internal/storage" ) -var version = "2.6.0" +//go:embed pg_back.conf +var defaultCfg string + var binDir string +var l = logger.NewLevelLog() type dump struct { // Name of the database to dump Database string // Per database pg_dump options to filter schema, tables, etc. - Options *dbOpts + Options *config.DbOpts // Path is the output file or directory of the dump // a directory is output with the directory format of pg_dump @@ -63,7 +76,7 @@ type dump struct { TimeFormat string // Connection parameters - ConnString *ConnInfo + ConnString *postgresql.ConnInfo // Cipher passphrase, when not empty cipher the file CipherPassphrase string @@ -82,44 +95,6 @@ type dump struct { PgDumpVersion int } -type dbOpts struct { - // Format of the dump - Format rune - - // Algorithm of the checksum of the file, "none" is used to - // disable checksuming - SumAlgo string - - // Number of parallel jobs for directory format - Jobs int - - // Compression level for compressed formats, -1 means the default - CompressLevel int - - // Purge configuration - PurgeInterval time.Duration - PurgeKeep int - - // Limit schemas - Schemas []string - ExcludedSchemas []string - - // Limit dumped tables - Tables []string - ExcludedTables []string - - // Other pg_dump options to use - PgDumpOpts []string - - // Whether to force the dump of large objects or not with pg_dump -b or - // -B, or let pg_dump use its default. 0 means default, 1 include - // blobs, 2 exclude blobs. - WithBlobs int - - // Connection user for that database - Username string -} - func main() { // Use another function to allow the use of defer for cleanup, as // os.Exit() does not run deferred functions @@ -133,8 +108,8 @@ func run() (retVal error) { // Parse commanline arguments first so that we can quit if we // have shown usage or version string. We may have to load a // non default configuration file - cliOpts, cliOptList, err := parseCli(os.Args[1:]) - var pce *parseCliResult + cliOpts, cliOptList, err := config.ParseCli(os.Args[1:], defaultCfg) + var pce *config.ParseCliResult if err != nil { if errors.As(err, &pce) { // Convert the configuration file if a path as been @@ -144,7 +119,7 @@ func run() (retVal error) { // output the result on stdout and exit to let the user // check the result if len(pce.LegacyConfig) > 0 { - if err := convertLegacyConfFile(pce.LegacyConfig); err != nil { + if err := legacy.ConvertLegacyConfFile(l, pce.LegacyConfig); err != nil { return err } } @@ -157,15 +132,15 @@ func run() (retVal error) { // Enable verbose mode or quiet mode as soon as possible l.SetVerbosity(cliOpts.Verbose, cliOpts.Quiet) - var cliOptions options + var cliOptions config.Options if cliOpts.NoConfigFile { l.Infoln("Skipping reading config file") - cliOptions = defaultOptions() + cliOptions = config.DefaultOptions() } else { // Load configuration file and allow the default configuration // file to be absent - cliOptions, err = loadConfigurationFile(cliOpts.CfgFile) + cliOptions, err = config.LoadConfigurationFile(cliOpts.CfgFile, l) if err != nil { return err } @@ -173,7 +148,7 @@ func run() (retVal error) { // override options from the configuration file with ones from // the command line - opts := mergeCliAndConfigOptions(cliOpts, cliOptions, cliOptList) + opts := config.MergeCliAndConfigOptions(cliOpts, cliOptions, cliOptList) err = ensureCipherParamsPresent(&opts) if err != nil { @@ -228,7 +203,8 @@ func run() (retVal error) { } if opts.Decrypt { - params := decryptParams{ + params := crypto.DecryptParams{ + Logger: l, PrivateKey: opts.CipherPrivateKey, Passphrase: opts.CipherPassphrase, } @@ -251,20 +227,24 @@ func run() (retVal error) { } // Ensure that pg_dump accepts the options we will give it - pgDumpVersion := pgToolVersion("pg_dump") + pgDumpVersion := pgcommand.PgToolVersion(l, binDir, "pg_dump") + if pgDumpVersion < 80400 { return fmt.Errorf("provided pg_dump is older than 8.4, unable use it") } // Parse the connection information l.Verboseln("processing input connection parameters") - conninfo, err := prepareConnInfo(opts.Host, opts.Port, opts.Username, opts.ConnDb) + conninfo, err := postgresql.PrepareConnInfo(opts.Host, opts.Port, opts.Username, opts.ConnDb) if err != nil { return fmt.Errorf("could not compute connection string: %w", err) } + if conninfo.Infos["application_name"] == "pg_back" { + l.Verboseln("using pg_back as application_name") + } - defer postBackupHook(opts.PostHook) - if err := preBackupHook(opts.PreHook); err != nil { + defer pgcommand.PostBackupHook(l, opts.PostHook) + if err := pgcommand.PreBackupHook(l, opts.PreHook); err != nil { return err } @@ -304,11 +284,11 @@ func run() (retVal error) { defer close(producedFiles) // Connect before running pg_dumpall so that we know if the user is superuser - db, err := dbOpen(conninfo) + db, err := postgresql.DbOpen(l, conninfo) if err != nil { return fmt.Errorf("connection to PostgreSQL failed: %w", err) } - defer WrappedClose(db, &retVal) + defer helpers.WrappedClose(db, &retVal) // Generate a single datetime that will be used in all files generated by pg_back var fileTime time.Time @@ -317,13 +297,13 @@ func run() (retVal error) { } if !opts.DumpOnly { - if !db.superuser { + if !db.Superuser { l.Infoln("connection user is not superuser, some information will not be dumped") } // Then we can implicitely avoid dumping role password when using a // regular user - dumpRolePasswords := opts.WithRolePasswords && db.superuser + dumpRolePasswords := opts.WithRolePasswords && db.Superuser if dumpRolePasswords { l.Infoln("dumping globals") } else { @@ -335,8 +315,8 @@ func run() (retVal error) { l.Infoln("dumping instance configuration") var ( - verr *pgVersionError - perr *pgPrivError + verr *postgresql.PgVersionError + perr *postgresql.PgPrivError ) if err := dumpSettings(opts.Directory, opts.Mode, opts.TimeFormat, db, producedFiles, fileTime); err != nil { @@ -352,13 +332,19 @@ func run() (retVal error) { } } - databases, err := listDatabases(db, opts.WithTemplates, opts.ExcludeDbs, opts.Dbnames) + databases, err := postgresql.ListDatabases( + l, + db, + opts.WithTemplates, + opts.ExcludeDbs, + opts.Dbnames, + ) if err != nil { return err } l.Verboseln("databases to dump:", databases) - if err := pauseReplicationWithTimeout(db, opts.PauseTimeout); err != nil { + if err := postgresql.PauseReplicationWithTimeout(l, db, opts.PauseTimeout); err != nil { return err } @@ -442,8 +428,8 @@ func run() (retVal error) { force = true } - b, err = dumpCreateDBAndACL(db, dbname, force) - var verr *pgVersionError + b, err = postgresql.DumpCreateDBAndACL(l, db, dbname, pgDumpVersion, force) + var verr *postgresql.PgVersionError if err != nil { if !errors.As(err, &verr) { l.Errorln(err) @@ -457,9 +443,9 @@ func run() (retVal error) { if canDumpConfig { l.Verboseln("dumping configuration of", dbname) - c, err = dumpDBConfig(db, dbname) + c, err = postgresql.DumpDBConfig(l, db, dbname, pgDumpVersion) if err != nil { - var verr *pgVersionError + var verr *postgresql.PgVersionError if !errors.As(err, &verr) { l.Errorln(err) exitCode = 1 @@ -473,7 +459,14 @@ func run() (retVal error) { // Write ACL and configuration to an SQL file if len(b) > 0 || len(c) > 0 { - aclpath := formatDumpPath(d.Directory, d.TimeFormat, "createdb.sql", dbname, d.When, 0) + aclpath := helpers.FormatDumpPath( + d.Directory, + d.TimeFormat, + "createdb.sql", + dbname, + d.When, + 0, + ) if err := os.MkdirAll(filepath.Dir(aclpath), 0700); err != nil { l.Errorln(err) exitCode = 1 @@ -515,7 +508,7 @@ func run() (retVal error) { } } - if err := resumeReplication(db); err != nil { + if err := postgresql.ResumeReplication(l, db); err != nil { l.Errorln(err) } if err := db.Close(); err != nil { @@ -539,31 +532,31 @@ func run() (retVal error) { // (globals and settings) like databases l.Infoln("purging old dumps") - var repo Repo + var repo storage.Repo switch opts.Upload { case "s3": - repo, err = NewS3Repo(opts) + repo, err = storage.NewS3Repo(opts) if err != nil { return fmt.Errorf("failed to prepare upload to S3: %w", err) } case "b2": - repo, err = NewB2Repo(opts) + repo, err = storage.NewB2Repo(opts) if err != nil { return fmt.Errorf("failed to prepare upload to B2: %w", err) } case "sftp": - repo, err = NewSFTPRepo(opts) + repo, err = storage.NewSFTPRepo(opts) if err != nil { return fmt.Errorf("failed to prepare upload over SFTP: %w", err) } case "gcs": - repo, err = NewGCSRepo(opts) + repo, err = storage.NewGCSRepo(opts) if err != nil { return fmt.Errorf("failed to prepare upload to GCS: %w", err) } case "azure": - repo, err = NewAzRepo(opts) + repo, err = storage.NewAzRepo(opts) if err != nil { return fmt.Errorf("failed to prepare upload to Azure: %w", err) } @@ -576,12 +569,12 @@ func run() (retVal error) { } limit := now.Add(o.PurgeInterval) - if err := purgeDumps(opts.Directory, dbname, o.PurgeKeep, limit); err != nil { + if err := storage.PurgeDumps(l, opts.Directory, dbname, o.PurgeKeep, limit); err != nil { retVal = err } if opts.PurgeRemote && repo != nil { - if err := purgeRemoteDumps(repo, opts.UploadPrefix, opts.Directory, dbname, o.PurgeKeep, limit); err != nil { + if err := storage.PurgeRemoteDumps(l, repo, opts.UploadPrefix, opts.Directory, dbname, o.PurgeKeep, limit); err != nil { retVal = err } } @@ -590,12 +583,12 @@ func run() (retVal error) { if !opts.DumpOnly { for _, other := range []string{"pg_globals", "pg_settings", "hba_file", "ident_file"} { limit := now.Add(defDbOpts.PurgeInterval) - if err := purgeDumps(opts.Directory, other, defDbOpts.PurgeKeep, limit); err != nil { + if err := storage.PurgeDumps(l, opts.Directory, other, defDbOpts.PurgeKeep, limit); err != nil { retVal = err } if opts.PurgeRemote && repo != nil { - if err := purgeRemoteDumps(repo, opts.UploadPrefix, opts.Directory, other, defDbOpts.PurgeKeep, limit); err != nil { + if err := storage.PurgeRemoteDumps(l, repo, opts.UploadPrefix, opts.Directory, other, defDbOpts.PurgeKeep, limit); err != nil { retVal = err } } @@ -605,8 +598,8 @@ func run() (retVal error) { return } -func defaultDbOpts(opts options) *dbOpts { - dbo := dbOpts{ +func defaultDbOpts(opts config.Options) *config.DbOpts { + dbo := config.DbOpts{ Format: opts.Format, Jobs: opts.DirJobs, CompressLevel: opts.CompressLevel, @@ -629,8 +622,8 @@ func (d *dump) dump(fc chan<- sumFileJob) error { // dump to prevent stacking pg_back processes if pg_dump last // longer than a schedule of pg_back. If the lock cannot be // acquired, skip the dump and exit with an error. - lock := formatDumpPath(d.Directory, d.TimeFormat, "lock", dbname, time.Time{}, 0) - flock, locked, err := lockPath(lock) + lock := helpers.FormatDumpPath(d.Directory, d.TimeFormat, "lock", dbname, time.Time{}, 0) + flock, locked, err := storage.LockPath(l, lock) if err != nil { return fmt.Errorf("unable to lock %s: %s", lock, err) } @@ -659,7 +652,7 @@ func (d *dump) dump(fc chan<- sumFileJob) error { d.When = time.Now() } - file := formatDumpPath( + file := helpers.FormatDumpPath( d.Directory, d.TimeFormat, fileEnd, @@ -669,7 +662,7 @@ func (d *dump) dump(fc chan<- sumFileJob) error { ) formatOpt := fmt.Sprintf("-F%c", d.Options.Format) - command := execPath("pg_dump") + command := pgcommand.ExecPath(binDir, "pg_dump") args := []string{formatOpt, "-f", file, "-w"} if fileEnd == "d" && d.Options.Jobs > 1 { @@ -746,7 +739,7 @@ func (d *dump) dump(fc chan<- sumFileJob) error { l.Errorf("[%s] %s\n", dbname, line) } } - if err := unlockPath(flock); err != nil { + if err := storage.UnlockPath(l, flock); err != nil { l.Errorf("could not release lock for %s: %s", dbname, err) flock.Close() //nolint:errcheck } @@ -760,7 +753,7 @@ func (d *dump) dump(fc chan<- sumFileJob) error { } } - if err := unlockPath(flock); err != nil { + if err := storage.UnlockPath(l, flock); err != nil { flock.Close() //nolint:errcheck return fmt.Errorf("could not release lock for %s: %s", dbname, err) } @@ -839,7 +832,7 @@ func dumper(jobs <-chan *dump, results chan<- *dump, fc chan<- sumFileJob) { } } -func ensureCipherParamsPresent(opts *options) error { +func ensureCipherParamsPresent(opts *config.Options) error { // Nothing needs to be done if we are not encrypting or decrypting if !opts.Encrypt && !opts.Decrypt { return nil @@ -862,134 +855,16 @@ func ensureCipherParamsPresent(opts *options) error { return nil } -func relPath(basedir, path string) string { - target, err := filepath.Rel(basedir, path) - if err != nil { - l.Warnf("could not get relative path from %s: %s\n", path, err) - target = path - } - - prefix := fmt.Sprintf("..%c", os.PathSeparator) - for strings.HasPrefix(target, prefix) { - target = strings.TrimPrefix(target, prefix) - } - - return target -} - -func execPath(prog string) string { - binFile := prog - if runtime.GOOS == "windows" { - binFile = fmt.Sprintf("%s.exe", prog) - } - - if binDir != "" { - return filepath.Join(binDir, binFile) - } - - return binFile -} - -func cleanDBName(dbname string) string { - // We do not want a database name starting with a dot to avoid creating hidden files - if strings.HasPrefix(dbname, ".") { - dbname = "_" + dbname - } - - // If there is a path separator in the database name, we do not want to - // create the dump in a subdirectory or in a parent directory - if strings.ContainsRune(dbname, os.PathSeparator) { - dbname = strings.ReplaceAll(dbname, string(os.PathSeparator), "_") - } - - // Always remove slashes to avoid issues with filenames on windows - if strings.ContainsRune(dbname, '/') { - dbname = strings.ReplaceAll(dbname, "/", "_") - } - - return dbname -} - -func formatDumpPath( - dir string, - timeFormat string, - suffix string, - dbname string, - when time.Time, - compressLevel int, -) string { - var f, s, d string - - // Avoid attacks on the database name - dbname = cleanDBName(dbname) - - d = dir - if dbname != "" { - d = strings.ReplaceAll(dir, "{dbname}", dbname) - } - - s = suffix - if suffix == "" { - s = "dump" - } - - // Output is "dir(formatted)/dbname_date.suffix" when the - // input time is not zero, otherwise do not include the date - // and time. Reference time for time.Format(): "Mon Jan 2 - // 15:04:05 MST 2006" - if when.IsZero() { - f = fmt.Sprintf("%s.%s", dbname, s) - } else { - f = fmt.Sprintf("%s_%s.%s", dbname, when.Format(timeFormat), s) - } - - if suffix == "sql" && compressLevel > 0 { - f = f + ".gz" - } - - return filepath.Join(d, f) -} - -func pgToolVersion(tool string) int { - vs, err := exec.Command(execPath(tool), "--version").Output() - if err != nil { - l.Warnf("failed to retrieve version of %s: %s", tool, err) - return 0 - } - - var maj, min, rev, numver int - n, _ := fmt.Sscanf(string(vs), tool+" (PostgreSQL) %d.%d.%d", &maj, &min, &rev) - - switch n { - case 3: - // Before PostgreSQL 10, the format si MAJ.MIN.REV - numver = (maj*100+min)*100 + rev - case 2: - // From PostgreSQL 10, the format si MAJ.REV, so the rev ends - // up in min with the scan - numver = maj*10000 + min - default: - // We have the special case of the development version, where the - // format is MAJdevel - fmt.Sscanf(string(vs), tool+" (PostgreSQL) %ddevel", &maj) - numver = maj * 10000 - } - - l.Verboseln(tool, "version is:", numver) - - return numver -} - func dumpGlobals( dir string, mode int, timeFormat string, withRolePasswords bool, - conninfo *ConnInfo, + conninfo *postgresql.ConnInfo, fc chan<- sumFileJob, when time.Time, ) error { - command := execPath("pg_dumpall") + command := pgcommand.ExecPath(binDir, "pg_dumpall") args := []string{"-g", "-w"} // pg_dumpall only connects to another database if it is given @@ -1003,7 +878,8 @@ func dumpGlobals( // information var env []string - pgDumpallVersion := pgToolVersion("pg_dumpall") + pgDumpallVersion := pgcommand.PgToolVersion(l, binDir, "pg_dumpall") + if pgDumpallVersion < 90300 { env = os.Environ() env = append(env, conninfo.MakeEnv()...) @@ -1026,7 +902,7 @@ func dumpGlobals( when = time.Now() } - file := formatDumpPath(dir, timeFormat, "sql", "pg_globals", when, 0) + file := helpers.FormatDumpPath(dir, timeFormat, "sql", "pg_globals", when, 0) args = append(args, "-f", file) if err := os.MkdirAll(filepath.Dir(file), 0700); err != nil { @@ -1071,7 +947,7 @@ func dumpSettings( dir string, mode int, timeFormat string, - db *pg, + db *postgresql.Pg, fc chan<- sumFileJob, when time.Time, ) error { @@ -1079,13 +955,13 @@ func dumpSettings( when = time.Now() } - file := formatDumpPath(dir, timeFormat, "out", "pg_settings", when, 0) + file := helpers.FormatDumpPath(dir, timeFormat, "out", "pg_settings", when, 0) if err := os.MkdirAll(filepath.Dir(file), 0o700); err != nil { return err } - s, err := showSettings(db) + s, err := postgresql.ShowSettings(l, db) if err != nil { return err } @@ -1123,7 +999,7 @@ func dumpConfigFiles( dir string, mode int, timeFormat string, - db *pg, + db *postgresql.Pg, fc chan<- sumFileJob, when time.Time, ) error { @@ -1131,13 +1007,13 @@ func dumpConfigFiles( if when.IsZero() { when = time.Now() } - file := formatDumpPath(dir, timeFormat, "out", param, when, 0) + file := helpers.FormatDumpPath(dir, timeFormat, "out", param, when, 0) if err := os.MkdirAll(filepath.Dir(file), 0700); err != nil { return err } - s, err := extractFileFromSettings(db, param) + s, err := postgresql.ExtractFileFromSettings(l, db, param) if err != nil { return err } @@ -1176,13 +1052,13 @@ func dumpConfigFiles( return nil } -func listRemoteFiles(repoName string, opts options, globs []string) error { - repo, err := NewRepo(repoName, opts) +func listRemoteFiles(repoName string, opts config.Options, globs []string) error { + repo, err := storage.NewRepo(repoName, opts) if err != nil { return err } - remoteFiles, err := repo.List("") + remoteFiles, err := repo.List(l, "") if err != nil { return fmt.Errorf("could not list contents of remote location: %w", err) } @@ -1194,7 +1070,7 @@ func listRemoteFiles(repoName string, opts options, globs []string) error { } for _, glob := range globs { - keep, err = filepath.Match(glob, i.key) + keep, err = filepath.Match(glob, i.Key) if err != nil { return fmt.Errorf("bad patern: %w", err) } @@ -1208,14 +1084,14 @@ func listRemoteFiles(repoName string, opts options, globs []string) error { continue } - fmt.Println(i.key) + fmt.Println(i.Key) } return nil } -func downloadFiles(repoName string, opts options, dir string, globs []string) error { - repo, err := NewRepo(repoName, opts) +func downloadFiles(repoName string, opts config.Options, dir string, globs []string) error { + repo, err := storage.NewRepo(repoName, opts) if err != nil { return err } @@ -1225,7 +1101,7 @@ func downloadFiles(repoName string, opts options, dir string, globs []string) er return fmt.Errorf("no filter given to download files, use globs as command line arguments") } - remoteFiles, err := repo.List("") + remoteFiles, err := repo.List(l, "") if err != nil { return fmt.Errorf("could not list contents of remote location: %w", err) } @@ -1233,7 +1109,7 @@ func downloadFiles(repoName string, opts options, dir string, globs []string) er for _, i := range remoteFiles { keep := false for _, glob := range globs { - keep, err = filepath.Match(glob, i.key) + keep, err = filepath.Match(glob, i.Key) if err != nil { return fmt.Errorf("bad patern: %w", err) } @@ -1244,27 +1120,27 @@ func downloadFiles(repoName string, opts options, dir string, globs []string) er } if !keep { - l.Verboseln("skipping:", i.key) + l.Verboseln("skipping:", i.Key) continue } - if i.isDir { + if i.IsDir { l.Warnf( "%s is a directory, append %c* to the filter to download its contents", - i.key, + i.Key, os.PathSeparator, ) continue } // Create any parent directory under target dir - path := filepath.Join(dir, i.key) + path := filepath.Join(dir, i.Key) parent := filepath.Dir(path) if err := os.MkdirAll(parent, 0700); err != nil { return fmt.Errorf("could not create directory %s: %w", parent, err) } - if err := repo.Download(i.key, path); err != nil { + if err := repo.Download(l, i.Key, path); err != nil { return err } } @@ -1272,7 +1148,7 @@ func downloadFiles(repoName string, opts options, dir string, globs []string) er return nil } -func decryptDirectory(dir string, params decryptParams, workers int, globs []string) error { +func decryptDirectory(dir string, params crypto.DecryptParams, workers int, globs []string) error { // Run a pool of workers to decrypt concurrently var wg sync.WaitGroup @@ -1298,7 +1174,7 @@ func decryptDirectory(dir string, params decryptParams, workers int, globs []str } l.Verbosef("[%d] processing: %s\n", id, file) - if err := decryptFile(file, params); err != nil { + if err := params.DecryptFile(file); err != nil { l.Errorln(err) failed = true } @@ -1409,28 +1285,12 @@ type sumFileJob struct { SumAlgo string } -type encryptParams struct { - // Encrypt with a passphrase - Passphrase string - - // Encrypt with an AGE public key encoded in Bech32 - PublicKey string -} - -type decryptParams struct { - // A passphrase to use for decryption - Passphrase string - - // An AGE private key encoded in Bech32 - PrivateKey string -} - type encryptFileJob struct { // Path of the file or directory to checksum Path string // How to encrypt the file - Params encryptParams + Params crypto.EncryptParams KeepSrc bool @@ -1456,7 +1316,7 @@ type uploadJob struct { // postProcessFiles is the entrypoint for common tasks to perform on files // produced during execution, checksum and encryption. Different go routines // are spawn to process the files as soon as possible -func postProcessFiles(inFiles chan sumFileJob, wg *sync.WaitGroup, opts options) chan error { +func postProcessFiles(inFiles chan sumFileJob, wg *sync.WaitGroup, opts config.Options) chan error { // Create a channel for errors so that we can inform the main goroutine // that a job failed and have the program exit with a non-zero // status. This chan is buffered with the number of goroutines using it @@ -1498,7 +1358,7 @@ func postProcessFiles(inFiles chan sumFileJob, wg *sync.WaitGroup, opts options) if j.SumAlgo != "none" { l.Infoln("computing checksum of", j.Path) - p, err := checksumFile(j.Path, opts.Mode, j.SumAlgo) + p, err := crypto.ChecksumFile(l, j.Path, opts.Mode, j.SumAlgo) if err != nil { l.Errorln("checksum failed:", err) if !failed { @@ -1512,7 +1372,8 @@ func postProcessFiles(inFiles chan sumFileJob, wg *sync.WaitGroup, opts options) if opts.Encrypt { encIn <- encryptFileJob{ Path: p, - Params: encryptParams{ + Params: crypto.EncryptParams{ + Logger: l, Passphrase: opts.CipherPassphrase, PublicKey: opts.CipherPublicKey, }, @@ -1532,7 +1393,8 @@ func postProcessFiles(inFiles chan sumFileJob, wg *sync.WaitGroup, opts options) if opts.Encrypt { encIn <- encryptFileJob{ Path: j.Path, - Params: encryptParams{ + Params: crypto.EncryptParams{ + Logger: l, Passphrase: opts.CipherPassphrase, PublicKey: opts.CipherPublicKey, }, @@ -1593,7 +1455,7 @@ func postProcessFiles(inFiles chan sumFileJob, wg *sync.WaitGroup, opts options) if opts.Encrypt { l.Infoln("encrypting", j.Path) - encFiles, err := encryptFile(j.Path, opts.Mode, j.Params, j.KeepSrc) + encFiles, err := j.Params.EncryptFile(j.Path, opts.Mode, j.KeepSrc) if err != nil { l.Errorln("encryption failed:", err) if !failed { @@ -1643,7 +1505,7 @@ func postProcessFiles(inFiles chan sumFileJob, wg *sync.WaitGroup, opts options) if j.SumAlgo != "none" { l.Infoln("computing checksum of", j.SumFile) - p, err := checksumFileList(j.Paths, opts.Mode, j.SumAlgo, j.SumFile) + p, err := crypto.ChecksumFileList(l, j.Paths, opts.Mode, j.SumAlgo, j.SumFile) if err != nil { l.Errorln("checksum of encrypted files failed:", err) if !failed { @@ -1664,7 +1526,7 @@ func postProcessFiles(inFiles chan sumFileJob, wg *sync.WaitGroup, opts options) }(i) } - repo, err := NewRepo(opts.Upload, opts) + repo, err := storage.NewRepo(opts.Upload, opts) if err != nil { l.Errorln(err) ret <- err @@ -1687,7 +1549,7 @@ func postProcessFiles(inFiles chan sumFileJob, wg *sync.WaitGroup, opts options) if opts.Upload != "none" && repo != nil { // Prepend the global prefix to the relative path of the dump - if err := repo.Upload(j.Path, filepath.Join(opts.UploadPrefix, relPath(opts.Directory, j.Path))); err != nil { + if err := repo.Upload(l, j.Path, filepath.Join(opts.UploadPrefix, helpers.RelPath(l, opts.Directory, j.Path))); err != nil { l.Errorln(err) if !failed { ret <- err diff --git a/main_test.go b/main_test.go index 70bfd76..d1f2ad4 100644 --- a/main_test.go +++ b/main_test.go @@ -26,51 +26,13 @@ package main import ( - "fmt" - "runtime" "testing" -) - -func TestExecPath(t *testing.T) { - var tests []struct { - dir string - prog string - want string - } - - if runtime.GOOS != "windows" { - tests = []struct { - dir string - prog string - want string - }{ - {"", "pg_dump", "pg_dump"}, - {"/path/to/bin", "prog", "/path/to/bin/prog"}, - } - } else { - tests = []struct { - dir string - prog string - want string - }{ - {"", "pg_dump", "pg_dump.exe"}, - {"C:\\path\\to\\bin", "prog", "C:\\path\\to\\bin\\prog.exe"}, - } - } - for i, st := range tests { - t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - binDir = st.dir - got := execPath(st.prog) - if got != st.want { - t.Errorf("expected %q, got %q\n", st.want, got) - } - }) - } -} + "github.com/orgrim/pg_back/internal/config" +) func TestEnsureCipherParamsPresent_NoEncryptNoDecrypt_NoParams_ReturnsNil(t *testing.T) { - opts := options{} + opts := config.Options{} err := ensureCipherParamsPresent(&opts) if err != nil { @@ -79,7 +41,7 @@ func TestEnsureCipherParamsPresent_NoEncryptNoDecrypt_NoParams_ReturnsNil(t *tes } func TestEnsureCipherParamsPresent_NoEncryptNoDecrypt_HasParams_ReturnsNil(t *testing.T) { - opts := options{ + opts := config.Options{ CipherPublicKey: "foo1", CipherPrivateKey: "bar99", CipherPassphrase: "secretwords", @@ -92,7 +54,7 @@ func TestEnsureCipherParamsPresent_NoEncryptNoDecrypt_HasParams_ReturnsNil(t *te } func TestEnsureCipherParamsPresent_Encrypt_NoParams_Failure(t *testing.T) { - opts := options{ + opts := config.Options{ Encrypt: true, CipherPrivateKey: "bar99", } @@ -104,7 +66,7 @@ func TestEnsureCipherParamsPresent_Encrypt_NoParams_Failure(t *testing.T) { } func TestEnsureCipherParamsPresent_Encrypt_NoParamsButEnv_Success(t *testing.T) { - opts := options{ + opts := config.Options{ Encrypt: true, } t.Setenv("PGBK_CIPHER_PASS", "works")