Skip to content

Commit ab1d346

Browse files
📝 Add docstrings to feat/ssh-agent-windows
Docstrings generation was requested by @andya1lan. * #2644 (comment) The following files were modified: * `pkg/remote/connparse/connparse.go` * `pkg/remote/sshagent_windows.go` * `pkg/remote/sshclient.go`
1 parent 323db7f commit ab1d346

File tree

3 files changed

+96
-26
lines changed

3 files changed

+96
-26
lines changed

pkg/remote/connparse/connparse.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,31 @@ func GetConnNameFromContext(ctx context.Context) (string, error) {
9292
return handler.GetRpcContext().Conn, nil
9393
}
9494

95-
// ParseURI parses a connection URI and returns the connection type, host/path, and parameters.
95+
// It recognizes explicit schemes (scheme://...), shorthand forms starting with "//host/path" and WSL-style URIs (wsl://distro/path). When no scheme is provided the scheme defaults to "wsh" and the host may be set to the current connection marker or to the local connection name for local shorthand. For the "wsh" scheme: missing host defaults to the local connection name; paths beginning with "/~" are normalized by removing the leading slash; other paths may receive a prepended "/" except when they look like Windows drive paths, start with ".", "~", or already start with a slash. Trailing slashes in the original URI are preserved in the parsed Path.
9696
func ParseURI(uri string) (*Connection, error) {
97-
split := strings.SplitN(uri, "://", 2)
9897
var scheme string
9998
var rest string
100-
if len(split) > 1 {
101-
scheme = split[0]
102-
rest = strings.TrimPrefix(split[1], "//")
99+
100+
if strings.HasPrefix(uri, "//") {
101+
rest = strings.TrimPrefix(uri, "//")
103102
} else {
104-
rest = split[0]
103+
split := strings.SplitN(uri, "://", 2)
104+
if len(split) > 1 {
105+
scheme = split[0]
106+
rest = strings.TrimPrefix(split[1], "//")
107+
} else {
108+
rest = split[0]
109+
}
105110
}
106111

107112
var host string
108113
var remotePath string
109114

110115
parseGenericPath := func() {
111-
split = strings.SplitN(rest, "/", 2)
112-
host = split[0]
113-
if len(split) > 1 && split[1] != "" {
114-
remotePath = split[1]
116+
parts := strings.SplitN(rest, "/", 2)
117+
host = parts[0]
118+
if len(parts) > 1 && parts[1] != "" {
119+
remotePath = parts[1]
115120
} else if strings.HasSuffix(rest, "/") {
116121
// preserve trailing slash
117122
remotePath = "/"
@@ -133,8 +138,9 @@ func ParseURI(uri string) (*Connection, error) {
133138
if scheme == "" {
134139
scheme = ConnectionTypeWsh
135140
addPrecedingSlash = false
136-
if len(rest) != len(uri) {
137-
// This accounts for when the uri starts with "//", which would get trimmed in the first split.
141+
if strings.HasPrefix(uri, "//") {
142+
rest = strings.TrimPrefix(uri, "//")
143+
// Handles remote shorthand like //host/path and WSL URIs //wsl://distro/path
138144
parseWshPath()
139145
} else if strings.HasPrefix(rest, "/~") {
140146
host = wshrpc.LocalConnName
@@ -166,4 +172,4 @@ func ParseURI(uri string) (*Connection, error) {
166172
Path: remotePath,
167173
}
168174
return conn, nil
169-
}
175+
}

pkg/remote/sshagent_windows.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//go:build windows
2+
3+
package remote
4+
5+
import (
6+
"net"
7+
"time"
8+
9+
"github.com/Microsoft/go-winio"
10+
)
11+
12+
// dialIdentityAgent connects to the Windows OpenSSH agent named pipe at the given path and returns the established connection or an error.
13+
func dialIdentityAgent(agentPath string) (net.Conn, error) {
14+
timeout := 2 * time.Second
15+
return winio.DialPipe(agentPath, &timeout)
16+
}

pkg/remote/sshclient.go

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"os/exec"
1818
"os/user"
1919
"path/filepath"
20+
"runtime"
2021
"strings"
2122
"sync"
2223
"time"
@@ -224,6 +225,14 @@ func createPublicKeyCallback(connCtx context.Context, sshKeywords *wconfig.ConnK
224225
}
225226
}
226227

228+
// createPasswordCallbackPrompt returns a function that obtains a password for SSH authentication.
229+
//
230+
// The returned callback returns a password string or an error when password acquisition fails.
231+
// If the optional `password` pointer is non-nil, its value is returned directly without prompting.
232+
// Otherwise the callback prompts the user (with a 60 second timeout) for a password using the
233+
// provided connection context and includes `remoteDisplayName` in the prompt. On prompt or input
234+
// errors the callback returns a ConnectionError that wraps the underlying error and includes
235+
// `debugInfo` for diagnostics. The callback also converts panics into errors.
227236
func createPasswordCallbackPrompt(connCtx context.Context, remoteDisplayName string, password *string, debugInfo *ConnectionDebugInfo) func() (secret string, err error) {
228237
return func() (secret string, outErr error) {
229238
defer func() {
@@ -233,12 +242,12 @@ func createPasswordCallbackPrompt(connCtx context.Context, remoteDisplayName str
233242
}
234243
}()
235244
blocklogger.Infof(connCtx, "[conndebug] Password Authentication requested from connection %s...\n", remoteDisplayName)
236-
245+
237246
if password != nil {
238247
blocklogger.Infof(connCtx, "[conndebug] using password from secret store, sending to ssh\n")
239248
return *password, nil
240249
}
241-
250+
242251
ctx, cancelFn := context.WithTimeout(connCtx, 60*time.Second)
243252
defer cancelFn()
244253
queryText := fmt.Sprintf(
@@ -598,6 +607,22 @@ func createHostKeyCallback(ctx context.Context, sshKeywords *wconfig.ConnKeyword
598607
return waveHostKeyCallback, hostKeyAlgorithms, nil
599608
}
600609

610+
// createClientConfig builds an ssh.ClientConfig configured for the target described
611+
// by sshKeywords and using connCtx for context-aware operations.
612+
//
613+
// The returned ClientConfig is populated with the selected user, authentication
614+
// methods (publickey, keyboard-interactive, password) ordered by the host's
615+
// PreferredAuthentications, a host key verification callback, and the host key
616+
// algorithms appropriate for the destination. Batch mode, IdentitiesOnly, and
617+
// preferred-authentication flags from sshKeywords are honored.
618+
//
619+
// This function may:
620+
// - open and query an SSH identity agent socket if configured and allowed;
621+
// - retrieve a password from the configured secret store when SshPasswordSecretName
622+
// is set.
623+
//
624+
// It returns a non-nil error when required setup steps fail (for example, secret
625+
// retrieval or host key callback construction).
601626
func createClientConfig(connCtx context.Context, sshKeywords *wconfig.ConnKeywords, debugInfo *ConnectionDebugInfo) (*ssh.ClientConfig, error) {
602627
chosenUser := utilfn.SafeDeref(sshKeywords.SshUser)
603628
chosenHostName := utilfn.SafeDeref(sshKeywords.SshHostName)
@@ -612,10 +637,11 @@ func createClientConfig(connCtx context.Context, sshKeywords *wconfig.ConnKeywor
612637

613638
// IdentitiesOnly indicates that only the keys listed in the identity and certificate files or passed as arguments should be used, even if there are matches in the SSH Agent, PKCS11Provider, or SecurityKeyProvider. See https://man.openbsd.org/ssh_config#IdentitiesOnly
614639
// TODO: Update if we decide to support PKCS11Provider and SecurityKeyProvider
615-
if !utilfn.SafeDeref(sshKeywords.SshIdentitiesOnly) {
616-
conn, err := net.Dial("unix", utilfn.SafeDeref(sshKeywords.SshIdentityAgent))
640+
agentPath := strings.TrimSpace(utilfn.SafeDeref(sshKeywords.SshIdentityAgent))
641+
if !utilfn.SafeDeref(sshKeywords.SshIdentitiesOnly) && agentPath != "" {
642+
conn, err := dialIdentityAgent(agentPath)
617643
if err != nil {
618-
log.Printf("Failed to open Identity Agent Socket: %v", err)
644+
log.Printf("Failed to open Identity Agent Socket %q: %v", agentPath, err)
619645
} else {
620646
agentClient = agent.NewClient(conn)
621647
authSockSigners, _ = agentClient.Signers()
@@ -801,7 +827,19 @@ func ConnectToClient(connCtx context.Context, opts *SSHOpts, currentClient *ssh.
801827

802828
// note that a `var == "yes"` will default to false
803829
// but `var != "no"` will default to true
804-
// when given unexpected strings
830+
// findSshConfigKeywords reads SSH configuration for the provided hostPattern and returns a populated
831+
// wconfig.ConnKeywords describing the resolved connection parameters.
832+
//
833+
// The returned ConnKeywords includes resolved values for user, hostname, port, identity files,
834+
// batch mode, publickey/password/keyboard-interactive authentication flags, preferred
835+
// authentications, AddKeysToAgent, IdentitiesOnly, IdentityAgent (with home‑dir expansion and
836+
// platform-aware fallbacks), ProxyJump entries, and user/global known_hosts files. Identity file
837+
// paths are trimmed of surrounding quotes; boolean-style options are normalized from common SSH
838+
// values (e.g., "yes"/"no"). ProxyJump entries are split on commas and empty/"none" values are
839+
// ignored. Known-hosts file fields are split on whitespace.
840+
//
841+
// An error is returned if reading or expanding SSH configuration values fails. Panics are
842+
// converted into errors and returned.
805843
func findSshConfigKeywords(hostPattern string) (connKeywords *wconfig.ConnKeywords, outErr error) {
806844
defer func() {
807845
panicErr := panichandler.PanicHandler("sshclient:find-ssh-config-keywords", recover())
@@ -900,17 +938,27 @@ func findSshConfigKeywords(hostPattern string) (connKeywords *wconfig.ConnKeywor
900938
return nil, err
901939
}
902940
if identityAgentRaw == "" {
903-
shellPath := shellutil.DetectLocalShellPath()
904-
authSockCommand := exec.Command(shellPath, "-c", "echo ${SSH_AUTH_SOCK}")
905-
sshAuthSock, err := authSockCommand.Output()
906-
if err == nil {
907-
agentPath, err := wavebase.ExpandHomeDir(trimquotes.TryTrimQuotes(strings.TrimSpace(string(sshAuthSock))))
941+
if envSock := os.Getenv("SSH_AUTH_SOCK"); envSock != "" {
942+
agentPath, err := wavebase.ExpandHomeDir(trimquotes.TryTrimQuotes(envSock))
908943
if err != nil {
909944
return nil, err
910945
}
911946
sshKeywords.SshIdentityAgent = utilfn.Ptr(agentPath)
947+
} else if runtime.GOOS == "windows" {
948+
sshKeywords.SshIdentityAgent = utilfn.Ptr(`\\.\\pipe\\openssh-ssh-agent`)
912949
} else {
913-
log.Printf("unable to find SSH_AUTH_SOCK: %v\n", err)
950+
shellPath := shellutil.DetectLocalShellPath()
951+
authSockCommand := exec.Command(shellPath, "-c", "echo ${SSH_AUTH_SOCK}")
952+
sshAuthSock, err := authSockCommand.Output()
953+
if err == nil {
954+
agentPath, err := wavebase.ExpandHomeDir(trimquotes.TryTrimQuotes(strings.TrimSpace(string(sshAuthSock))))
955+
if err != nil {
956+
return nil, err
957+
}
958+
sshKeywords.SshIdentityAgent = utilfn.Ptr(agentPath)
959+
} else {
960+
log.Printf("unable to find SSH_AUTH_SOCK: %v\n", err)
961+
}
914962
}
915963
} else {
916964
agentPath, err := wavebase.ExpandHomeDir(trimquotes.TryTrimQuotes(identityAgentRaw))
@@ -1040,4 +1088,4 @@ func mergeKeywords(oldKeywords *wconfig.ConnKeywords, newKeywords *wconfig.ConnK
10401088
}
10411089

10421090
return &outKeywords
1043-
}
1091+
}

0 commit comments

Comments
 (0)