diff --git a/docs/docs/connections.mdx b/docs/docs/connections.mdx index 08a8ac2632..93a666ef29 100644 --- a/docs/docs/connections.mdx +++ b/docs/docs/connections.mdx @@ -167,6 +167,14 @@ In addition to the regular ssh config file, wave also has its own config file to | ssh:userknownhostsfile | A list containing the paths of any user host key database files used to keep track of authorized connections. Can be used to overwrite the value in `~/.ssh/config` or to set it if the ssh config is being ignored.| | ssh:globalknownhostsfile | A list containing the paths of any global host key database files used to keep track of authorized connections. Can be used to overwrite the value in `~/.ssh/config` or to set it if the ssh config is being ignored.| +### SSH Agent Detection + +Wave resolves the identity agent path in this order: + +- If `SSH_AUTH_SOCK` is set, that socket or pipe is used. +- If `SSH_AUTH_SOCK` is empty on Windows, Wave falls back to the built-in OpenSSH agent pipe `\\.\\pipe\\openssh-ssh-agent`. Ensure the **OpenSSH Authentication Agent** service is running, or set `SSH_AUTH_SOCK` to a custom pipe if you use another agent. +- If `SSH_AUTH_SOCK` is empty on macOS/Linux, Wave attempts to detect the agent path from the shell; you can override by setting `SSH_AUTH_SOCK` or `ssh:identityagent` explicitly. + ### Example Internal Configurations Here are a couple examples of things you can do using the internal configuration file `connections.json`: diff --git a/go.mod b/go.mod index ba6d1da584..351609305b 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/wavetermdev/waveterm go 1.24.6 require ( + github.com/Microsoft/go-winio v0.6.2 github.com/alexflint/go-filemutex v1.3.0 github.com/aws/aws-sdk-go-v2 v1.40.1 github.com/aws/aws-sdk-go-v2/config v1.32.0 diff --git a/go.sum b/go.sum index d38e9d0cc2..f7fc0cfe41 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/0xrawsec/golang-utils v1.3.2 h1:ww4jrtHRSnX9xrGzJYbalx5nXoZewy4zPxiY+ubJgtg= github.com/0xrawsec/golang-utils v1.3.2/go.mod h1:m7AzHXgdSAkFCD9tWWsApxNVxMlyy7anpPVOyT/yM7E= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM= github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A= github.com/aws/aws-sdk-go-v2 v1.40.1 h1:difXb4maDZkRH0x//Qkwcfpdg1XQVXEAEs2DdXldFFc= diff --git a/pkg/remote/connparse/connparse.go b/pkg/remote/connparse/connparse.go index 18c4e5e274..4fef7d6949 100644 --- a/pkg/remote/connparse/connparse.go +++ b/pkg/remote/connparse/connparse.go @@ -25,6 +25,22 @@ const ( var windowsDriveRegex = regexp.MustCompile(`^[a-zA-Z]:`) var wslConnRegex = regexp.MustCompile(`^wsl://[^/]+`) +func needsPrecedingSlash(path string) bool { + if len(path) <= 1 { + return false + } + if windowsDriveRegex.MatchString(path) { + return false + } + disallowedPrefixes := []string{"/", "~", "./", "../", ".\\", "..\\"} + for _, prefix := range disallowedPrefixes { + if strings.HasPrefix(path, prefix) { + return false + } + } + return path != ".." +} + type Connection struct { Scheme string Host string @@ -94,24 +110,29 @@ func GetConnNameFromContext(ctx context.Context) (string, error) { // ParseURI parses a connection URI and returns the connection type, host/path, and parameters. func ParseURI(uri string) (*Connection, error) { - split := strings.SplitN(uri, "://", 2) var scheme string var rest string - if len(split) > 1 { - scheme = split[0] - rest = strings.TrimPrefix(split[1], "//") + + if strings.HasPrefix(uri, "//") { + rest = strings.TrimPrefix(uri, "//") } else { - rest = split[0] + split := strings.SplitN(uri, "://", 2) + if len(split) > 1 { + scheme = split[0] + rest = strings.TrimPrefix(split[1], "//") + } else { + rest = split[0] + } } var host string var remotePath string parseGenericPath := func() { - split = strings.SplitN(rest, "/", 2) - host = split[0] - if len(split) > 1 && split[1] != "" { - remotePath = split[1] + parts := strings.SplitN(rest, "/", 2) + host = parts[0] + if len(parts) > 1 && parts[1] != "" { + remotePath = parts[1] } else if strings.HasSuffix(rest, "/") { // preserve trailing slash remotePath = "/" @@ -133,8 +154,8 @@ func ParseURI(uri string) (*Connection, error) { if scheme == "" { scheme = ConnectionTypeWsh addPrecedingSlash = false - if len(rest) != len(uri) { - // This accounts for when the uri starts with "//", which would get trimmed in the first split. + if strings.HasPrefix(uri, "//") { + // Handles remote shorthand like //host/path and WSL URIs //wsl://distro/path parseWshPath() } else if strings.HasPrefix(rest, "/~") { host = wshrpc.LocalConnName @@ -155,7 +176,7 @@ func ParseURI(uri string) (*Connection, error) { } if strings.HasPrefix(remotePath, "/~") { remotePath = strings.TrimPrefix(remotePath, "/") - } else if addPrecedingSlash && (len(remotePath) > 1 && !windowsDriveRegex.MatchString(remotePath) && !strings.HasPrefix(remotePath, "/") && !strings.HasPrefix(remotePath, "~") && !strings.HasPrefix(remotePath, "./") && !strings.HasPrefix(remotePath, "../") && !strings.HasPrefix(remotePath, ".\\") && !strings.HasPrefix(remotePath, "..\\") && remotePath != "..") { + } else if addPrecedingSlash && needsPrecedingSlash(remotePath) { remotePath = "/" + remotePath } } diff --git a/pkg/remote/sshagent_unix.go b/pkg/remote/sshagent_unix.go new file mode 100644 index 0000000000..41629ce89a --- /dev/null +++ b/pkg/remote/sshagent_unix.go @@ -0,0 +1,10 @@ +//go:build !windows + +package remote + +import "net" + +// dialIdentityAgent connects to a Unix domain socket identity agent. +func dialIdentityAgent(agentPath string) (net.Conn, error) { + return net.Dial("unix", agentPath) +} diff --git a/pkg/remote/sshagent_unix_test.go b/pkg/remote/sshagent_unix_test.go new file mode 100644 index 0000000000..bd95494972 --- /dev/null +++ b/pkg/remote/sshagent_unix_test.go @@ -0,0 +1,35 @@ +//go:build !windows + +package remote + +import ( + "net" + "path/filepath" + "testing" +) + +func TestDialIdentityAgentUnix(t *testing.T) { + socketPath := filepath.Join(t.TempDir(), "agent.sock") + + ln, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("listen unix socket: %v", err) + } + defer ln.Close() + + acceptDone := make(chan struct{}) + go func() { + conn, _ := ln.Accept() + if conn != nil { + conn.Close() + } + close(acceptDone) + }() + + conn, err := dialIdentityAgent(socketPath) + if err != nil { + t.Fatalf("dialIdentityAgent: %v", err) + } + conn.Close() + <-acceptDone +} diff --git a/pkg/remote/sshagent_windows.go b/pkg/remote/sshagent_windows.go new file mode 100644 index 0000000000..8c11c0182d --- /dev/null +++ b/pkg/remote/sshagent_windows.go @@ -0,0 +1,16 @@ +//go:build windows + +package remote + +import ( + "net" + "time" + + "github.com/Microsoft/go-winio" +) + +// dialIdentityAgent connects to the Windows OpenSSH agent named pipe. +func dialIdentityAgent(agentPath string) (net.Conn, error) { + timeout := 2 * time.Second + return winio.DialPipe(agentPath, &timeout) +} diff --git a/pkg/remote/sshagent_windows_test.go b/pkg/remote/sshagent_windows_test.go new file mode 100644 index 0000000000..ba62fe167d --- /dev/null +++ b/pkg/remote/sshagent_windows_test.go @@ -0,0 +1,21 @@ +//go:build windows + +package remote + +import ( + "testing" + "time" +) + +func TestDialIdentityAgentWindowsTimeout(t *testing.T) { + start := time.Now() + _, err := dialIdentityAgent(`\\.\\pipe\\waveterm-nonexistent-agent`) + if err == nil { + t.Skip("unexpectedly connected to a test pipe; skipping") + } + // Optionally verify error indicates connection/timeout failure + t.Logf("dialIdentityAgent returned expected error: %v", err) + if time.Since(start) > 3*time.Second { + t.Fatalf("dialIdentityAgent exceeded expected timeout window") + } +} diff --git a/pkg/remote/sshclient.go b/pkg/remote/sshclient.go index c7419fd940..76f02a613f 100644 --- a/pkg/remote/sshclient.go +++ b/pkg/remote/sshclient.go @@ -17,6 +17,7 @@ import ( "os/exec" "os/user" "path/filepath" + "runtime" "strings" "sync" "time" @@ -233,12 +234,12 @@ func createPasswordCallbackPrompt(connCtx context.Context, remoteDisplayName str } }() blocklogger.Infof(connCtx, "[conndebug] Password Authentication requested from connection %s...\n", remoteDisplayName) - + if password != nil { blocklogger.Infof(connCtx, "[conndebug] using password from secret store, sending to ssh\n") return *password, nil } - + ctx, cancelFn := context.WithTimeout(connCtx, 60*time.Second) defer cancelFn() queryText := fmt.Sprintf( @@ -612,10 +613,11 @@ func createClientConfig(connCtx context.Context, sshKeywords *wconfig.ConnKeywor // IdentitiesOnly indicates that only the keys listed in the identity and certificate files or passed as arguments should be used, even if there are matches in the SSH Agent, PKCS11Provider, or SecurityKeyProvider. See https://man.openbsd.org/ssh_config#IdentitiesOnly // TODO: Update if we decide to support PKCS11Provider and SecurityKeyProvider - if !utilfn.SafeDeref(sshKeywords.SshIdentitiesOnly) { - conn, err := net.Dial("unix", utilfn.SafeDeref(sshKeywords.SshIdentityAgent)) + agentPath := strings.TrimSpace(utilfn.SafeDeref(sshKeywords.SshIdentityAgent)) + if !utilfn.SafeDeref(sshKeywords.SshIdentitiesOnly) && agentPath != "" { + conn, err := dialIdentityAgent(agentPath) if err != nil { - log.Printf("Failed to open Identity Agent Socket: %v", err) + log.Printf("Failed to open Identity Agent Socket %q: %v", agentPath, err) } else { agentClient = agent.NewClient(conn) authSockSigners, _ = agentClient.Signers() @@ -900,17 +902,32 @@ func findSshConfigKeywords(hostPattern string) (connKeywords *wconfig.ConnKeywor return nil, err } if identityAgentRaw == "" { - shellPath := shellutil.DetectLocalShellPath() - authSockCommand := exec.Command(shellPath, "-c", "echo ${SSH_AUTH_SOCK}") - sshAuthSock, err := authSockCommand.Output() - if err == nil { - agentPath, err := wavebase.ExpandHomeDir(trimquotes.TryTrimQuotes(strings.TrimSpace(string(sshAuthSock)))) + if envSock := os.Getenv("SSH_AUTH_SOCK"); envSock != "" { + agentPath, err := wavebase.ExpandHomeDir(trimquotes.TryTrimQuotes(envSock)) if err != nil { return nil, err } sshKeywords.SshIdentityAgent = utilfn.Ptr(agentPath) + } else if runtime.GOOS == "windows" { + sshKeywords.SshIdentityAgent = utilfn.Ptr(`\\.\\pipe\\openssh-ssh-agent`) } else { - log.Printf("unable to find SSH_AUTH_SOCK: %v\n", err) + shellPath := shellutil.DetectLocalShellPath() + authSockCommand := exec.Command(shellPath, "-c", "echo ${SSH_AUTH_SOCK}") + sshAuthSock, err := authSockCommand.Output() + if err == nil { + trimmedSock := strings.TrimSpace(string(sshAuthSock)) + if trimmedSock == "" { + log.Printf("SSH_AUTH_SOCK is empty in shell environment") + } else { + agentPath, err := wavebase.ExpandHomeDir(trimquotes.TryTrimQuotes(trimmedSock)) + if err != nil { + return nil, err + } + sshKeywords.SshIdentityAgent = utilfn.Ptr(agentPath) + } + } else { + log.Printf("unable to find SSH_AUTH_SOCK: %v\n", err) + } } } else { agentPath, err := wavebase.ExpandHomeDir(trimquotes.TryTrimQuotes(identityAgentRaw))