diff --git a/pkg/cmd/refresh/refresh.go b/pkg/cmd/refresh/refresh.go index 8499156a..d82f17da 100644 --- a/pkg/cmd/refresh/refresh.go +++ b/pkg/cmd/refresh/refresh.go @@ -22,6 +22,7 @@ type RefreshStore interface { ssh.SSHConfigurerV2Store GetCurrentUser() (*entity.User, error) GetCurrentUserKeys() (*entity.UserKeys, error) + WriteAuthorizedKey(publicKey string) error Chmod(string, fs.FileMode) error MkdirAll(string, fs.FileMode) error GetBrevCloudflaredBinaryPath() (string, error) @@ -62,7 +63,7 @@ func RunRefreshBetter(store RefreshStore) error { return breverrors.WrapAndTrace(err) } - cu, err := GetConfigUpdater(store) + cu, keys, err := GetConfigUpdater(store) if err != nil { return breverrors.WrapAndTrace(err) } @@ -72,6 +73,11 @@ func RunRefreshBetter(store RefreshStore) error { return breverrors.WrapAndTrace(err) } + err = store.WriteAuthorizedKey(keys.PublicKey) + if err != nil { + return breverrors.WrapAndTrace(err) + } + privateKeyPath, err := store.GetPrivateKeyPath() if err != nil { return breverrors.WrapAndTrace(err) @@ -91,7 +97,7 @@ func RunRefresh(store RefreshStore) error { return breverrors.WrapAndTrace(err) } - cu, err := GetConfigUpdater(store) + cu, keys, err := GetConfigUpdater(store) if err != nil { return breverrors.WrapAndTrace(err) } @@ -101,6 +107,11 @@ func RunRefresh(store RefreshStore) error { return breverrors.WrapAndTrace(err) } + err = store.WriteAuthorizedKey(keys.PublicKey) + if err != nil { + return breverrors.WrapAndTrace(err) + } + privateKeyPath, err := store.GetPrivateKeyPath() if err != nil { return breverrors.WrapAndTrace(err) @@ -139,20 +150,20 @@ func RunRefreshAsync(rstore RefreshStore) *RefreshRes { return &res } -func GetConfigUpdater(store RefreshStore) (*ssh.ConfigUpdater, error) { +func GetConfigUpdater(store RefreshStore) (*ssh.ConfigUpdater, *entity.UserKeys, error) { configs, err := ssh.GetSSHConfigs(store) if err != nil { - return nil, breverrors.WrapAndTrace(err) + return nil, nil, breverrors.WrapAndTrace(err) } keys, err := store.GetCurrentUserKeys() if err != nil { - return nil, breverrors.WrapAndTrace(err) + return nil, nil, breverrors.WrapAndTrace(err) } cu := ssh.NewConfigUpdater(store, configs, keys.PrivateKey) - return cu, nil + return cu, keys, nil } func GetCloudflare(refreshStore RefreshStore) store.Cloudflared { diff --git a/pkg/files/files.go b/pkg/files/files.go index 1b148b31..7fa48991 100644 --- a/pkg/files/files.go +++ b/pkg/files/files.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" breverrors "github.com/brevdev/brev-cli/pkg/errors" "golang.org/x/text/encoding/charmap" @@ -75,6 +76,10 @@ func GetSSHPrivateKeyPath(home string) string { return fpath } +func GetAuthorizedKeysPath(home string) string { + return filepath.Join(home, ".ssh", "authorized_keys") +} + func GetUserSSHConfigPath(home string) (string, error) { sshConfigPath := filepath.Join(home, ".ssh", "config") return sshConfigPath, nil @@ -210,6 +215,39 @@ func OverwriteJSON(fs afero.Fs, filepath string, v interface{}) error { // write +// WriteAuthorizedKey ensures the given public key is present in ~/.ssh/authorized_keys. +// It appends the key only if it's not already there. +func WriteAuthorizedKey(fs afero.Fs, publicKey string, home string) error { + authorizedKeysPath := GetAuthorizedKeysPath(home) + err := fs.MkdirAll(filepath.Dir(authorizedKeysPath), 0o700) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + publicKey = strings.TrimSpace(publicKey) + + existing, err := afero.ReadFile(fs, authorizedKeysPath) + if err != nil && !os.IsNotExist(err) { + return breverrors.WrapAndTrace(err) + } + + if strings.Contains(string(existing), publicKey) { + return nil + } + + content := string(existing) + if len(content) > 0 && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += publicKey + "\n" + + err = afero.WriteFile(fs, authorizedKeysPath, []byte(content), 0o600) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + func WriteSSHPrivateKey(fs afero.Fs, data string, home string) error { pkPath := GetSSHPrivateKeyPath(home) err := fs.MkdirAll(filepath.Dir(pkPath), defaultFilePermission) diff --git a/pkg/store/ssh.go b/pkg/store/ssh.go index 61f195ed..d1dcc367 100644 --- a/pkg/store/ssh.go +++ b/pkg/store/ssh.go @@ -235,6 +235,18 @@ func (f FileStore) WritePrivateKey(pem string) error { return nil } +func (f FileStore) WriteAuthorizedKey(publicKey string) error { + home, err := f.UserHomeDir() + if err != nil { + return breverrors.WrapAndTrace(err) + } + err = files.WriteAuthorizedKey(f.fs, publicKey, home) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + func (f FileStore) GetPrivateKeyPath() (string, error) { home, err := f.UserHomeDir() if err != nil {