diff --git a/remote.go b/remote.go index cff5ec3..fb46c8e 100644 --- a/remote.go +++ b/remote.go @@ -34,6 +34,7 @@ import ( sshutil "github.com/aucloud/go-sshutil" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" ) type RemoteCmd struct { @@ -107,6 +108,38 @@ func NewRemoteKeyAuthRunner(ctx context.Context, user, host, key string) (*Remot return &Remote{client}, nil } +func NewRemoteAgentAuthRunner(ctx context.Context, user, host, agentSocket string) (*Remote, error) { + + if _, err := os.Stat(agentSocket); os.IsNotExist(err) { + return nil, fmt.Errorf("agent socket %s does not exist: %w", agentSocket, err) + } + agentConn, err := net.Dial("unix", agentSocket) + if err != nil { + return nil, fmt.Errorf("failed to open SSH agent socket %s: %v", agentSocket, err) + } + agentClient := agent.NewClient(agentConn) + config := &ssh.ClientConfig{ + User: user, + // FIXME: This is insecure. We should verify RSA fingerprints of hosts... + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Auth: []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}, + } + addr, err := ResolveHostname(host) + if err != nil { + return nil, fmt.Errorf("failed to resolve hostname %s: %w", host, err) + } + client, err := sshutil.NewClient( + ctx, + sshutil.ConstantAddrResolver{addr}, + config, + sshutil.DefaultConnectBackoff(), + ) + if err != nil { + return nil, fmt.Errorf("failed to establish an SSH connection to %s: %w", host, err) + } + return &Remote{client}, nil +} + func NewRemoteKeyAuthRunnerViaJumphost(ctx context.Context, user, host, jumphost, key string) (*Remote, error) { if _, err := os.Stat(key); os.IsNotExist(err) { return nil, fmt.Errorf("error reading private ssh key %s: %w", key, err) diff --git a/runcmd_test.go b/runcmd_test.go index 885c3f2..31b288d 100644 --- a/runcmd_test.go +++ b/runcmd_test.go @@ -41,7 +41,19 @@ var ( /* FIXME: Mock an SSH server func TestKeyAuth(t *testing.T) { - rRunner, err := NewRemoteKeyAuthRunner(user, host, key) + rRunner, err := NewRemoteKeyAuthRunner(context.TODO(), user, host, key) + if err != nil { + t.Error(err) + } + if err := testRun(rRunner); err != nil { + t.Error(err) + } +} +*/ + +/* FIXME: Mock an SSH server +func TestKeyAuth(t *testing.T) { + rRunner, err := NewRemoteAgentAuthRunner(context.TODO(), os.Getenv("USER"), "localhost", os.Getenv("SSH_AUTH_SOCK")) if err != nil { t.Error(err) } @@ -58,7 +70,7 @@ func TestPassAuth(t *testing.T) { os.Exit(1) } }() - rRunner, err := NewRemotePassAuthRunner(user, host, pass) + rRunner, err := NewRemotePassAuthRunner(context.TODO(), user, host, pass) if err != nil { t.Error(err) } @@ -80,7 +92,7 @@ func TestLocalRun(t *testing.T) { /* FIXME: Mock anSSH server func TestRemoteRun(t *testing.T) { - rRunner, err := NewRemoteKeyAuthRunner(user, host, key) + rRunner, err := NewRemoteKeyAuthRunner(context.TODO(), user, host, key) if err != nil { t.Error(err) } @@ -102,7 +114,7 @@ func TestLocalStartWait(t *testing.T) { /* FIXME: Mock an SSH server func TestRemoteStartWait(t *testing.T) { - rRunner, err := NewRemoteKeyAuthRunner(user, host, key) + rRunner, err := NewRemoteKeyAuthRunner(context.TODO(), user, host, key) if err != nil { t.Error(err) } @@ -242,7 +254,7 @@ func testPipe(localToRemote bool) error { if err != nil { return err } - rRunner, err := NewRemoteKeyAuthRunner(user, host, key) + rRunner, err := NewRemoteKeyAuthRunner(context.TODO(), user, host, key) if err != nil { return err }